Skip to content

Commit 7fd1212

Browse files
committed
feat: ByteStream auto mime_type detection and base64 (de)encoding
1 parent 7ade6a2 commit 7fd1212

File tree

4 files changed

+247
-2
lines changed

4 files changed

+247
-2
lines changed

haystack_experimental/dataclasses/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from haystack_experimental.dataclasses.byte_stream import ByteStream
56
from haystack_experimental.dataclasses.chat_message import (
67
ChatMessage,
78
ChatMessageContentT,
@@ -18,12 +19,13 @@
1819

1920
__all__ = [
2021
"AsyncStreamingCallbackT",
22+
"ByteStream",
2123
"ChatMessage",
24+
"ChatMessageContentT",
2225
"ChatRole",
2326
"StreamingCallbackT",
27+
"TextContent",
2428
"ToolCall",
2529
"ToolCallResult",
26-
"TextContent",
27-
"ChatMessageContentT",
2830
"Tool",
2931
]
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
"""
5+
Data classes for representing binary data in the Haystack API. The ByteStream class can be used to represent binary data
6+
in the API, and can be converted to and from base64 encoded strings, dictionaries, and files. This is particularly
7+
useful for representing media files in chat messages.
8+
"""
9+
10+
import logging
11+
import mimetypes
12+
from base64 import b64encode, b64decode
13+
from dataclasses import dataclass, field
14+
from pathlib import Path
15+
from typing import Any, Dict, Optional
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@dataclass
22+
class ByteStream:
23+
"""
24+
Base data class representing a binary object in the Haystack API.
25+
"""
26+
27+
data: bytes
28+
meta: Dict[str, Any] = field(default_factory=dict, hash=False)
29+
mime_type: Optional[str] = field(default=None)
30+
31+
@property
32+
def type(self) -> Optional[str]:
33+
"""
34+
Return the type of the ByteStream. This is the first part of the mime type, or None if the mime type is not set.
35+
36+
:return: The type of the ByteStream.
37+
"""
38+
if self.mime_type:
39+
return self.mime_type.split("/", maxsplit=1)[0]
40+
return None
41+
42+
@property
43+
def subtype(self) -> Optional[str]:
44+
"""
45+
Return the subtype of the ByteStream. This is the second part of the mime type,
46+
or None if the mime type is not set.
47+
48+
:return: The subtype of the ByteStream.
49+
"""
50+
if self.mime_type:
51+
return self.mime_type.split("/", maxsplit=1)[-1]
52+
return None
53+
54+
def to_file(self, destination_path: Path):
55+
"""
56+
Write the ByteStream to a file. Note: the metadata will be lost.
57+
58+
:param destination_path: The path to write the ByteStream to.
59+
"""
60+
with open(destination_path, "wb") as fd:
61+
fd.write(self.data)
62+
63+
@classmethod
64+
def from_file_path(
65+
cls, filepath: Path, mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None
66+
) -> "ByteStream":
67+
"""
68+
Create a ByteStream from the contents read from a file.
69+
70+
:param filepath: A valid path to a file.
71+
:param mime_type: The mime type of the file.
72+
:param meta: Additional metadata to be stored with the ByteStream.
73+
"""
74+
if mime_type is None:
75+
mime_type = mimetypes.guess_type(filepath)[0]
76+
if mime_type is None:
77+
logger.warning("Could not determine mime type for file %s", filepath)
78+
79+
with open(filepath, "rb") as fd:
80+
return cls(data=fd.read(), mime_type=mime_type, meta=meta or {})
81+
82+
@classmethod
83+
def from_string(
84+
cls, text: str, encoding: str = "utf-8", mime_type: Optional[str] = None, meta: Optional[Dict[str, Any]] = None
85+
) -> "ByteStream":
86+
"""
87+
Create a ByteStream encoding a string.
88+
89+
:param text: The string to encode
90+
:param encoding: The encoding used to convert the string into bytes
91+
:param mime_type: The mime type of the file.
92+
:param meta: Additional metadata to be stored with the ByteStream.
93+
"""
94+
return cls(data=text.encode(encoding), mime_type=mime_type, meta=meta or {})
95+
96+
def to_string(self, encoding: str = "utf-8") -> str:
97+
"""
98+
Convert the ByteStream to a string, metadata will not be included.
99+
100+
:param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8".
101+
:returns: The string representation of the ByteStream.
102+
:raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
103+
"""
104+
return self.data.decode(encoding)
105+
106+
@classmethod
107+
def from_base64(
108+
cls,
109+
base64_string: str,
110+
encoding: str = "utf-8",
111+
meta: Optional[Dict[str, Any]] = None,
112+
mime_type: Optional[str] = None,
113+
) -> "ByteStream":
114+
"""
115+
Create a ByteStream from a base64 encoded string.
116+
117+
:param base64_string: The base64 encoded string representation of the ByteStream data.
118+
:param encoding: The encoding used to convert the base64 string into bytes.
119+
:param meta: Additional metadata to be stored with the ByteStream.
120+
:param mime_type: The mime type of the file.
121+
:returns: A new ByteStream instance.
122+
"""
123+
return cls(data=b64decode(base64_string.encode(encoding)), meta=meta or {}, mime_type=mime_type)
124+
125+
def to_base64(self, encoding: str = "utf-8") -> str:
126+
"""
127+
Convert the ByteStream data to a base64 encoded string.
128+
129+
:returns: The base64 encoded string representation of the ByteStream data.
130+
"""
131+
return b64encode(self.data).decode(encoding)
132+
133+
@classmethod
134+
def from_dict(cls, data: Dict[str, Any], encoding: str = "utf-8") -> "ByteStream":
135+
"""
136+
Create a ByteStream from a dictionary.
137+
138+
:param data: The dictionary representation of the ByteStream.
139+
:param encoding: The encoding used to convert the base64 string into bytes.
140+
:returns: A new ByteStream instance.
141+
"""
142+
return cls.from_base64(data["data"], encoding=encoding, meta=data.get("meta"), mime_type=data.get("mime_type"))
143+
144+
def to_dict(self, encoding: str = "utf-8"):
145+
"""
146+
Convert the ByteStream to a dictionary.
147+
148+
:param encoding: The encoding used to convert the bytes to a string. Defaults to "utf-8".
149+
:returns: The dictionary representation of the ByteStream.
150+
:raises: UnicodeDecodeError: If the ByteStream data cannot be decoded with the specified encoding.
151+
"""
152+
return {"data": self.to_base64(encoding=encoding), "meta": self.meta, "mime_type": self.mime_type}

test/dataclasses/__init__.py

Whitespace-only changes.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import pytest
2+
from base64 import b64encode
3+
from pathlib import Path
4+
from unittest.mock import mock_open, patch
5+
6+
from haystack_experimental.dataclasses.byte_stream import ByteStream
7+
8+
@pytest.fixture
9+
def byte_stream():
10+
test_data = b"test data"
11+
test_meta = {"key": "value"}
12+
test_mime = "text/plain"
13+
return ByteStream(data=test_data, meta=test_meta, mime_type=test_mime)
14+
15+
def test_init(byte_stream):
16+
assert byte_stream.data == b"test data"
17+
assert byte_stream.meta == {"key": "value"}
18+
assert byte_stream.mime_type == "text/plain"
19+
20+
def test_type_property(byte_stream):
21+
assert byte_stream.type == "text"
22+
stream_without_mime = ByteStream(data=b"test data")
23+
assert stream_without_mime.type is None
24+
25+
def test_subtype_property(byte_stream):
26+
assert byte_stream.subtype == "plain"
27+
stream_without_mime = ByteStream(data=b"test data")
28+
assert stream_without_mime.subtype is None
29+
30+
@patch("builtins.open", new_callable=mock_open)
31+
def test_to_file(mock_file, byte_stream):
32+
path = Path("test.txt")
33+
byte_stream.to_file(path)
34+
mock_file.assert_called_once_with(path, "wb")
35+
mock_file().write.assert_called_once_with(b"test data")
36+
37+
@patch("builtins.open", new_callable=mock_open, read_data=b"test data")
38+
def test_from_file_path(mock_file):
39+
path = Path("test.txt")
40+
with patch("mimetypes.guess_type", return_value=("text/plain", None)):
41+
byte_stream = ByteStream.from_file_path(path)
42+
assert byte_stream.data == b"test data"
43+
assert byte_stream.mime_type == "text/plain"
44+
45+
@patch("mimetypes.guess_type", return_value=(None, None))
46+
@patch("haystack_experimental.dataclasses.byte_stream.logger.warning")
47+
def test_from_file_path_unknown_mime(mock_warning, _, byte_stream):
48+
path = Path("test.txt")
49+
with patch("builtins.open", new_callable=mock_open, read_data=b"test data"):
50+
byte_stream = ByteStream.from_file_path(path)
51+
assert byte_stream.mime_type is None
52+
mock_warning.assert_called_once()
53+
54+
def test_from_string():
55+
text = "Hello, World!"
56+
byte_stream = ByteStream.from_string(text, mime_type="text/plain")
57+
assert byte_stream.data == text.encode("utf-8")
58+
assert byte_stream.mime_type == "text/plain"
59+
60+
def test_to_string():
61+
byte_stream = ByteStream(data=b"Hello, World!")
62+
assert byte_stream.to_string() == "Hello, World!"
63+
64+
def test_from_base64():
65+
base64_string = b64encode(b"test data").decode("utf-8")
66+
byte_stream = ByteStream.from_base64(base64_string, mime_type="text/plain")
67+
assert byte_stream.data == b"test data"
68+
assert byte_stream.mime_type == "text/plain"
69+
70+
def test_to_base64(byte_stream):
71+
expected = b64encode(b"test data").decode("utf-8")
72+
assert byte_stream.to_base64() == expected
73+
74+
def test_from_dict():
75+
data = {
76+
"data": b64encode(b"test data").decode("utf-8"),
77+
"meta": {"key": "value"},
78+
"mime_type": "text/plain",
79+
}
80+
byte_stream = ByteStream.from_dict(data)
81+
assert byte_stream.data == b"test data"
82+
assert byte_stream.meta == {"key": "value"}
83+
assert byte_stream.mime_type == "text/plain"
84+
85+
def test_to_dict(byte_stream):
86+
expected = {
87+
"data": b64encode(b"test data").decode("utf-8"),
88+
"meta": {"key": "value"},
89+
"mime_type": "text/plain",
90+
}
91+
assert byte_stream.to_dict() == expected

0 commit comments

Comments
 (0)