Skip to content

Commit b237386

Browse files
authored
Merge pull request #56 from enoch3712/55-vision-without-documentloader-is-not-working
55 vision without documentloader is not working
2 parents b317c8c + d2abd28 commit b237386

File tree

14 files changed

+1811
-1459
lines changed

14 files changed

+1811
-1459
lines changed
Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,34 @@
11
from abc import ABC
22
from io import BytesIO
3+
from typing import Any, List, Union
34
from PIL import Image
45
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
5-
from extract_thinker.utils import extract_json
6-
76

87
class DocumentLoaderLLMImage(CachedDocumentLoader, ABC):
8+
SUPPORTED_FORMATS = ['pdf', 'jpg', 'jpeg', 'png']
9+
910
def __init__(self, content=None, cache_ttl=300, llm=None):
1011
super().__init__(content, cache_ttl)
1112
self.llm = llm
1213

13-
def extract_image_content(self, image_stream: BytesIO) -> str:
14-
"""
15-
Extracts text or data from an image using an LLM.
16-
The actual implementation uses an LLM to process the image content.
17-
"""
18-
# Load the image from the stream
19-
image = Image.open(image_stream)
20-
21-
# Encode the image to base64
22-
base64_image = self.encode_image(image)
23-
24-
# Use the LLM to extract the content from the image
25-
resp = self.llm.completion(
26-
model="claude-3-sonnet-20240229",
27-
messages=[
28-
{
29-
"role": "system",
30-
"content": 'You are a worldclass Image data extractor. You receive an image and extract useful information from it. You output a JSON with the extracted information.',
31-
},
32-
{
33-
"role": "user",
34-
"content": [
35-
{
36-
"type": "image_url",
37-
"image_url": {
38-
"url": "data:image/jpeg;base64," + base64_image
39-
},
40-
},
41-
{"type": "text", "text": "###JSON Output\n"},
42-
],
43-
},
44-
],
45-
)
46-
47-
# Extract the JSON text from the response
48-
jsonText = resp.choices[0].message.content
49-
50-
# Extract the JSON from the text
51-
jsonText = extract_json(jsonText)
52-
53-
# Return the extracted content
54-
return jsonText
14+
def load_content_from_file(self, file_path: str) -> Union[str, object]:
15+
images = self.convert_to_images(file_path)
16+
results = []
17+
for _, image_bytes in images.items():
18+
image_stream = BytesIO(image_bytes)
19+
results.append({"image": image_stream})
20+
return results
21+
22+
def load_content_from_stream(self, stream: BytesIO) -> Union[str, object]:
23+
images = self.convert_to_images(stream)
24+
results = []
25+
for _, image_bytes in images.items():
26+
image_stream = BytesIO(image_bytes)
27+
results.append({"image": image_stream})
28+
return results
29+
30+
def load_content_from_stream_list(self, stream: BytesIO) -> List[Any]:
31+
return self.load_content_from_stream(stream)
32+
33+
def load_content_from_file_list(self, file_path: str) -> List[Any]:
34+
return self.load_content_from_file(file_path)

extract_thinker/document_loader/document_loader_pypdf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import io
22
from typing import Any, Dict, List, Union
33
from PyPDF2 import PdfReader
4-
from extract_thinker.document_loader.document_loader_llm_image import DocumentLoaderLLMImage
54
from extract_thinker.utils import get_file_extension
6-
5+
from extract_thinker.document_loader.cached_document_loader import CachedDocumentLoader
76
SUPPORTED_FORMATS = ['pdf']
87

9-
class DocumentLoaderPyPdf(DocumentLoaderLLMImage):
8+
class DocumentLoaderPyPdf(CachedDocumentLoader):
109
def __init__(self, content: Any = None, cache_ttl: int = 300):
1110
super().__init__(content, cache_ttl)
1211

extract_thinker/document_loader/document_loader_tesseract.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,15 @@ def process_pdf(self, stream: BytesIO) -> str:
8484

8585
extracted_text = []
8686
for page_number, image_bytes in images.items():
87-
# Check if image_bytes is not empty and has the expected structure
88-
# if not image_bytes or not isinstance(image_bytes, (list, tuple)):
89-
# print(f"Skipping page {page_number}: Invalid image data")
90-
# continue
91-
92-
# image = BytesIO(image_bytes[0])
93-
text = self.process_image(image_bytes)
87+
# Convert image data to proper BytesIO stream
88+
if isinstance(image_bytes, bytes):
89+
image_stream = BytesIO(image_bytes)
90+
elif isinstance(image_bytes, BytesIO):
91+
image_stream = image_bytes
92+
else:
93+
raise ValueError(f"Unexpected image data type for page {page_number}: {type(image_bytes)}")
94+
95+
text = self.process_image(image_stream)
9496
extracted_text.append(text)
9597

9698
if not extracted_text:
@@ -117,14 +119,22 @@ def process_image(self, image: BytesIO) -> str:
117119

118120
def worker(self, input_queue: Queue, output_queue: Queue):
119121
while True:
120-
image = input_queue.get()
121-
if image is None: # Sentinel to indicate shutdown
122+
image_data = input_queue.get()
123+
if image_data is None: # Sentinel to indicate shutdown
122124
break
123125
try:
124-
text = self.process_image(image)
125-
output_queue.put((image, text))
126+
# Convert bytes to BytesIO if needed
127+
if isinstance(image_data, bytes):
128+
image_stream = BytesIO(image_data)
129+
elif isinstance(image_data, BytesIO):
130+
image_stream = image_data
131+
else:
132+
raise ValueError(f"Unexpected image data type: {type(image_data)}")
133+
134+
text = self.process_image(image_stream)
135+
output_queue.put((image_data, text))
126136
except Exception as e:
127-
output_queue.put((image, str(e)))
137+
output_queue.put((image_data, str(e)))
128138
input_queue.task_done()
129139

130140
@cachedmethod(cache=attrgetter('cache'), key=lambda self, stream: hashkey(id(stream)))

extract_thinker/extractor.py

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import litellm
77
from pydantic import BaseModel
88
from extract_thinker.document_loader.document_loader import DocumentLoader
9+
from extract_thinker.document_loader.document_loader_llm_image import DocumentLoaderLLMImage
910
from extract_thinker.models.classification import Classification
1011
from extract_thinker.models.classification_response import ClassificationResponse
1112
from extract_thinker.llm import LLM
@@ -84,6 +85,12 @@ def extract(
8485
if not issubclass(response_model, BaseModel):
8586
raise ValueError("response_model must be a subclass of Pydantic's BaseModel.")
8687

88+
if vision and not self.get_document_loader_for_file(source):
89+
if not litellm.supports_vision(self.llm.model):
90+
raise ValueError(f"Model {self.llm.model} does not support vision. Please provide a document loader or a model that supports vision.")
91+
else:
92+
self.document_loader = DocumentLoaderLLMImage(llm=self.llm)
93+
8794
if isinstance(source, str):
8895
if os.path.exists(source):
8996
return self.extract_from_file(source, response_model, vision)
@@ -449,55 +456,84 @@ def _extract(
449456
for interceptor in self.llm_interceptors:
450457
interceptor.intercept(self.llm)
451458

452-
messages = [
453-
{
454-
"role": "system",
455-
"content": "You are a server API that receives document information "
456-
"and returns specific fields in JSON format.",
457-
},
458-
]
459-
460-
if self.extra_content is not None:
461-
if isinstance(self.extra_content, dict):
462-
self.extra_content = yaml.dump(self.extra_content)
463-
messages.append(
464-
{
465-
"role": "user",
466-
"content": "##Extra Content\n\n" + self.extra_content,
467-
}
468-
)
469-
470-
if content is not None:
471-
if isinstance(content, dict):
472-
if content.get("is_spreadsheet", False):
473-
content = json_to_formatted_string(content.get("data", {}))
474-
content = yaml.dump(content, default_flow_style=True)
475-
messages.append(
476-
{"role": "user", "content": "##Content\n\n" + content}
477-
)
478-
479459
if vision:
480460
if not litellm.supports_vision(model=self.llm.model):
481461
raise ValueError(
482462
f"Model {self.llm.model} is not supported for vision, since it's not a vision model."
483463
)
484464

485-
base64_encoded_image = encode_image(file_or_stream, is_stream)
465+
# Initialize the content list for the message
466+
message_content = []
467+
468+
# Add text content if it exists
469+
if isinstance(content, str):
470+
message_content.append({
471+
"type": "text",
472+
"text": content
473+
})
474+
475+
# Add images
476+
if isinstance(content, list): # Assuming content is a list of dicts with 'image' key
477+
for page in content:
478+
if 'image' in page:
479+
base64_image = encode_image(page['image'])
480+
message_content.append({
481+
"type": "image_url",
482+
"image_url": {
483+
"url": f"data:image/jpeg;base64,{base64_image}"
484+
}
485+
})
486486

487+
# Create the messages array with the correct structure
487488
messages = [
489+
{
490+
"role": "system",
491+
"content": "You are a server API that receives document information and returns specific fields in JSON format.",
492+
},
488493
{
489494
"role": "user",
490-
"content": [
491-
{
492-
"type": "image_url",
493-
"image_url": {
494-
"url": "data:image/jpeg;base64," + base64_encoded_image
495-
},
496-
},
497-
],
495+
"content": message_content
498496
}
499497
]
500498

499+
# Add extra content if it exists
500+
if self.extra_content is not None:
501+
if isinstance(self.extra_content, dict):
502+
self.extra_content = yaml.dump(self.extra_content)
503+
messages.insert(1, {
504+
"role": "user",
505+
"content": [{"type": "text", "text": "##Extra Content\n\n" + self.extra_content}]
506+
})
507+
508+
else:
509+
# Non-vision logic remains the same
510+
messages = [
511+
{
512+
"role": "system",
513+
"content": "You are a server API that receives document information and returns specific fields in JSON format.",
514+
},
515+
]
516+
517+
if self.extra_content is not None:
518+
if isinstance(self.extra_content, dict):
519+
self.extra_content = yaml.dump(self.extra_content)
520+
messages.append(
521+
{
522+
"role": "user",
523+
"content": "##Extra Content\n\n" + self.extra_content,
524+
}
525+
)
526+
527+
if content is not None:
528+
if isinstance(content, dict):
529+
if content.get("is_spreadsheet", False):
530+
content = json_to_formatted_string(content.get("data", {}))
531+
content = yaml.dump(content, default_flow_style=True)
532+
if isinstance(content, str):
533+
messages.append(
534+
{"role": "user", "content": "##Content\n\n" + content}
535+
)
536+
501537
if self.llm.token_limit:
502538
max_tokens_per_request = self.llm.token_limit - 1000
503539
content_tokens = num_tokens_from_string(str(content))

extract_thinker/utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,34 @@
1010
from io import BytesIO
1111
from typing import Union
1212

13-
def encode_image(image_path):
14-
with open(image_path, "rb") as image_file:
15-
return base64.b64encode(image_file.read()).decode("utf-8")
13+
def encode_image(image_source: Union[str, BytesIO]) -> str:
14+
"""
15+
Encode an image to base64 string from either a file path or BytesIO stream.
16+
17+
Args:
18+
image_source (Union[str, BytesIO]): The image source, either a file path or BytesIO stream
19+
20+
Returns:
21+
str: Base64 encoded string of the image
22+
"""
23+
try:
24+
if isinstance(image_source, str):
25+
with open(image_source, "rb") as image_file:
26+
return base64.b64encode(image_file.read()).decode("utf-8")
27+
elif isinstance(image_source, BytesIO):
28+
# Save current position
29+
current_position = image_source.tell()
30+
# Move to start of stream
31+
image_source.seek(0)
32+
# Encode stream content
33+
encoded = base64.b64encode(image_source.read()).decode("utf-8")
34+
# Restore original position
35+
image_source.seek(current_position)
36+
return encoded
37+
else:
38+
raise ValueError("Image source must be either a file path (str) or BytesIO stream")
39+
except Exception as e:
40+
raise Exception(f"Failed to encode image: {str(e)}")
1641

1742
def is_pdf_stream(stream: Union[BytesIO, str]) -> bool:
1843
"""

0 commit comments

Comments
 (0)