feat: Add OCR functionality to process method; integrate Tesseract for text extraction from screenshots
This commit is contained in:
		@@ -1,5 +1,6 @@
 | 
				
			|||||||
import pyautogui
 | 
					import pyautogui
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
 | 
					import pytesseract
 | 
				
			||||||
import time, io, base64
 | 
					import time, io, base64
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from objects.inputs import MouseInput, KeyboardInput, ButtonType
 | 
					from objects.inputs import MouseInput, KeyboardInput, ButtonType
 | 
				
			||||||
@@ -28,6 +29,28 @@ def take_screenshot(cross_position: list[tuple[int, int]] | None = None) -> byte
 | 
				
			|||||||
    screenshot.save("screenshot.png", format='PNG')
 | 
					    screenshot.save("screenshot.png", format='PNG')
 | 
				
			||||||
    return buf.getvalue()
 | 
					    return buf.getvalue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def perform_ocr(screenshot: bytes) -> list[dict]:
 | 
				
			||||||
 | 
					    """Perform OCR on screenshot bytes and return list of text blocks with positions."""
 | 
				
			||||||
 | 
					    from PIL import Image # type: ignore
 | 
				
			||||||
 | 
					    import io
 | 
				
			||||||
 | 
					    # open image from bytes
 | 
				
			||||||
 | 
					    img = Image.open(io.BytesIO(screenshot))
 | 
				
			||||||
 | 
					    # perform OCR, get data dictionary
 | 
				
			||||||
 | 
					    data = pytesseract.image_to_data(img, output_type=pytesseract.Output.DICT)
 | 
				
			||||||
 | 
					    results = []
 | 
				
			||||||
 | 
					    n = len(data.get('level', []))
 | 
				
			||||||
 | 
					    for i in range(n):
 | 
				
			||||||
 | 
					        text = data['text'][i]
 | 
				
			||||||
 | 
					        if text and text.strip():
 | 
				
			||||||
 | 
					            results.append({
 | 
				
			||||||
 | 
					                'text': text,
 | 
				
			||||||
 | 
					                'left': data['left'][i],
 | 
				
			||||||
 | 
					                'top': data['top'][i],
 | 
				
			||||||
 | 
					                'width': data['width'][i],
 | 
				
			||||||
 | 
					                'height': data['height'][i]
 | 
				
			||||||
 | 
					            })
 | 
				
			||||||
 | 
					    return results
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def screenshot_to_base64(screenshot: bytes) -> str:
 | 
					def screenshot_to_base64(screenshot: bytes) -> str:
 | 
				
			||||||
    """Convert screenshot bytes to a base64 encoded string."""
 | 
					    """Convert screenshot bytes to a base64 encoded string."""
 | 
				
			||||||
    return base64.b64encode(screenshot).decode('utf-8')
 | 
					    return base64.b64encode(screenshot).decode('utf-8')
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
import traceback
 | 
					import traceback
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import openai
 | 
					import openai
 | 
				
			||||||
 | 
					import base64
 | 
				
			||||||
from flask import jsonify
 | 
					from flask import jsonify
 | 
				
			||||||
from objects import aic
 | 
					from objects import aic
 | 
				
			||||||
import ai.compute
 | 
					import ai.compute
 | 
				
			||||||
@@ -45,9 +46,22 @@ class AIProcessor:
 | 
				
			|||||||
        click_positions = []  # used for screenshot crosshair position
 | 
					        click_positions = []  # used for screenshot crosshair position
 | 
				
			||||||
        nextsteps = ""
 | 
					        nextsteps = ""
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
 | 
					            # append user prompt with optional image
 | 
				
			||||||
            self.session.messages.append(
 | 
					            self.session.messages.append(
 | 
				
			||||||
                aic.Message(role="user", content=prompt, image=img_data)
 | 
					                aic.Message(role="user", content=prompt, image=img_data)
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					            # if image provided, perform OCR and include text positions
 | 
				
			||||||
 | 
					            if img_data is not None:
 | 
				
			||||||
 | 
					                # decode base64 if needed
 | 
				
			||||||
 | 
					                try:
 | 
				
			||||||
 | 
					                    img_bytes = base64.b64decode(img_data) if isinstance(img_data, str) else img_data
 | 
				
			||||||
 | 
					                    ocr_results = ai.compute.perform_ocr(img_bytes)
 | 
				
			||||||
 | 
					                    # append OCR results as a tool message
 | 
				
			||||||
 | 
					                    self.session.messages.append(
 | 
				
			||||||
 | 
					                        aic.Message(role="tool", name="ocr", content=json.dumps(ocr_results))
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                except Exception as e:
 | 
				
			||||||
 | 
					                    logger.debug("OCR failed: %s", e)
 | 
				
			||||||
            response = self.oai.chat.completions.create(
 | 
					            response = self.oai.chat.completions.create(
 | 
				
			||||||
                model=self.model,
 | 
					                model=self.model,
 | 
				
			||||||
                messages=self.session.messages_dict(),
 | 
					                messages=self.session.messages_dict(),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,5 +6,6 @@ python-dotenv
 | 
				
			|||||||
pyautogui
 | 
					pyautogui
 | 
				
			||||||
pynput
 | 
					pynput
 | 
				
			||||||
pillow
 | 
					pillow
 | 
				
			||||||
 | 
					 pytesseract
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# --index-url https://mirrors.sustech.edu.cn/pypi/simple
 | 
					# --index-url https://mirrors.sustech.edu.cn/pypi/simple
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user