Compare commits
38 Commits
84d65cb505
...
main
Author | SHA1 | Date | |
---|---|---|---|
6b13586154 | |||
7192f4bc18 | |||
36cfeffe9c | |||
7f5296b2ef | |||
e5b3ea8b57 | |||
ff7c362cfe | |||
b035bee682 | |||
c2fb041285 | |||
4369611610 | |||
93a01b792b | |||
3d5f71ec84 | |||
20f05ca991 | |||
859e1c2f0b | |||
d9a9eba4c7 | |||
b89051a37f | |||
72a876410c | |||
46a5bce956 | |||
e639e1edd3 | |||
9bd15d45c5 | |||
105ab4a04b | |||
5be7f9aadb | |||
20764d5d19 | |||
158529a2bd | |||
b583094e20 | |||
d7c4f9b0cb | |||
035252c146 | |||
892f41f78a | |||
0af7dc7699 | |||
2bcddedca5 | |||
b881f04acc | |||
670066100f | |||
52c455b20c | |||
a4e078bc19 | |||
1925a77d85 | |||
e573ecb553 | |||
f7feb12946 | |||
66330bfc73 | |||
41f7d0e210 |
112
ai/compute.py
112
ai/compute.py
@@ -1,18 +1,99 @@
|
||||
import pyautogui
|
||||
import threading
|
||||
import pytesseract
|
||||
import time, io, base64
|
||||
import sys
|
||||
from objects.inputs import MouseInput, KeyboardInput, ButtonType
|
||||
from PIL import ImageGrab, ImageDraw # type: ignore
|
||||
from objects import logger as logger_module
|
||||
import logging
|
||||
logger: logging.Logger = logger_module.get_logger(__name__)
|
||||
|
||||
def take_screenshot(cross_position: list[tuple[int, int]] | None = None) -> bytes:
|
||||
"""Take a screenshot of the current screen and return it as bytes."""
|
||||
screenshot = ImageGrab.grab()
|
||||
buf = io.BytesIO()
|
||||
|
||||
# Optionally draw a crosshair at the specified position
|
||||
if cross_position:
|
||||
for pos in cross_position:
|
||||
x, y = pos
|
||||
draw = ImageDraw.Draw(screenshot)
|
||||
size = 20 # half‐length of each arm
|
||||
color = (255, 0, 0)
|
||||
width = 2
|
||||
# horizontal line
|
||||
draw.line((x - size, y, x + size, y), fill=color, width=width)
|
||||
# vertical line
|
||||
draw.line((x, y - size, x, y + size), fill=color, width=width)
|
||||
|
||||
screenshot.save(buf, format='PNG')
|
||||
# save in a file
|
||||
screenshot.save("screenshot.png", format='PNG')
|
||||
return buf.getvalue()
|
||||
|
||||
def perform_ocr(screenshot: bytes) -> list[dict]:
|
||||
"""Perform OCR on screenshot bytes and return list of text blocks with positions."""
|
||||
from PIL import Image # type: ignore
|
||||
import io
|
||||
# open image from bytes
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
# perform OCR, get data dictionary
|
||||
data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
|
||||
results = []
|
||||
n = len(data.get('level', []))
|
||||
for i in range(n):
|
||||
text = data['text'][i]
|
||||
if text and text.strip():
|
||||
# Fix the center-point calculation (add first, then divide)
|
||||
results.append({
|
||||
'text': text,
|
||||
'x': data['left'][i] + data['width'][i] // 2,
|
||||
'y': data['top'][i] + data['height'][i] // 2
|
||||
})
|
||||
|
||||
# check if debug is enabled
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
# take screenshot + put blue circle with x, y on screenshot for each component
|
||||
screenshot_with_circles = Image.open(io.BytesIO(screenshot))
|
||||
draw = ImageDraw.Draw(screenshot_with_circles)
|
||||
for result in results:
|
||||
x, y = result['x'], result['y']
|
||||
size = 10
|
||||
color = (0, 0, 255) # blue
|
||||
width = 2
|
||||
# horizontal line
|
||||
draw.line((x - size, y, x + size, y), fill=color, width=width)
|
||||
# vertical line
|
||||
draw.line((x, y - size, x, y + size), fill=color, width=width)
|
||||
screenshot_with_circles.save("screenshot_with_circles.png", format='PNG')
|
||||
# save in a file
|
||||
logger.debug("Debug, saving ocr results screenshot with circles")
|
||||
screenshot_with_circles.save("ocr_results.png", format='PNG')
|
||||
return results
|
||||
|
||||
def screenshot_to_base64(screenshot: bytes) -> str:
|
||||
"""Convert screenshot bytes to a base64 encoded string."""
|
||||
return base64.b64encode(screenshot).decode('utf-8')
|
||||
|
||||
def show_click_indicator(x: int, y: int, duration: float = 2.0, size: int = 50) -> None:
|
||||
"""Display a red circle at (x, y) for the given duration, can be clicked through."""
|
||||
pass
|
||||
|
||||
def press_mouse(mouse_input: MouseInput) -> None:
|
||||
"""Presses mouse buttons at the given position."""
|
||||
x, y = mouse_input.x, mouse_input.y
|
||||
button = mouse_input.click_type
|
||||
if button == ButtonType.LEFT:
|
||||
if button == "left":
|
||||
pyautogui.click(x, y, button='left')
|
||||
elif button == ButtonType.DOUBLE_LEFT:
|
||||
elif button == "double_left":
|
||||
pyautogui.doubleClick(x, y)
|
||||
elif button == ButtonType.RIGHT:
|
||||
elif button == "right":
|
||||
pyautogui.click(x, y, button='right')
|
||||
elif button == ButtonType.MIDDLE:
|
||||
elif button == "middle":
|
||||
pyautogui.click(x, y, button='middle')
|
||||
# Show red circle indicator at click position for 2 seconds
|
||||
threading.Thread(target=show_click_indicator, args=(x, y), daemon=True).start()
|
||||
|
||||
def press_keyboard(keyboard_input: KeyboardInput) -> None:
|
||||
"""Types the given sequence of keys."""
|
||||
@@ -22,8 +103,29 @@ def press_keyboard(keyboard_input: KeyboardInput) -> None:
|
||||
if keyboard_input.press_enter:
|
||||
pyautogui.press('enter')
|
||||
|
||||
def _execute(name, args):
|
||||
def wait(duration: float) -> None:
|
||||
"""Waits for the specified duration in seconds."""
|
||||
time.sleep(duration)
|
||||
|
||||
def search_pc(query: str) -> None:
|
||||
"""Presses the Windows key."""
|
||||
pyautogui.hotkey('win')
|
||||
wait(4)
|
||||
press_keyboard(KeyboardInput(text=query))
|
||||
|
||||
def reprompt(nextsteps: str, processor) -> None:
|
||||
"""Re-execute GPT and take a new screenshot."""
|
||||
scr = screenshot_to_base64(take_screenshot())
|
||||
return processor.process(nextsteps, img_data=scr)
|
||||
|
||||
def _execute(name, args=[], processor=None):
|
||||
if name == "click_button":
|
||||
press_mouse(MouseInput(**args))
|
||||
elif name == "type_text":
|
||||
press_keyboard(KeyboardInput(**args))
|
||||
elif name == "wait":
|
||||
wait(**args)
|
||||
elif name == "search_pc":
|
||||
search_pc(**args)
|
||||
elif name == "reprompt":
|
||||
reprompt(**args, processor=processor)
|
||||
|
@@ -1,8 +1,14 @@
|
||||
import traceback
|
||||
import json # new
|
||||
import json
|
||||
import openai
|
||||
import base64
|
||||
from flask import jsonify
|
||||
from objects import aic
|
||||
import ai.compute
|
||||
from objects import logger as logger_module
|
||||
import logging
|
||||
|
||||
logger: logging.Logger = logger_module.get_logger(__name__)
|
||||
|
||||
class AIProcessor:
|
||||
def __init__(self, api_key: str, model: str = "gpt-4.1"):
|
||||
@@ -34,11 +40,17 @@ class AIProcessor:
|
||||
return f"Error executing {name}: {e}"
|
||||
|
||||
# -------------------------- main entry -------------------------- #
|
||||
def process(self, prompt: str, img_data: str | bytes | None = None) -> str | list[dict]:
|
||||
def process(self, prompt: str, img_data: str | bytes | None = None) -> list[str | dict]:
|
||||
outputs = [] # type: list[str | dict]
|
||||
reexec = True
|
||||
click_positions = [] # used for screenshot crosshair position
|
||||
nextsteps = ""
|
||||
try:
|
||||
# append user prompt with optional image
|
||||
self.session.messages.append(
|
||||
aic.Message(role="user", content=prompt, image=img_data)
|
||||
)
|
||||
# if image provided, perform OCR and include text positions
|
||||
response = self.oai.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=self.session.messages_dict(),
|
||||
@@ -49,25 +61,84 @@ class AIProcessor:
|
||||
tool_calls = getattr(response.choices[0].message, "tool_calls", None)
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
ai.compute._execute(
|
||||
name=tc.function.name,
|
||||
args=json.loads(tc.function.arguments)
|
||||
ags = json.loads(tc.function.arguments)
|
||||
logger.debug(
|
||||
"Processing tool call: %s with arguments: %s",
|
||||
tc.function.name,
|
||||
tc.function.arguments,
|
||||
)
|
||||
if tc.function.name == "confirm":
|
||||
reexec = False
|
||||
try:
|
||||
nextsteps = ags.get("goal", "")
|
||||
except:
|
||||
nextsteps = str(tc.function.arguments)
|
||||
print('ERROR NEXT STEPS IS STR, ', nextsteps)
|
||||
if tc.function.name == "click_button":
|
||||
# extract click position for screenshot crosshair
|
||||
click_positions.append((ags.get("x", 0), ags.get("y", 0)))
|
||||
r = ai.compute._execute(
|
||||
name=tc.function.name,
|
||||
args=json.loads(tc.function.arguments),
|
||||
processor=self,
|
||||
)
|
||||
outputs.append(r) if r else None
|
||||
# Make sure every images except the two last are removed
|
||||
for msg in self.session.messages[:-3]:
|
||||
if msg.image and not msg.disable_image:
|
||||
msg.image = None
|
||||
# copy of self.session.messages, but shorten the image data for better debugging
|
||||
cps = [
|
||||
aic.Message(
|
||||
role=msg.role,
|
||||
content=msg.content[:80],
|
||||
image=msg.image[:20] if isinstance(msg.image, str) else None, # type: ignore
|
||||
disable_image=msg.disable_image,
|
||||
name=msg.name,
|
||||
)
|
||||
for msg in self.session.messages
|
||||
]
|
||||
logger.debug(
|
||||
"Shortened message copies for processing: %s", cps
|
||||
)
|
||||
if reexec:
|
||||
img_bytes = ai.compute.take_screenshot(cross_position=click_positions)
|
||||
img = ai.compute.screenshot_to_base64(
|
||||
img_bytes
|
||||
)
|
||||
|
||||
ocr_results = []
|
||||
try:
|
||||
ocr_results = ai.compute.perform_ocr(img_bytes)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.debug("OCR failed: %s", e)
|
||||
|
||||
self.session.messages.append(
|
||||
aic.Message(
|
||||
role="assistant",
|
||||
content=str(((tc.function.name, tc.function.arguments) for tc in tool_calls)),
|
||||
)
|
||||
)
|
||||
|
||||
outputs.extend( self.process(nextsteps+f"\nOCR Positions: {ocr_results}", img) )
|
||||
return [
|
||||
{
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments),
|
||||
}
|
||||
for tc in tool_calls
|
||||
]
|
||||
] + outputs # type: ignore
|
||||
|
||||
# otherwise return final assistant content
|
||||
print(f"Response: {json.dumps(response.to_dict(), indent=4)}") # debug
|
||||
output_text: str = response.choices[0].message.content # type: ignore
|
||||
outputs.append(output_text)
|
||||
self.session.messages.append(
|
||||
aic.Message(role="assistant", content=output_text)
|
||||
aic.Message(role="assistant", content="Executed: " + (str(*outputs)))
|
||||
)
|
||||
return output_text
|
||||
|
||||
return [*outputs]
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return f"Error processing request: {str(e)}"
|
||||
return [f"Error processing request: {str(e)}"]
|
||||
|
2
main.py
2
main.py
@@ -12,7 +12,7 @@ def main():
|
||||
model=os.getenv("OPENAI_MODEL", "gpt-4.1")
|
||||
)
|
||||
server = webserver.web.WebServerApp(aip)
|
||||
server.run()
|
||||
server.run(host="0.0.0.0", port=int(os.getenv("PORT", 5000)), debug=int(os.getenv("DEBUG", 0)) > 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -8,7 +8,8 @@ SYSTEM_PROMPT = """
|
||||
You are CopeAI Windows Agent. You are currently controlling a Windows 11 machine. \
|
||||
You are capable to see the screen, click buttons, type text, and interact with the system. \
|
||||
You will use the functions provided. The resolution of the machine is 1920x1080. \
|
||||
Your text response must indicate what you are doing."""
|
||||
Your text response must indicate what you are doing. If the place where you clicked seems incorrect, \
|
||||
you will use everything you can to find the position of the location of the goal and click again. You will see a red cross on where you previously clicked."""
|
||||
|
||||
FUNCTIONS = [
|
||||
{
|
||||
@@ -30,7 +31,7 @@ FUNCTIONS = [
|
||||
"click_type": {
|
||||
"type": "string",
|
||||
"enum": ["left", "double_left", "middle", "right"],
|
||||
"description": "The type of mouse click to perform."
|
||||
"description": "The type of mouse click to perform. `double_left` is a double click, used to open apps or files."
|
||||
}
|
||||
},
|
||||
"required": ["click_type", "x", "y"],
|
||||
@@ -58,7 +59,75 @@ FUNCTIONS = [
|
||||
"required": ["text", "press_enter"],
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "wait",
|
||||
"description": "Wait for a specified amount of time.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"duration": {
|
||||
"type": "number",
|
||||
"description": "The duration to wait in seconds."
|
||||
}
|
||||
},
|
||||
"required": ["duration"],
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "reprompt",
|
||||
"description": "After doing what you had to do, re-execute once again with a new screenshot.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nextsteps": {
|
||||
"type": "string",
|
||||
"description": "The new steps to perform."
|
||||
}
|
||||
},
|
||||
"required": ["nextsteps"],
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "confirm",
|
||||
"description": "Confirm that the task is completed and no further actions are needed. ONLY execute this when you fulfilled the user's request. This can be the only function called.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"goal": {
|
||||
"type": "string",
|
||||
"description": "The goal that was achieved."
|
||||
}
|
||||
},
|
||||
"required": ["goal"],
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_pc",
|
||||
"description": "Open the start menu, then searches for content. Use to open apps, open file explorer, or search the web. Use this in priority!!!",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "The search query to perform."
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
44
objects/logger.py
Normal file
44
objects/logger.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Configuration values
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
LOG_DIR = os.getenv("LOG_DIR", os.path.join(os.getcwd(), "logs"))
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# Log file path
|
||||
LOG_FILE = os.path.join(LOG_DIR, "app.log")
|
||||
|
||||
# Create root logger
|
||||
logger = logging.getLogger("gpt-agent")
|
||||
logger.setLevel(LOG_LEVEL)
|
||||
|
||||
# Formatter
|
||||
formatter = logging.Formatter(LOG_FORMAT)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(LOG_LEVEL)
|
||||
console_handler.setFormatter(formatter)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# Rotating file handler
|
||||
file_handler = RotatingFileHandler(LOG_FILE, maxBytes=5*1024*1024, backupCount=5)
|
||||
file_handler.setLevel(LOG_LEVEL)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
def get_logger(name: str | None = None) -> logging.Logger:
|
||||
"""
|
||||
Retrieve a configured logger instance. If name is provided,
|
||||
returns a child logger of the configured root logger.
|
||||
"""
|
||||
if name:
|
||||
return logger.getChild(name)
|
||||
return logger
|
@@ -5,4 +5,7 @@ python-dotenv
|
||||
# libraries to control mouse+keyboard+see screen
|
||||
pyautogui
|
||||
pynput
|
||||
Pillow
|
||||
pillow
|
||||
pytesseract
|
||||
|
||||
# --index-url https://mirrors.sustech.edu.cn/pypi/simple
|
||||
|
@@ -2,7 +2,7 @@ from flask import Flask, request, jsonify
|
||||
import os, ai.processor
|
||||
from dotenv import load_dotenv
|
||||
import io
|
||||
from PIL import ImageGrab
|
||||
from PIL import ImageGrab # type: ignore
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -27,8 +27,6 @@ class WebServerApp:
|
||||
# Process the data as needed
|
||||
prompt = data.get('prompt', '')
|
||||
|
||||
|
||||
|
||||
if not prompt:
|
||||
return jsonify({"error": "No prompt provided"}), 400
|
||||
img_data = None
|
||||
@@ -40,6 +38,7 @@ class WebServerApp:
|
||||
img_data = None
|
||||
else:
|
||||
if 'host_screenshot' in data:
|
||||
print('Taking screenshot...')
|
||||
# take a screenshot right here
|
||||
# capture the full screen
|
||||
screenshot_img = ImageGrab.grab()
|
||||
|
Reference in New Issue
Block a user