copeai-ai-backend/copeai_backend/conversation.py

102 lines
3.0 KiB
Python

from dataclasses import dataclass
import typing
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
import tiktoken
from enum import Enum
from copeai_backend.exception import ConversationLockedException
from . import models
encoding = tiktoken.get_encoding("cl100k_base")
BASE_PROMPT = ""
def text_to_tokens(string_or_messages: str | list[str | dict | list]) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = 0
messages = []
if isinstance(string_or_messages, str):
messages = [{"role": "user", "content": string_or_messages}]
else:
messages = string_or_messages
for message in messages:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
if isinstance(message, dict):
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
if key == "name": # if there's a name, the role is omitted
num_tokens += -1 # role is always required and always 1 token
elif isinstance(message, list):
for item in message:
if item["type"] == "text":
num_tokens += len(encoding.encode(item["text"]))
elif isinstance(message, str):
num_tokens += len(encoding.encode(message))
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens
class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
@dataclass
class GeneratingResponseChunk:
"""A chunk of a response from the model. You receive this when the **generation is still going on**, and streamed."""
text: str
raw: ChatCompletionChunk
class Conversation:
def __init__(self, add_base_prompt: bool = True, storage: dict = {}) -> None:
self.messages = []
self.last_used_model: models.Model | None = None
self.locked = False
self.interruput = False
self.store = storage
if add_base_prompt and BASE_PROMPT:
self.messages.append({"role": Role.SYSTEM, "content": BASE_PROMPT})
def add_message(self, role: Role, message, username: str | None = None):
if not self.locked:
d = {"role": role.value, "content": message}
if username:
d["name"] = username
self.messages.append(d)
else:
raise ConversationLockedException()
def interrupt(self):
"""Interrupts any conversations going on."""
self.interruput = True
def get_tokens(self):
return text_to_tokens(self.messages)
def last_role(self):
return Role(self.messages[-1]["role"])
def last_message(self):
return self.messages[-1]["content"]
@dataclass
class ConversationResponse:
"""A response from the generation. You receive this when the **generation is done**, or non-streamed requests."""
conversation: Conversation
response: str | list[str]
raw_response: list[ChatCompletion] | list[ChatCompletionChunk]