144 lines
5.9 KiB
Python
144 lines
5.9 KiB
Python
import traceback
|
|
import json
|
|
import openai
|
|
import base64
|
|
from flask import jsonify
|
|
from objects import aic
|
|
import ai.compute
|
|
from objects import logger as logger_module
|
|
import logging
|
|
|
|
logger: logging.Logger = logger_module.get_logger(__name__)
|
|
|
|
class AIProcessor:
|
|
def __init__(self, api_key: str, model: str = "gpt-4.1"):
|
|
self.oai = openai.Client(api_key=api_key)
|
|
self.model = model
|
|
self.session = aic.Session(messages=[aic.Message(role="system", content=aic.SYSTEM_PROMPT)], model=model) # type: ignore
|
|
self._tools_map = { # local binding of python callables
|
|
"click_button": self._click_button,
|
|
"type_text": self._type_text,
|
|
}
|
|
|
|
# --------------------- tool implementations --------------------- #
|
|
def _click_button(self, x: int, y: int, click_type: str) -> str:
|
|
# TODO: integrate real mouse automation.
|
|
return f"Performed {click_type} click at ({x}, {y})."
|
|
|
|
def _type_text(self, text: str) -> str:
|
|
# TODO: integrate real typing automation.
|
|
return f'Typed text: "{text}"'
|
|
|
|
def _execute_tool(self, name: str, arguments: dict) -> str:
|
|
func = self._tools_map.get(name)
|
|
if not func:
|
|
return f"Unknown tool: {name}"
|
|
try:
|
|
return func(**arguments)
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return f"Error executing {name}: {e}"
|
|
|
|
# -------------------------- main entry -------------------------- #
|
|
def process(self, prompt: str, img_data: str | bytes | None = None) -> list[str | dict]:
|
|
outputs = [] # type: list[str | dict]
|
|
reexec = True
|
|
click_positions = [] # used for screenshot crosshair position
|
|
nextsteps = ""
|
|
try:
|
|
# append user prompt with optional image
|
|
self.session.messages.append(
|
|
aic.Message(role="user", content=prompt, image=img_data)
|
|
)
|
|
# if image provided, perform OCR and include text positions
|
|
response = self.oai.chat.completions.create(
|
|
model=self.model,
|
|
messages=self.session.messages_dict(),
|
|
tools=aic.FUNCTIONS, # type: ignore
|
|
)
|
|
|
|
# return tool call requests if any
|
|
tool_calls = getattr(response.choices[0].message, "tool_calls", None)
|
|
if tool_calls:
|
|
for tc in tool_calls:
|
|
ags = json.loads(tc.function.arguments)
|
|
logger.debug(
|
|
"Processing tool call: %s with arguments: %s",
|
|
tc.function.name,
|
|
tc.function.arguments,
|
|
)
|
|
if tc.function.name == "confirm":
|
|
reexec = False
|
|
try:
|
|
nextsteps = ags.get("goal", "")
|
|
except:
|
|
nextsteps = str(tc.function.arguments)
|
|
print('ERROR NEXT STEPS IS STR, ', nextsteps)
|
|
if tc.function.name == "click_button":
|
|
# extract click position for screenshot crosshair
|
|
click_positions.append((ags.get("x", 0), ags.get("y", 0)))
|
|
r = ai.compute._execute(
|
|
name=tc.function.name,
|
|
args=json.loads(tc.function.arguments),
|
|
processor=self,
|
|
)
|
|
outputs.append(r) if r else None
|
|
# Make sure every images except the two last are removed
|
|
for msg in self.session.messages[:-3]:
|
|
if msg.image and not msg.disable_image:
|
|
msg.image = None
|
|
# copy of self.session.messages, but shorten the image data for better debugging
|
|
cps = [
|
|
aic.Message(
|
|
role=msg.role,
|
|
content=msg.content[:80],
|
|
image=msg.image[:20] if isinstance(msg.image, str) else None, # type: ignore
|
|
disable_image=msg.disable_image,
|
|
name=msg.name,
|
|
)
|
|
for msg in self.session.messages
|
|
]
|
|
logger.debug(
|
|
"Shortened message copies for processing: %s", cps
|
|
)
|
|
if reexec:
|
|
img_bytes = ai.compute.take_screenshot(cross_position=click_positions)
|
|
img = ai.compute.screenshot_to_base64(
|
|
img_bytes
|
|
)
|
|
|
|
ocr_results = []
|
|
try:
|
|
ocr_results = ai.compute.perform_ocr(img_bytes)
|
|
except Exception as e:
|
|
logger.debug("OCR failed: %s", e)
|
|
|
|
self.session.messages.append(
|
|
aic.Message(
|
|
role="assistant",
|
|
content=str(tool_calls),
|
|
)
|
|
)
|
|
|
|
outputs.extend( self.process(nextsteps+f"\nOCR Positions: {ocr_results}", img) )
|
|
return [
|
|
{
|
|
"name": tc.function.name,
|
|
"arguments": json.loads(tc.function.arguments),
|
|
}
|
|
for tc in tool_calls
|
|
] + outputs # type: ignore
|
|
|
|
# otherwise return final assistant content
|
|
print(f"Response: {json.dumps(response.to_dict(), indent=4)}") # debug
|
|
output_text: str = response.choices[0].message.content # type: ignore
|
|
outputs.append(output_text)
|
|
self.session.messages.append(
|
|
aic.Message(role="assistant", content=output_text)
|
|
)
|
|
|
|
return [*outputs]
|
|
except Exception as e:
|
|
traceback.print_exc()
|
|
return [f"Error processing request: {str(e)}"]
|