diff --git a/extract_thinker/batch_job.py b/extract_thinker/batch_job.py new file mode 100644 index 0000000..947c550 --- /dev/null +++ b/extract_thinker/batch_job.py @@ -0,0 +1,212 @@ +import asyncio +from typing import Any, List, Type, Iterator, Optional +from pydantic import BaseModel +from openai import OpenAI +from instructor.batch import BatchJob as InstructorBatchJob +import json +import os + +SLEEP_TIME = 60 + +class BatchJob: + def __init__( + self, + messages_batch: Iterator[List[dict]], + model: str, + response_model: Type[BaseModel], + file_path: str, + output_path: str, + api_key: str = os.getenv("OPENAI_API_KEY") + ): + self.response_model = response_model + self.output_path = output_path + self.file_path = file_path + self.model = model + self.client = OpenAI(api_key=api_key) + self.batch_id = None + self.file_id = None + + # Create the batch job input file (.jsonl) + InstructorBatchJob.create_from_messages( + messages_batch=messages_batch, + model=model, + file_path=file_path, + response_model=response_model + ) + + self._add_method_to_file() + + # Upload file and create batch job + self.file_id = self._upload_file() + if not self.file_id: + raise ValueError("Failed to upload file") + + self.batch_id = self._create_batch_job() + if not self.batch_id: + raise ValueError("Failed to create batch job") + + def _add_method_to_file(self) -> None: + """Transform the JSONL file to match OpenAI's batch request format.""" + with open(self.file_path, 'r') as file: + lines = file.readlines() + + with open(self.file_path, 'w') as file: + for line in lines: + data = json.loads(line) + + new_data = { + "custom_id": data["custom_id"], + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": data["params"]["model"], + "messages": data["params"]["messages"], + "max_tokens": data["params"]["max_tokens"], + "temperature": data["params"]["temperature"], + "tools": data["params"]["tools"], + "tool_choice": data["params"]["tool_choice"] + } + } + file.write(json.dumps(new_data) + '\n') + + def _upload_file(self) -> Optional[str]: + """Upload the JSONL file to OpenAI.""" + try: + with open(self.file_path, "rb") as file: + response = self.client.files.create( + file=file, + purpose="batch" + ) + return response.id + except Exception as e: + print(f"Error uploading file: {e}") + return None + + def _create_batch_job(self) -> Optional[str]: + """Create a batch job via OpenAI API.""" + try: + batch = self.client.batches.create( + input_file_id=self.file_id, + endpoint="/v1/chat/completions", + completion_window="24h" + ) + return batch.id + except Exception as e: + print(f"Error creating batch job: {e}") + return None + + async def get_status(self) -> str: + """ + Get the current status of the batch job. + Returns: queued, processing, completed, or failed + """ + try: + batch = await asyncio.to_thread( + self.client.batches.retrieve, + self.batch_id + ) + return self._map_status(batch.status) + except Exception as e: + print(f"Error getting batch status: {e}") + return "failed" + + def _map_status(self, api_status: str) -> str: + """Maps OpenAI API status to simplified status.""" + status_mapping = { + 'validating': 'queued', + 'in_progress': 'processing', + 'finalizing': 'processing', + 'completed': 'completed', + 'failed': 'failed', + 'expired': 'failed', + 'cancelling': 'processing', + 'cancelled': 'failed' + } + return status_mapping.get(api_status, 'failed') + + async def get_result(self) -> BaseModel: + """ + Wait for job completion and return parsed results using Instructor. + Returns a tuple of (parsed_results, unparsed_results). + + parsed_results: List of successfully parsed objects matching response_model + unparsed_results: List of results that failed to parse + """ + try: + # Wait until the batch is complete + while True: + status = await self.get_status() + if status == 'completed': + break + elif status == 'failed': + raise ValueError("Batch job failed") + await asyncio.sleep(SLEEP_TIME) + + # Get batch details + batch = await asyncio.to_thread( + self.client.batches.retrieve, + self.batch_id + ) + + if not batch.output_file_id: + raise ValueError("No output file ID found") + + # Download the output file + response = await asyncio.to_thread( + self.client.files.content, + batch.output_file_id + ) + + # Save the output file + with open(self.output_path, 'w') as f: + f.write(response.text) + + # Use Instructor to parse the results + parsed, unparsed = InstructorBatchJob.parse_from_file( + file_path=self.output_path, + response_model=self.response_model + ) + + return parsed[0] + + except Exception as e: + raise ValueError(f"Failed to process output file: {e}") + finally: + self._cleanup_files() + + async def cancel(self) -> bool: + """Cancel the current batch job and confirm cancellation.""" + if not self.batch_id: + print("No batch job to cancel.") + return False + + try: + await asyncio.to_thread( + self.client.batches.cancel, + self.batch_id + ) + print("Batch job canceled successfully.") + self._cleanup_files() + return True + except Exception as e: + print(f"Error cancelling batch: {e}") + return False + + def _cleanup_files(self): + """Remove temporary files and batch directory if empty""" + try: + if os.path.exists(self.file_path): + os.remove(self.file_path) + if os.path.exists(self.output_path): + os.remove(self.output_path) + + # Try to remove parent directory if empty + batch_dir = os.path.dirname(self.file_path) + if os.path.exists(batch_dir) and not os.listdir(batch_dir): + os.rmdir(batch_dir) + except Exception as e: + print(f"Warning: Failed to cleanup batch files: {e}") + + def __del__(self): + """Cleanup files when object is destroyed""" + self._cleanup_files() \ No newline at end of file diff --git a/extract_thinker/extractor.py b/extract_thinker/extractor.py index 62bba31..2bd5614 100644 --- a/extract_thinker/extractor.py +++ b/extract_thinker/extractor.py @@ -1,8 +1,9 @@ import asyncio import base64 from io import BytesIO -from typing import Any, Dict, List, Optional, IO, Union, get_origin, get_args - +from typing import Any, Dict, List, Optional, IO, Type, Union, get_origin, get_args +from instructor.batch import BatchJob +import uuid import litellm from pydantic import BaseModel from extract_thinker.document_loader.document_loader import DocumentLoader @@ -15,9 +16,10 @@ from extract_thinker.document_loader.loader_interceptor import LoaderInterceptor from extract_thinker.document_loader.llm_interceptor import LlmInterceptor from concurrent.futures import ThreadPoolExecutor, as_completed +from extract_thinker.batch_job import BatchJob + from extract_thinker.utils import ( - get_file_extension, encode_image, json_to_formatted_string, num_tokens_from_string, @@ -26,6 +28,13 @@ from copy import deepcopy class Extractor: + BATCH_SUPPORTED_MODELS = [ + "gpt-4o-mini", + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4", + ] + def __init__( self, document_loader: Optional[DocumentLoader] = None, llm: Optional[LLM] = None ): @@ -198,11 +207,6 @@ def classify_from_excel( content = self.document_loader.load_content_from_stream(path) return self._classify(content, classifications) - # def classify_with_image(self, messages: List[Dict[str, Any]]): - # resp = litellm.completion(self.llm.model, messages) - - # return ClassificationResponse(**resp.choices[0].message.content) - def _add_classification_structure(self, classification: Classification) -> str: content = "" if classification.contract: @@ -444,6 +448,154 @@ def aggregate_results(self, results: List[Any], response_model: type[BaseModel]) # Create an instance of the response_model with the aggregated_dict return response_model(**aggregated_dict) + def extract_batch( + self, + source: Union[str, IO, List[Union[str, IO]]], + response_model: Type[BaseModel], + vision: bool = False, + content: Optional[str] = None, + output_file_path: Optional[str] = None, + batch_file_path: Optional[str] = None, + ) -> BatchJob: + """ + Extracts information from a source or list of sources using batch processing. + + Args: + source: A single source (file path or IO stream) or a list of sources. + response_model: The Pydantic model to parse the response into. + vision: Whether to use vision capabilities (processing images). + content: Additional content to include in the extraction. + + Returns: + A BatchJob object to monitor and retrieve batch processing results. + """ + if not self.can_handle_batch(): + raise ValueError( + f"Model {self.llm.model} does not support batch processing. " + f"Supported models: {', '.join(self.BATCH_SUPPORTED_MODELS)}" + ) + + # Create batch directory if it doesn't exist + batch_dir = os.path.join(os.getcwd(), "extract_thinker_batch") + os.makedirs(batch_dir, exist_ok=True) + + # Generate unique paths if not provided + unique_id = str(uuid.uuid4()) + if output_file_path is None: + new_output_file_path = os.path.join(batch_dir, f"output_{unique_id}.jsonl") + else: + new_output_file_path = output_file_path + + if batch_file_path is None: + new_batch_file_path = os.path.join(batch_dir, f"input_{unique_id}.jsonl") + else: + new_batch_file_path = batch_file_path + + # Check if provided paths exist + for path in [new_output_file_path, new_batch_file_path]: + if os.path.exists(path): + raise ValueError(f"File already exists: {path}") + + self.extra_content = content + + # Ensure that sources is a list + if not isinstance(source, list): + sources = [source] + else: + sources = source + + def get_messages(): + for idx, src in enumerate(sources): + # Prepare content for each source + if vision: + # Handle vision content + if isinstance(src, str): + if os.path.exists(src): + with open(src, "rb") as f: + image_data = f.read() + else: + raise ValueError(f"File {src} does not exist.") + elif isinstance(src, IO): + image_data = src.read() + else: + raise ValueError("Invalid source type for vision data.") + + encoded_image = base64.b64encode(image_data).decode("utf-8") + image_content = f"data:image/jpeg;base64,{encoded_image}" + message_content = [ + { + "type": "image_url", + "image_url": { + "url": image_content + } + } + ] + if self.extra_content: + message_content.insert(0, {"type": "text", "text": self.extra_content}) + + messages = [ + { + "role": "system", + "content": "You are a server API that receives document information and returns specific fields in JSON format.", + }, + { + "role": "user", + "content": message_content, + }, + ] + else: + if isinstance(src, str): + if os.path.exists(src): + content_data = self.document_loader.load_content_from_file(src) + else: + content_data = src # Assume src is the text content + elif isinstance(src, IO): + content_data = self.document_loader.load_content_from_stream(src) + else: + raise ValueError("Invalid source type.") + + message_content = f"##Content\n\n{content_data}" + if self.extra_content: + message_content = f"##Extra Content\n\n{self.extra_content}\n\n" + message_content + + messages = [ + { + "role": "system", + "content": "You are a server API that receives document information and returns specific fields in JSON format.", + }, + { + "role": "user", + "content": message_content, + }, + ] + yield messages + + # Create batch job with the message generator + batch_job = BatchJob( + messages_batch=get_messages(), + model=self.llm.model, + response_model=response_model, + file_path=new_batch_file_path, + output_path=new_output_file_path + ) + + return batch_job + + def can_handle_batch(self) -> bool: + """ + Checks if the current LLM model supports batch processing. + + Returns: + bool: True if batch processing is supported, False otherwise. + """ + if not self.llm or not self.llm.model: + return False + + return any( + supported_model in self.llm.model.lower() + for supported_model in self.BATCH_SUPPORTED_MODELS + ) + def _extract( self, content, diff --git a/extract_thinker/llm.py b/extract_thinker/llm.py index 1b413d6..c9124ef 100644 --- a/extract_thinker/llm.py +++ b/extract_thinker/llm.py @@ -1,15 +1,14 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional import instructor import litellm -from extract_thinker.utils import num_tokens_from_string +from extract_thinker.models.batch_result import BatchResult from litellm import Router - class LLM: - def __init__(self, - model: str, - api_base: str = None, - api_key: str = None, + def __init__(self, + model: str, + api_base: str = None, + api_key: str = None, api_version: str = None, token_limit: int = None): self.client = instructor.from_litellm(litellm.completion, mode=instructor.Mode.MD_JSON) @@ -47,3 +46,9 @@ def request(self, messages: List[Dict[str, str]], response_model: str) -> Any: ) return response + + def batch_request(self, batch_requests: List[Dict[str, Any]]) -> str: + return self.batch_client.create_batch(batch_requests) + + def retrieve_batch_results(self, batch_id: str) -> BatchResult: + return self.batch_client.get_batch_results(batch_id) diff --git a/extract_thinker/models/batch_result.py b/extract_thinker/models/batch_result.py new file mode 100644 index 0000000..4f0aa02 --- /dev/null +++ b/extract_thinker/models/batch_result.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel +from typing import Any, List + +class BatchResult(BaseModel): + id: str + results: List[Any] \ No newline at end of file diff --git a/extract_thinker/models/batch_status.py b/extract_thinker/models/batch_status.py new file mode 100644 index 0000000..32159f0 --- /dev/null +++ b/extract_thinker/models/batch_status.py @@ -0,0 +1,7 @@ +from typing import Optional +from pydantic import BaseModel + +class BatchStatus(BaseModel): + id: str + status: str + output_file_id: Optional[str] = None \ No newline at end of file diff --git a/tests/models/ChartWithContent.py b/tests/models/ChartWithContent.py index fc66481..32d1d21 100644 --- a/tests/models/ChartWithContent.py +++ b/tests/models/ChartWithContent.py @@ -2,7 +2,6 @@ from tests.models.Chart import Chart from pydantic import Field - class ChartWithContent(Contract): content: str = Field(description="The content of the page without the chart") chart: Chart = Field(description="The chart of the page") \ No newline at end of file diff --git a/tests/test_extractor.py b/tests/test_extractor.py index 851c943..030f64c 100644 --- a/tests/test_extractor.py +++ b/tests/test_extractor.py @@ -1,13 +1,13 @@ +import asyncio import os - +import time from dotenv import load_dotenv from extract_thinker.extractor import Extractor from extract_thinker.document_loader.document_loader_tesseract import DocumentLoaderTesseract from extract_thinker.document_loader.document_loader_pypdf import DocumentLoaderPyPdf -from tests.models import ChartWithContent from tests.models.invoice import InvoiceContract +from tests.models.ChartWithContent import ChartWithContent from extract_thinker.document_loader.document_loader_azure_document_intelligence import DocumentLoaderAzureForm -import json load_dotenv() cwd = os.getcwd() @@ -32,7 +32,6 @@ def test_extract_with_tessaract_and_gpt4o_mini(): assert result.invoice_number == "0000001" assert result.invoice_date == "2014-05-07" - def test_extract_with_azure_di_and_gpt4o_mini(): subscription_key = os.getenv("AZURE_SUBSCRIPTION_KEY") endpoint = os.getenv("AZURE_ENDPOINT") @@ -95,4 +94,58 @@ def test_vision_content_pdf(): # Assert assert result is not None - # TODO: For now is sanity to test for errors \ No newline at end of file + # TODO: For now is sanity to test for errors + +def test_batch_extraction_single_source(): + # Arrange + load_dotenv() + tesseract_path = os.getenv("TESSERACT_PATH") + test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png") + + extractor = Extractor() + extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path)) + extractor.load_llm("gpt-4o-mini") + + # Act + batch_job = extractor.extract_batch(test_file_path, InvoiceContract) + + # Assert batch status + status = asyncio.run(batch_job.get_status()) + assert status in ["queued", "processing", "completed"] + print(f"Batch status: {status}") + + result = asyncio.run(batch_job.get_result()) + + # Get results and verify + assert result.invoice_number == "0000001" + assert result.invoice_date == "2014-05-07" + +def test_cancel_batch_extraction(): + # Arrange + tesseract_path = os.getenv("TESSERACT_PATH") + test_file_path = os.path.join(os.getcwd(), "tests", "test_images", "invoice.png") + batch_file_path = os.path.join(os.getcwd(), "tests", "batch_input.jsonl") + output_file_path = os.path.join(os.getcwd(), "tests", "batch_output.jsonl") + + extractor = Extractor() + extractor.load_document_loader(DocumentLoaderTesseract(tesseract_path)) + extractor.load_llm("gpt-4o-mini") + + # Act + batch_job = extractor.extract_batch( + test_file_path, + InvoiceContract, + batch_file_path=batch_file_path, + output_file_path=output_file_path + ) + + # Cancel the batch job + cancel_success = asyncio.run(batch_job.cancel()) + assert cancel_success, "Batch job cancellation failed" + + # Add a small delay to ensure cleanup has time to complete + time.sleep(1) + + # Check if files were removed + assert not os.path.exists(batch_job.file_path), f"Batch input file was not removed: {batch_job.file_path}" + assert not os.path.exists(batch_job.output_path), f"Batch output file was not removed: {batch_job.output_path}" \ No newline at end of file