102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
from dataclasses import dataclass
|
|
import typing
|
|
from openai import AsyncStream
|
|
from openai.types.chat import ChatCompletionChunk, ChatCompletion
|
|
import tiktoken
|
|
from enum import Enum
|
|
|
|
from copeai_backend.exception import ConversationLockedException
|
|
|
|
from . import models
|
|
|
|
encoding = tiktoken.get_encoding("cl100k_base")
|
|
|
|
BASE_PROMPT = ""
|
|
|
|
|
|
def text_to_tokens(string_or_messages: str | list[str | dict | list]) -> int:
|
|
"""Returns the number of tokens in a text string."""
|
|
num_tokens = 0
|
|
|
|
messages = []
|
|
if isinstance(string_or_messages, str):
|
|
messages = [{"role": "user", "content": string_or_messages}]
|
|
else:
|
|
messages = string_or_messages
|
|
|
|
for message in messages:
|
|
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
|
num_tokens += 4
|
|
|
|
if isinstance(message, dict):
|
|
for key, value in message.items():
|
|
num_tokens += len(encoding.encode(str(value)))
|
|
if key == "name": # if there's a name, the role is omitted
|
|
num_tokens += -1 # role is always required and always 1 token
|
|
elif isinstance(message, list):
|
|
for item in message:
|
|
if item["type"] == "text":
|
|
num_tokens += len(encoding.encode(item["text"]))
|
|
elif isinstance(message, str):
|
|
num_tokens += len(encoding.encode(message))
|
|
num_tokens += 2 # every reply is primed with <im_start>assistant
|
|
|
|
return num_tokens
|
|
|
|
|
|
class Role(Enum):
|
|
SYSTEM = "system"
|
|
USER = "user"
|
|
ASSISTANT = "assistant"
|
|
|
|
|
|
@dataclass
|
|
class GeneratingResponseChunk:
|
|
"""A chunk of a response from the model. You receive this when the **generation is still going on**, and streamed."""
|
|
|
|
text: str
|
|
raw: ChatCompletionChunk
|
|
|
|
|
|
class Conversation:
|
|
def __init__(self, add_base_prompt: bool = True, storage: dict = {}) -> None:
|
|
self.messages = []
|
|
self.last_used_model: models.Model | None = None
|
|
self.locked = False
|
|
self.interruput = False
|
|
self.store = storage
|
|
|
|
if add_base_prompt and BASE_PROMPT:
|
|
self.messages.append({"role": Role.SYSTEM.value, "content": BASE_PROMPT})
|
|
|
|
def add_message(self, role: Role, message, username: str | None = None):
|
|
if not self.locked:
|
|
d = {"role": role.value, "content": message}
|
|
if username:
|
|
d["name"] = username
|
|
self.messages.append(d)
|
|
else:
|
|
raise ConversationLockedException()
|
|
|
|
def interrupt(self):
|
|
"""Interrupts any conversations going on."""
|
|
self.interruput = True
|
|
|
|
def get_tokens(self):
|
|
return text_to_tokens(self.messages)
|
|
|
|
def last_role(self):
|
|
return Role(self.messages[-1]["role"])
|
|
|
|
def last_message(self):
|
|
return self.messages[-1]["content"]
|
|
|
|
|
|
@dataclass
|
|
class ConversationResponse:
|
|
"""A response from the generation. You receive this when the **generation is done**, or non-streamed requests."""
|
|
|
|
conversation: Conversation
|
|
text: str | list[str]
|
|
raw_response: list[ChatCompletion] | list[ChatCompletionChunk]
|