feat: wip, v2
This commit is contained in:
		
							
								
								
									
										3
									
								
								copeai_backend/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								copeai_backend/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					from .conversation import Conversation, ConversationResponse, Role
 | 
				
			||||||
 | 
					from .generate import process_text_streaming, simple_process_text
 | 
				
			||||||
 | 
					from .models import Model, Service, GPT_3, GPT_4
 | 
				
			||||||
							
								
								
									
										102
									
								
								copeai_backend/conversation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										102
									
								
								copeai_backend/conversation.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,102 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					import typing
 | 
				
			||||||
 | 
					from openai import AsyncStream
 | 
				
			||||||
 | 
					from openai.types.chat import ChatCompletionChunk, ChatCompletion
 | 
				
			||||||
 | 
					import tiktoken
 | 
				
			||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from copeai_backend.exception import ConversationLockedException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from . import models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					encoding = tiktoken.get_encoding("cl100k_base")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BASE_PROMPT = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Role(Enum):
 | 
				
			||||||
 | 
					    SYSTEM = "system"
 | 
				
			||||||
 | 
					    USER = "user"
 | 
				
			||||||
 | 
					    ASSISTANT = "assistant"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class GeneratingResponseChunk:
 | 
				
			||||||
 | 
					    """A chunk of a response from the model. You receive this when the **generation is still going on**, and streamed."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    text: str
 | 
				
			||||||
 | 
					    raw: ChatCompletionChunk
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Conversation:
 | 
				
			||||||
 | 
					    def __init__(self, add_base_prompt: bool = True, storage: dict = {}) -> None:
 | 
				
			||||||
 | 
					        self.messages = []
 | 
				
			||||||
 | 
					        self.last_used_model: models.Model | None = None
 | 
				
			||||||
 | 
					        self.locked = False
 | 
				
			||||||
 | 
					        self.interruput = False
 | 
				
			||||||
 | 
					        self.store = storage
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if add_base_prompt and BASE_PROMPT:
 | 
				
			||||||
 | 
					            self.messages.append({"role": Role.SYSTEM.value, "content": BASE_PROMPT})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_message(self, role: Role, message, username: str | None = None):
 | 
				
			||||||
 | 
					        if not self.locked:
 | 
				
			||||||
 | 
					            d = {"role": role.value, "content": message}
 | 
				
			||||||
 | 
					            if username:
 | 
				
			||||||
 | 
					                d["name"] = username
 | 
				
			||||||
 | 
					            self.messages.append(d)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ConversationLockedException()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def interrupt(self):
 | 
				
			||||||
 | 
					        """Interrupts any conversations going on."""
 | 
				
			||||||
 | 
					        self.interruput = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_tokens(self):
 | 
				
			||||||
 | 
					        return text_to_tokens(self.messages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def last_role(self):
 | 
				
			||||||
 | 
					        return Role(self.messages[-1]["role"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def last_message(self):
 | 
				
			||||||
 | 
					        return self.messages[-1]["content"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class ConversationResponse:
 | 
				
			||||||
 | 
					    """A response from the generation. You receive this when the **generation is done**, or non-streamed requests."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    conversation: Conversation
 | 
				
			||||||
 | 
					    text: str | list[str]
 | 
				
			||||||
 | 
					    raw_response: list[ChatCompletion] | list[ChatCompletionChunk]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def text_to_tokens(string_or_messages: str | list[str | dict | list] | Conversation) -> int:
 | 
				
			||||||
 | 
					    """Returns the number of tokens in a text string."""
 | 
				
			||||||
 | 
					    num_tokens = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    messages = []
 | 
				
			||||||
 | 
					    if isinstance(string_or_messages, str):
 | 
				
			||||||
 | 
					        messages = [{"role": "user", "content": string_or_messages}]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        messages = string_or_messages
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for message in messages:
 | 
				
			||||||
 | 
					        # every message follows <im_start>{role/name}\n{content}<im_end>\n
 | 
				
			||||||
 | 
					        num_tokens += 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if isinstance(message, dict):
 | 
				
			||||||
 | 
					            for key, value in message.items():
 | 
				
			||||||
 | 
					                num_tokens += len(encoding.encode(str(value)))
 | 
				
			||||||
 | 
					                if key == "name":  # if there's a name, the role is omitted
 | 
				
			||||||
 | 
					                    num_tokens += 1  # role is always required and always 1 token
 | 
				
			||||||
 | 
					        elif isinstance(message, list):
 | 
				
			||||||
 | 
					            for item in message:
 | 
				
			||||||
 | 
					                if item["type"] == "text":
 | 
				
			||||||
 | 
					                    num_tokens += len(encoding.encode(item["text"]))
 | 
				
			||||||
 | 
					        elif isinstance(message, str):
 | 
				
			||||||
 | 
					            num_tokens += len(encoding.encode(message))
 | 
				
			||||||
 | 
					        elif isinstance(messages, Conversation):
 | 
				
			||||||
 | 
					            for message in messages.messages:
 | 
				
			||||||
 | 
					                num_tokens += text_to_tokens(message["content"])
 | 
				
			||||||
 | 
					    num_tokens += 2  # every reply is primed with <im_start>assistant
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return num_tokens
 | 
				
			||||||
							
								
								
									
										7
									
								
								copeai_backend/exception/LockedConversationException.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								copeai_backend/exception/LockedConversationException.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					class ConversationLockedException(Exception):
 | 
				
			||||||
 | 
					    """Raised when there is already an ongoing conversation."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self):
 | 
				
			||||||
 | 
					        super().__init__(
 | 
				
			||||||
 | 
					            "There is already an ongoing conversation. Please wait until it is finished."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
							
								
								
									
										1
									
								
								copeai_backend/exception/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								copeai_backend/exception/__init__.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
				
			|||||||
 | 
					from .LockedConversationException import ConversationLockedException
 | 
				
			||||||
							
								
								
									
										91
									
								
								copeai_backend/generate.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								copeai_backend/generate.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,91 @@
 | 
				
			|||||||
 | 
					import json
 | 
				
			||||||
 | 
					import traceback
 | 
				
			||||||
 | 
					from typing import Any, AsyncGenerator, Coroutine, Generator
 | 
				
			||||||
 | 
					import requests
 | 
				
			||||||
 | 
					import openai
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import asyncio
 | 
				
			||||||
 | 
					from dotenv import load_dotenv
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					from .conversation import (
 | 
				
			||||||
 | 
					    Conversation,
 | 
				
			||||||
 | 
					    Role,
 | 
				
			||||||
 | 
					    ConversationResponse,
 | 
				
			||||||
 | 
					    GeneratingResponseChunk,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from .models import Model
 | 
				
			||||||
 | 
					from .exception import ConversationLockedException
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					load_dotenv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					oclient = openai.AsyncOpenAI(api_key=os.environ.get("OPENAI_KEY"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def simple_process_text(
 | 
				
			||||||
 | 
					    conversation: Conversation,
 | 
				
			||||||
 | 
					    model: Model,
 | 
				
			||||||
 | 
					    new_message: str,
 | 
				
			||||||
 | 
					    additional_args: dict = {},
 | 
				
			||||||
 | 
					) -> ConversationResponse:
 | 
				
			||||||
 | 
					    conversation.add_message(Role.USER, new_message)
 | 
				
			||||||
 | 
					    conversation.last_used_model = model
 | 
				
			||||||
 | 
					    r = await oclient.chat.completions.create(
 | 
				
			||||||
 | 
					        model=model.id, messages=conversation.messages, **additional_args
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    conversation.add_message(Role.ASSISTANT, r.choices[0].message.content)
 | 
				
			||||||
 | 
					    return ConversationResponse(conversation, r.choices[0].message.content, r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def process_text_streaming(
 | 
				
			||||||
 | 
					    conversation: Conversation,
 | 
				
			||||||
 | 
					    model: Model,
 | 
				
			||||||
 | 
					    new_message: str,
 | 
				
			||||||
 | 
					    additional_args: dict = {},
 | 
				
			||||||
 | 
					) -> [ConversationResponse, GeneratingResponseChunk]: # FIXME change type
 | 
				
			||||||
 | 
					    if conversation.locked:
 | 
				
			||||||
 | 
					        raise ConversationLockedException()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        text_parts = []
 | 
				
			||||||
 | 
					        resp_parts = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        conversation.add_message(
 | 
				
			||||||
 | 
					            Role.USER,
 | 
				
			||||||
 | 
					            new_message,
 | 
				
			||||||
 | 
					            (additional_args["userid"] if "userid" in additional_args else "unknown"),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        conversation.last_used_model = model
 | 
				
			||||||
 | 
					        conversation.locked = True
 | 
				
			||||||
 | 
					        if model.service == "openai":
 | 
				
			||||||
 | 
					            response = await oclient.chat.completions.create(
 | 
				
			||||||
 | 
					                model=model.id,
 | 
				
			||||||
 | 
					                messages=conversation.messages,
 | 
				
			||||||
 | 
					                temperature=0.9,
 | 
				
			||||||
 | 
					                top_p=1.0,
 | 
				
			||||||
 | 
					                presence_penalty=0.6,
 | 
				
			||||||
 | 
					                frequency_penalty=0.0,
 | 
				
			||||||
 | 
					                max_tokens=4096,
 | 
				
			||||||
 | 
					                stream=True,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            async for chunk in response:
 | 
				
			||||||
 | 
					                partition = chunk.choices[0].delta
 | 
				
			||||||
 | 
					                if (
 | 
				
			||||||
 | 
					                    "content"
 | 
				
			||||||
 | 
					                    in json.loads(chunk.model_dump_json())["choices"][0]["delta"].keys()
 | 
				
			||||||
 | 
					                ):
 | 
				
			||||||
 | 
					                    if partition.content is not None:
 | 
				
			||||||
 | 
					                        text_parts.append(partition.content)
 | 
				
			||||||
 | 
					                        resp_parts.append(chunk)
 | 
				
			||||||
 | 
					                        yield GeneratingResponseChunk(partition.content, chunk)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                    if conversation.interruput:
 | 
				
			||||||
 | 
					                        yield ConversationResponse(conversation, text_parts, resp_parts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        conversation.locked = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        conversation.add_message(Role.ASSISTANT, ''.join(text_parts))
 | 
				
			||||||
 | 
					        yield ConversationResponse(conversation, text_parts, resp_parts)
 | 
				
			||||||
 | 
					    except Exception as e:
 | 
				
			||||||
 | 
					        conversation.locked = False
 | 
				
			||||||
 | 
					        raise e
 | 
				
			||||||
							
								
								
									
										17
									
								
								copeai_backend/models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								copeai_backend/models.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
				
			|||||||
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					from typing import Literal
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Service = Literal["openai", "bard"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@dataclass
 | 
				
			||||||
 | 
					class Model:
 | 
				
			||||||
 | 
					    id: str
 | 
				
			||||||
 | 
					    usage_name: str
 | 
				
			||||||
 | 
					    service: Service
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					GPT_3 = Model(id="gpt-3.5-turbo-16k-0613", usage_name="GPT-3", service="openai")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					GPT_4 = Model(id="gpt-4-16k-0613", usage_name="GPT-4", service="openai")
 | 
				
			||||||
							
								
								
									
										69
									
								
								main.py
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								main.py
									
									
									
									
									
								
							@@ -8,6 +8,9 @@ import openai
 | 
				
			|||||||
import sqlite3
 | 
					import sqlite3
 | 
				
			||||||
import tiktoken
 | 
					import tiktoken
 | 
				
			||||||
from dotenv import load_dotenv
 | 
					from dotenv import load_dotenv
 | 
				
			||||||
 | 
					from typing import Dict
 | 
				
			||||||
 | 
					import copeai_backend
 | 
				
			||||||
 | 
					import views.GenerationState
 | 
				
			||||||
 | 
					
 | 
				
			||||||
load_dotenv()
 | 
					load_dotenv()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -19,6 +22,8 @@ intents.members = True
 | 
				
			|||||||
intents.presences = True
 | 
					intents.presences = True
 | 
				
			||||||
intents.dm_messages = True
 | 
					intents.dm_messages = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cached_conversations: Dict[discord.User, copeai_backend.conversation.Conversation] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class App(discord.Client):
 | 
					class App(discord.Client):
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
        super().__init__(intents=intents)
 | 
					        super().__init__(intents=intents)
 | 
				
			||||||
@@ -55,11 +60,6 @@ async def on_message(message: discord.Message):
 | 
				
			|||||||
    if not isinstance(message.channel, discord.DMChannel): return
 | 
					    if not isinstance(message.channel, discord.DMChannel): return
 | 
				
			||||||
    if message.author.id == app.user.id: return
 | 
					    if message.author.id == app.user.id: return
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        c = db.cursor()
 | 
					 | 
				
			||||||
        c.execute('SELECT * FROM message_history WHERE user_id = ? ORDER BY timestamp DESC', (message.author.id,))
 | 
					 | 
				
			||||||
        msgs = c.fetchall()
 | 
					 | 
				
			||||||
        message_token_usage = num_tokens_from_string(message.content)
 | 
					 | 
				
			||||||
        max_token = int(os.environ['MAX_TOKEN_PER_REQUEST'])
 | 
					 | 
				
			||||||
        with open('base-prompt.txt', 'r', encoding='utf-8') as f:
 | 
					        with open('base-prompt.txt', 'r', encoding='utf-8') as f:
 | 
				
			||||||
            bprompt = f.read()
 | 
					            bprompt = f.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -85,31 +85,52 @@ async def on_message(message: discord.Message):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        for arg in arguments.keys(): bprompt = bprompt.replace(f'|{arg}|', arguments[arg])
 | 
					        for arg in arguments.keys(): bprompt = bprompt.replace(f'|{arg}|', arguments[arg])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        previous_tokens = 200+len(bprompt)+message_token_usage
 | 
					        if message.author not in cached_conversations:
 | 
				
			||||||
        # (message_id, user_id, content, token, role, timestamp)
 | 
					            cached_conversations[message.author] = copeai_backend.conversation.Conversation()
 | 
				
			||||||
        # order by timestamp (most recent to least recent)
 | 
					            c = db.cursor()
 | 
				
			||||||
        usable_messages = []
 | 
					            c.execute('SELECT * FROM message_history WHERE user_id = ? ORDER BY timestamp DESC', (message.author.id,))
 | 
				
			||||||
        for msg in msgs:
 | 
					            msgs = c.fetchall()
 | 
				
			||||||
            d = previous_tokens + msg[3]
 | 
					            message_token_usage = num_tokens_from_string(message.content)
 | 
				
			||||||
            if d >= max_token:
 | 
					            max_token = int(os.environ['MAX_TOKEN_PER_REQUEST'])
 | 
				
			||||||
                break
 | 
					 | 
				
			||||||
            previous_tokens += msg[3]
 | 
					 | 
				
			||||||
            usable_messages.append(msg)
 | 
					 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
        usable_messages.reverse()
 | 
					            previous_tokens = 200+len(bprompt)+message_token_usage
 | 
				
			||||||
 | 
					            # (message_id, user_id, content, token, role, timestamp)
 | 
				
			||||||
 | 
					            # order by timestamp (most recent to least recent)
 | 
				
			||||||
 | 
					            usable_messages = []
 | 
				
			||||||
 | 
					            for msg in msgs:
 | 
				
			||||||
 | 
					                d = previous_tokens + msg[3]
 | 
				
			||||||
 | 
					                if d >= max_token:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					                previous_tokens += msg[3]
 | 
				
			||||||
 | 
					                usable_messages.append(msg)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            usable_messages.reverse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            
 | 
					            
 | 
				
			||||||
        messages = [{"role": "system", "content": bprompt}]
 | 
					            messages = [{"role": "system", "content": bprompt}]
 | 
				
			||||||
        for v in usable_messages: messages.append({"role": v[4], "content": v[2]})
 | 
					            for v in usable_messages: messages.append({"role": v[4], "content": v[2]})
 | 
				
			||||||
        messages.append({"role": "user", "content": message.content})
 | 
					        else:
 | 
				
			||||||
 | 
					            total_tokens = copeai_backend.conversation.text_to_tokens(cached_conversations[message.author])
 | 
				
			||||||
 | 
					            while total_tokens > int(os.environ['MAX_TOKEN_PER_REQUEST']) - 400:
 | 
				
			||||||
 | 
					                cached_conversations[message.author].messages.pop(0)
 | 
				
			||||||
 | 
					                total_tokens = copeai_backend.conversation.text_to_tokens(cached_conversations[message.author])
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        cached_conversations[message.author].add_message(
 | 
				
			||||||
 | 
					            role=copeai_backend.conversation.Role.user,
 | 
				
			||||||
 | 
					            content=message.content
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        await message.channel.typing()
 | 
					        await message.channel.typing()
 | 
				
			||||||
        typing.append(message.channel)
 | 
					        typing.append(message.channel)
 | 
				
			||||||
        req = await openai.ChatCompletion.acreate(
 | 
					        req = copeai_backend.generate.process_text_streaming(
 | 
				
			||||||
            model="gpt-3.5-turbo",
 | 
					            conversation=cached_conversations[message.author],
 | 
				
			||||||
            temperature=0.5,
 | 
					            model=copeai_backend.models.GPT_3,
 | 
				
			||||||
            max_tokens=max_token-(previous_tokens-200),
 | 
					            new_message=message.content,
 | 
				
			||||||
            messages=messages
 | 
					            additional_args={
 | 
				
			||||||
 | 
					                "max_tokens": int(os.environ['MAX_TOKEN_PER_REQUEST']),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
        typing.remove(message.channel)
 | 
					        typing.remove(message.channel)
 | 
				
			||||||
        response = req['choices'][0]['message']['content']
 | 
					        response = req['choices'][0]['message']['content']
 | 
				
			||||||
        prompt_used_tokens = req['usage']['prompt_tokens']
 | 
					        prompt_used_tokens = req['usage']['prompt_tokens']
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								views/GenerationState.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								views/GenerationState.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					from enum import Enum
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import discord
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationState(Enum):
 | 
				
			||||||
 | 
					    GENERATING = "generating"
 | 
				
			||||||
 | 
					    FINISHED = "finished"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationStateButton(discord.ui.Button):
 | 
				
			||||||
 | 
					    def __init__(self, label: str, style: discord.ButtonStyle, emoji: str | discord.Emoji | discord.PartialEmoji | None = None, disabled: bool=False):
 | 
				
			||||||
 | 
					        super().__init__(label=label, style=style, emoji=emoji, disabled=disabled)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GenerationStateView(discord.ui.View):
 | 
				
			||||||
 | 
					    def __init__(self, state: GenerationState):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.state = state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if state == GenerationState.GENERATING:
 | 
				
			||||||
 | 
					            self.add_item(GenerationStateButton(label="Generating...", style=discord.ButtonStyle.grey, emoji="✨", disabled=True))
 | 
				
			||||||
 | 
					        elif state == GenerationState.FINISHED:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
		Reference in New Issue
	
	Block a user