copeai-ai-backend/copeai_backend/generate.py

92 lines
2.8 KiB
Python

import json
import traceback
from typing import Any, AsyncGenerator, Coroutine, Generator
import requests
import openai
import asyncio
from dotenv import load_dotenv
import os
from .conversation import (
Conversation,
Role,
ConversationResponse,
GeneratingResponseChunk,
)
from .models import Model
from .exception import ConversationLockedException
load_dotenv()
oclient = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_KEY"))
async def simple_process_text(
conversation: Conversation,
model: Model,
new_message: str,
additional_args: dict = {},
) -> ConversationResponse:
conversation.add_message(Role.USER, new_message)
conversation.last_used_model = model
r = await oclient.chat.completions.create(
model=model.id, messages=conversation.messages, **additional_args
)
conversation.add_message(Role.ASSISTANT, r.choices[0].message.content)
return ConversationResponse(conversation, r.choices[0].message.content, r)
async def process_text_streaming(
conversation: Conversation,
model: Model,
new_message: str,
additional_args: dict = {},
) -> [ConversationResponse, GeneratingResponseChunk]: # FIXME change type
if conversation.locked:
raise ConversationLockedException()
try:
text_parts = []
resp_parts = []
conversation.add_message(
Role.USER,
new_message,
(additional_args["userid"] if "userid" in additional_args else "unknown"),
)
conversation.last_used_model = model
conversation.locked = True
if model.service == "openai":
response = await oclient.chat.completions.create(
model=model.id,
messages=conversation.messages,
temperature=0.9,
top_p=1.0,
presence_penalty=0.6,
frequency_penalty=0.0,
max_tokens=4096,
stream=True,
)
async for chunk in response:
partition = chunk.choices[0].delta
if (
"content"
in json.loads(chunk.model_dump_json())["choices"][0]["delta"].keys()
):
if partition.content is not None:
text_parts.append(partition.content)
resp_parts.append(chunk)
yield GeneratingResponseChunk(partition.content, chunk)
if conversation.interruput:
yield ConversationResponse(conversation, text_parts, resp_parts)
conversation.locked = False
conversation.add_message(Role.ASSISTANT, ''.join(text_parts))
yield ConversationResponse(conversation, text_parts, resp_parts)
except Exception as e:
conversation.locked = False
raise e