Skip to content

Commit

Permalink
Merge pull request #22 from enoch3712/classificationStrategies
Browse files Browse the repository at this point in the history
Classification strategies
  • Loading branch information
enoch3712 authored Jun 17, 2024
2 parents d8e4aa4 + 828d5de commit 3d49019
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 94 deletions.
3 changes: 2 additions & 1 deletion extract_thinker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .document_loader.document_loader_pypdf import DocumentLoaderPyPdf
from .document_loader.document_loader_text import DocumentLoaderText
from .models import classification, classification_response
from .process import Process
from .process import Process, ClassificationStrategy
from .splitter import Splitter
from .image_splitter import ImageSplitter
from .models.classification import Classification
Expand All @@ -27,6 +27,7 @@
'classification',
'classification_response',
'Process',
'ClassificationStrategy',
'Splitter',
'ImageSplitter',
'Classification',
Expand Down
108 changes: 93 additions & 15 deletions extract_thinker/extractor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import base64
from io import BytesIO
from typing import Any, Dict, List, Optional, IO, Union

from pydantic import BaseModel
Expand All @@ -15,7 +17,7 @@

from extract_thinker.utils import get_file_extension, encode_image
import yaml

import litellm

SUPPORTED_IMAGE_FORMATS = ["jpeg", "png", "bmp", "tiff"]
SUPPORTED_EXCEL_FORMATS = ['.xls', '.xlsx', '.xlsm', '.xlsb', '.odf', '.ods', '.odt', '.csv']
Expand All @@ -31,7 +33,7 @@ def __init__(
self.document_loaders_by_file_type: Dict[str, DocumentLoader] = {}
self.loader_interceptors: List[LoaderInterceptor] = []
self.llm_interceptors: List[LlmInterceptor] = []
self.extra_content: Optional[str] = None
self.is_classify_image: bool = False

def add_interceptor(
self, interceptor: Union[LoaderInterceptor, LlmInterceptor]
Expand Down Expand Up @@ -113,12 +115,19 @@ def extract_from_stream(
content = self.document_loader.load(stream)
return self._extract(content, stream, response_model, vision, is_stream=True)

def classify_from_image(self, image: Any, classifications: List[Classification]):
# requires no content extraction from loader
content = {
"image": image,
}
return self._classify(content, classifications, image)

def classify_from_path(self, path: str, classifications: List[Classification]):
content = self.document_loader.load_content_from_file(path)
content = self.document_loader.load_content_from_file_list(path) if self.is_classify_image else self.document_loader.load_content_from_file(path)
return self._classify(content, classifications)

def classify_from_stream(self, stream: IO, classifications: List[Classification]):
content = self.document_loader.load_content_from_stream(stream)
content = self.document_loader.load_content_from_stream_list(stream) if self.is_classify_image else self.document_loader.load_content_from_stream(stream)
self._classify(content, classifications)

def classify_from_excel(self, path: Union[str, IO], classifications: List[Classification]):
Expand All @@ -128,28 +137,98 @@ def classify_from_excel(self, path: Union[str, IO], classifications: List[Classi
content = self.document_loader.load_content_from_stream(path)
return self._classify(content, classifications)

def _classify(self, content: str, classifications: List[Classification]):
# def classify_with_image(self, messages: List[Dict[str, Any]]):
# resp = litellm.completion(self.llm.model, messages)

# return ClassificationResponse(**resp.choices[0].message.content)

def _add_classification_structure(self, classification: Classification) -> str:
content = ""
if classification.contract:
content = "\tContract Structure:\n"
# Iterate over the fields of the contract attribute if it's not None
for name, field in classification.contract.model_fields.items():
# Extract the type and required status from the field's string representation
field_str = str(field)
field_type = field_str.split('=')[1].split(' ')[0] # Extracts the type
required = 'required' in field_str # Checks if 'required' is in the string
# Creating a string representation of the field attributes
attributes = f"required={required}"
# Append each field's details to the content string
field_details = f"\t\tName: {name}, Type: {field_type}, Attributes: {attributes}"
content += field_details + "\n"
return content

def _classify(self, content: Any, classifications: List[Classification], image: Optional[Any] = None):
messages = [
{
"role": "system",
"content": "You are a server API that receives document information "
"and returns specific fields in JSON format.",
"and returns specific fields in JSON format.\n",
},
]

input_data = (
f"##Content\n{content}\n##Classifications\n"
+ "\n".join([f"{c.name}: {c.description}" for c in classifications])
+ "\n\n##JSON Output\n"
)
if self.is_classify_image:
input_data = (
f"##Take the first image, and compare to the several images provided. Then classificationaccording to the classifcation attached to the image\n"
+ "Output Example: \n"
+ "{\r\n\t\"name\": \"DMV Form\",\r\n\t\"confidence\": 8\r\n}"
+ "\n\n##ClassificationResponse JSON Output\n"
)

messages.append({"role": "user", "content": input_data})
else:
input_data = (
f"##Content\n{content}\n##Classifications\n#if contract present, each field present increase confidence level\n"
+ "\n".join([f"{c.name}: {c.description} \n{self._add_classification_structure(c)}" for c in classifications])
+ "#Dont use contract structure, just to help on the ClassificationResponse\nOutput Example: \n"
+ "{\r\n\t\"name\": \"DMV Form\",\r\n\t\"confidence\": 8\r\n}"
+ "\n\n##ClassificationResponse JSON Output\n"
)

response = self.llm.request(messages, ClassificationResponse)
#messages.append({"role": "user", "content": input_data})

if self.is_classify_image:
messages.append(
{
"role": "user",
"content": [
{
"type": "text",
"text": input_data,
},
],
}
)
for classification in classifications:
if classification.image:
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "{classification.name}: {classification.description}"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64," + encode_image(classification.image)
},
},
],
})
else:
raise ValueError(f"Image required for classification '{classification.name}' but not found.")

response = self.llm.request(messages, ClassificationResponse)
else:
messages.append({"role": "user", "content": input_data})
response = self.llm.request(messages, ClassificationResponse)

return response

def classify(self, input: Union[str, IO], classifications: List[Classification]):
def classify(self, input: Union[str, IO], classifications: List[Classification], image: bool = False):
self.is_classify_image = image

if image:
return self.classify_from_image(input, classifications)

if isinstance(input, str):
# Check if the input is a valid file path
if os.path.isfile(input):
Expand Down Expand Up @@ -209,7 +288,6 @@ def _extract(self,
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
Expand Down
7 changes: 7 additions & 0 deletions extract_thinker/image_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
from extract_thinker.splitter import Splitter
from extract_thinker.utils import extract_json

VISION_MODELS = ["gpt-4o", "gpt-4-turbo", "model3", "claude-3-haiku-20240307", "claude-3-opus-20240229", "claude-3-sonnet-20240229"]


class ImageSplitter(Splitter):

def __init__(self, model: str):
if model not in VISION_MODELS:
raise ValueError(f"Model {model} is not supported for ImageSplitter. Supported models are {VISION_MODELS}")
self.model = model

def encode_image(self, image):
buffered = BytesIO()
image.save(buffered, format=image.format)
Expand Down
10 changes: 5 additions & 5 deletions extract_thinker/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@ def load_router(self, router: Router) -> None:
self.router = router

def request(self, messages: List[Dict[str, str]], response_model: str) -> Any:
contents = map(lambda message: message['content'], messages)
all_contents = ' '.join(contents)
max_tokens = num_tokens_from_string(all_contents)
# contents = map(lambda message: message['content'], messages)
# all_contents = ' '.join(contents)
# max_tokens = num_tokens_from_string(all_contents)

if self.router:
response = self.router.completion(
model=self.model,
max_tokens=max_tokens,
#max_tokens=max_tokens,
messages=messages,
response_model=response_model,
)
else:
response = self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
#max_tokens=max_tokens,
messages=messages,
response_model=response_model,
api_base=self.api_base,
Expand Down
8 changes: 8 additions & 0 deletions extract_thinker/models/classification.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Any, Optional
from pydantic import BaseModel
from extract_thinker.models.contract import Contract
import os


class Classification(BaseModel):
name: str
description: str
contract: Optional[Contract] = None
image: Optional[str] = None # Path to the image file
extractor: Optional[Any] = None

def set_image(self, image_path: str):
if os.path.isfile(image_path):
self.image = image_path
else:
raise ValueError(f"The provided string '{image_path}' is not a valid file path.")
4 changes: 3 additions & 1 deletion extract_thinker/models/classification_response.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pydantic import BaseModel
from typing import Optional
from pydantic import BaseModel, Field


class ClassificationResponse(BaseModel):
name: str
confidence: Optional[int] = Field("From 1 to 10. 10 being the highest confidence. Always integer", ge=1, le=10)

def __hash__(self):
return hash((self.name))
36 changes: 29 additions & 7 deletions extract_thinker/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@
)
from extract_thinker.utils import get_image_type

from enum import Enum


class ClassificationStrategy(Enum):
CONSENSUS = "consensus"
HIGHER_ORDER = "higher_order"
BOTH = "both"


class Process:
def __init__(self):
Expand Down Expand Up @@ -39,25 +47,39 @@ def load_splitter(self, splitter: Splitter):
self.splitter = splitter
return self

def add_classifyExtractor(self, extractor_groups: List[List[Extractor]]):
def add_classify_extractor(self, extractor_groups: List[List[Extractor]]):
for extractors in extractor_groups:
self.extractor_groups.append(extractors)
return self

async def _classify_async(self, extractor, file, classifications):
async def _classify_async(self, extractor: Extractor, file: str, classifications: List[Classification], image: bool = False):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, extractor.classify, file, classifications)
return await loop.run_in_executor(None, extractor.classify, file, classifications, image)

async def classify_async(self, file: str, classifications) -> Optional[Classification]:
def classify(self, file: str, classifications, strategy: ClassificationStrategy = ClassificationStrategy.CONSENSUS, threshold: int = 9, image: bool = False) -> Optional[Classification]:
result = asyncio.run(self.classify_async(file, classifications, strategy, threshold, image))

return result

async def classify_async(self, file: str, classifications, strategy: ClassificationStrategy = ClassificationStrategy.CONSENSUS, threshold: int = 9, image: str = False) -> Optional[Classification]:
for extractor_group in self.extractor_groups:
group_classifications = await asyncio.gather(*(self._classify_async(extractor, file, classifications) for extractor in extractor_group))
group_classifications = await asyncio.gather(*(self._classify_async(extractor, file, classifications, image) for extractor in extractor_group))

# Implement different strategies
if strategy == ClassificationStrategy.CONSENSUS:
# Check if all classifications in the group are the same
if len(set(group_classifications)) == 1:
return group_classifications[0]
elif strategy == ClassificationStrategy.HIGHER_ORDER:
# Pick the result with the highest confidence
return max(group_classifications, key=lambda c: c.confidence)
elif strategy == ClassificationStrategy.BOTH:
if len(set(group_classifications)) == 1:
maxResult = max(group_classifications, key=lambda c: c.confidence)
if maxResult.confidence >= threshold:
return maxResult

# If no agreement was found, return None
return None
raise ValueError("No consensus could be reached on the classification of the document. Please try again with a different strategy or threshold.")

async def classify_extractor(self, session, extractor, file):
return await session.run(extractor.classify, file)
Expand Down
Loading

0 comments on commit 3d49019

Please sign in to comment.