feat: wip, v2

This commit is contained in:
Showdown76 2024-01-16 21:07:39 +01:00
parent a71d8bfad8
commit 21d4ad64f3
8 changed files with 289 additions and 26 deletions

View File

@ -0,0 +1,3 @@
from .conversation import Conversation, ConversationResponse, Role
from .generate import process_text_streaming, simple_process_text
from .models import Model, Service, GPT_3, GPT_4

View File

@ -0,0 +1,102 @@
from dataclasses import dataclass
import typing
from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletion
import tiktoken
from enum import Enum
from copeai_backend.exception import ConversationLockedException
from . import models
encoding = tiktoken.get_encoding("cl100k_base")
BASE_PROMPT = ""
class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
@dataclass
class GeneratingResponseChunk:
"""A chunk of a response from the model. You receive this when the **generation is still going on**, and streamed."""
text: str
raw: ChatCompletionChunk
class Conversation:
def __init__(self, add_base_prompt: bool = True, storage: dict = {}) -> None:
self.messages = []
self.last_used_model: models.Model | None = None
self.locked = False
self.interruput = False
self.store = storage
if add_base_prompt and BASE_PROMPT:
self.messages.append({"role": Role.SYSTEM.value, "content": BASE_PROMPT})
def add_message(self, role: Role, message, username: str | None = None):
if not self.locked:
d = {"role": role.value, "content": message}
if username:
d["name"] = username
self.messages.append(d)
else:
raise ConversationLockedException()
def interrupt(self):
"""Interrupts any conversations going on."""
self.interruput = True
def get_tokens(self):
return text_to_tokens(self.messages)
def last_role(self):
return Role(self.messages[-1]["role"])
def last_message(self):
return self.messages[-1]["content"]
@dataclass
class ConversationResponse:
"""A response from the generation. You receive this when the **generation is done**, or non-streamed requests."""
conversation: Conversation
text: str | list[str]
raw_response: list[ChatCompletion] | list[ChatCompletionChunk]
def text_to_tokens(string_or_messages: str | list[str | dict | list] | Conversation) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = 0
messages = []
if isinstance(string_or_messages, str):
messages = [{"role": "user", "content": string_or_messages}]
else:
messages = string_or_messages
for message in messages:
# every message follows <im_start>{role/name}\n{content}<im_end>\n
num_tokens += 4
if isinstance(message, dict):
for key, value in message.items():
num_tokens += len(encoding.encode(str(value)))
if key == "name": # if there's a name, the role is omitted
num_tokens += 1 # role is always required and always 1 token
elif isinstance(message, list):
for item in message:
if item["type"] == "text":
num_tokens += len(encoding.encode(item["text"]))
elif isinstance(message, str):
num_tokens += len(encoding.encode(message))
elif isinstance(messages, Conversation):
for message in messages.messages:
num_tokens += text_to_tokens(message["content"])
num_tokens += 2 # every reply is primed with <im_start>assistant
return num_tokens

View File

@ -0,0 +1,7 @@
class ConversationLockedException(Exception):
"""Raised when there is already an ongoing conversation."""
def __init__(self):
super().__init__(
"There is already an ongoing conversation. Please wait until it is finished."
)

View File

@ -0,0 +1 @@
from .LockedConversationException import ConversationLockedException

View File

@ -0,0 +1,91 @@
import json
import traceback
from typing import Any, AsyncGenerator, Coroutine, Generator
import requests
import openai
import asyncio
from dotenv import load_dotenv
import os
from .conversation import (
Conversation,
Role,
ConversationResponse,
GeneratingResponseChunk,
)
from .models import Model
from .exception import ConversationLockedException
load_dotenv()
oclient = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_KEY"))
async def simple_process_text(
conversation: Conversation,
model: Model,
new_message: str,
additional_args: dict = {},
) -> ConversationResponse:
conversation.add_message(Role.USER, new_message)
conversation.last_used_model = model
r = await oclient.chat.completions.create(
model=model.id, messages=conversation.messages, **additional_args
)
conversation.add_message(Role.ASSISTANT, r.choices[0].message.content)
return ConversationResponse(conversation, r.choices[0].message.content, r)
async def process_text_streaming(
conversation: Conversation,
model: Model,
new_message: str,
additional_args: dict = {},
) -> [ConversationResponse, GeneratingResponseChunk]: # FIXME change type
if conversation.locked:
raise ConversationLockedException()
try:
text_parts = []
resp_parts = []
conversation.add_message(
Role.USER,
new_message,
(additional_args["userid"] if "userid" in additional_args else "unknown"),
)
conversation.last_used_model = model
conversation.locked = True
if model.service == "openai":
response = await oclient.chat.completions.create(
model=model.id,
messages=conversation.messages,
temperature=0.9,
top_p=1.0,
presence_penalty=0.6,
frequency_penalty=0.0,
max_tokens=4096,
stream=True,
)
async for chunk in response:
partition = chunk.choices[0].delta
if (
"content"
in json.loads(chunk.model_dump_json())["choices"][0]["delta"].keys()
):
if partition.content is not None:
text_parts.append(partition.content)
resp_parts.append(chunk)
yield GeneratingResponseChunk(partition.content, chunk)
if conversation.interruput:
yield ConversationResponse(conversation, text_parts, resp_parts)
conversation.locked = False
conversation.add_message(Role.ASSISTANT, ''.join(text_parts))
yield ConversationResponse(conversation, text_parts, resp_parts)
except Exception as e:
conversation.locked = False
raise e

17
copeai_backend/models.py Normal file
View File

@ -0,0 +1,17 @@
from dataclasses import dataclass
from typing import Literal
Service = Literal["openai", "bard"]
@dataclass
class Model:
id: str
usage_name: str
service: Service
GPT_3 = Model(id="gpt-3.5-turbo-16k-0613", usage_name="GPT-3", service="openai")
GPT_4 = Model(id="gpt-4-16k-0613", usage_name="GPT-4", service="openai")

73
main.py
View File

@ -8,6 +8,9 @@ import openai
import sqlite3 import sqlite3
import tiktoken import tiktoken
from dotenv import load_dotenv from dotenv import load_dotenv
from typing import Dict
import copeai_backend
import views.GenerationState
load_dotenv() load_dotenv()
@ -19,6 +22,8 @@ intents.members = True
intents.presences = True intents.presences = True
intents.dm_messages = True intents.dm_messages = True
cached_conversations: Dict[discord.User, copeai_backend.conversation.Conversation] = {}
class App(discord.Client): class App(discord.Client):
def __init__(self): def __init__(self):
super().__init__(intents=intents) super().__init__(intents=intents)
@ -55,11 +60,6 @@ async def on_message(message: discord.Message):
if not isinstance(message.channel, discord.DMChannel): return if not isinstance(message.channel, discord.DMChannel): return
if message.author.id == app.user.id: return if message.author.id == app.user.id: return
try: try:
c = db.cursor()
c.execute('SELECT * FROM message_history WHERE user_id = ? ORDER BY timestamp DESC', (message.author.id,))
msgs = c.fetchall()
message_token_usage = num_tokens_from_string(message.content)
max_token = int(os.environ['MAX_TOKEN_PER_REQUEST'])
with open('base-prompt.txt', 'r', encoding='utf-8') as f: with open('base-prompt.txt', 'r', encoding='utf-8') as f:
bprompt = f.read() bprompt = f.read()
@ -84,32 +84,53 @@ async def on_message(message: discord.Message):
} }
for arg in arguments.keys(): bprompt = bprompt.replace(f'|{arg}|', arguments[arg]) for arg in arguments.keys(): bprompt = bprompt.replace(f'|{arg}|', arguments[arg])
previous_tokens = 200+len(bprompt)+message_token_usage
# (message_id, user_id, content, token, role, timestamp)
# order by timestamp (most recent to least recent)
usable_messages = []
for msg in msgs:
d = previous_tokens + msg[3]
if d >= max_token:
break
previous_tokens += msg[3]
usable_messages.append(msg)
usable_messages.reverse()
if message.author not in cached_conversations:
cached_conversations[message.author] = copeai_backend.conversation.Conversation()
c = db.cursor()
c.execute('SELECT * FROM message_history WHERE user_id = ? ORDER BY timestamp DESC', (message.author.id,))
msgs = c.fetchall()
message_token_usage = num_tokens_from_string(message.content)
max_token = int(os.environ['MAX_TOKEN_PER_REQUEST'])
previous_tokens = 200+len(bprompt)+message_token_usage
# (message_id, user_id, content, token, role, timestamp)
# order by timestamp (most recent to least recent)
usable_messages = []
for msg in msgs:
d = previous_tokens + msg[3]
if d >= max_token:
break
previous_tokens += msg[3]
usable_messages.append(msg)
usable_messages.reverse()
messages = [{"role": "system", "content": bprompt}]
for v in usable_messages: messages.append({"role": v[4], "content": v[2]})
else:
total_tokens = copeai_backend.conversation.text_to_tokens(cached_conversations[message.author])
while total_tokens > int(os.environ['MAX_TOKEN_PER_REQUEST']) - 400:
cached_conversations[message.author].messages.pop(0)
total_tokens = copeai_backend.conversation.text_to_tokens(cached_conversations[message.author])
messages = [{"role": "system", "content": bprompt}] cached_conversations[message.author].add_message(
for v in usable_messages: messages.append({"role": v[4], "content": v[2]}) role=copeai_backend.conversation.Role.user,
messages.append({"role": "user", "content": message.content}) content=message.content
)
await message.channel.typing() await message.channel.typing()
typing.append(message.channel) typing.append(message.channel)
req = await openai.ChatCompletion.acreate( req = copeai_backend.generate.process_text_streaming(
model="gpt-3.5-turbo", conversation=cached_conversations[message.author],
temperature=0.5, model=copeai_backend.models.GPT_3,
max_tokens=max_token-(previous_tokens-200), new_message=message.content,
messages=messages additional_args={
"max_tokens": int(os.environ['MAX_TOKEN_PER_REQUEST']),
}
) )
typing.remove(message.channel) typing.remove(message.channel)
response = req['choices'][0]['message']['content'] response = req['choices'][0]['message']['content']
prompt_used_tokens = req['usage']['prompt_tokens'] prompt_used_tokens = req['usage']['prompt_tokens']

21
views/GenerationState.py Normal file
View File

@ -0,0 +1,21 @@
from enum import Enum
import discord
class GenerationState(Enum):
GENERATING = "generating"
FINISHED = "finished"
class GenerationStateButton(discord.ui.Button):
def __init__(self, label: str, style: discord.ButtonStyle, emoji: str | discord.Emoji | discord.PartialEmoji | None = None, disabled: bool=False):
super().__init__(label=label, style=style, emoji=emoji, disabled=disabled)
class GenerationStateView(discord.ui.View):
def __init__(self, state: GenerationState):
super().__init__()
self.state = state
if state == GenerationState.GENERATING:
self.add_item(GenerationStateButton(label="Generating...", style=discord.ButtonStyle.grey, emoji="", disabled=True))
elif state == GenerationState.FINISHED:
pass