Skip to content

Commit 2b71bb4

Browse files
daiyiplangfun authors
authored and
langfun authors
committed
lf.Mime.from_uri to support data URLs.
E.g.: ``` lf.Mime.from_uri('data:text/plain;base64,...') ``` PiperOrigin-RevId: 739602781
1 parent 77d42c5 commit 2b71bb4

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

langfun/core/modalities/mime.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,33 @@ def embeddable_uri(self) -> str:
152152

153153
@classmethod
154154
def from_uri(cls, uri: str, **kwargs) -> 'Mime':
155+
if uri.startswith('data:'):
156+
mime_type, content = cls._parse_data_uri(uri)
157+
return cls.class_from_mime_type(mime_type).from_bytes(content, **kwargs)
158+
155159
if cls is Mime:
156160
content = cls.download(uri)
157161
mime = from_buffer(content, mime=True).lower()
158162
return cls.class_from_mime_type(mime)(uri=uri, content=content, **kwargs)
159163
return cls(uri=uri, content=None, **kwargs)
160164

165+
@classmethod
166+
def _parse_data_uri(cls, uri: str) -> tuple[str, bytes]:
167+
"""Returns the MIME type and content from the given data URI."""
168+
assert uri.startswith('data:'), uri
169+
mime_end_pos = uri.find(';', 0)
170+
if mime_end_pos == -1:
171+
raise ValueError(f'Invalid data URI: {uri!r}.')
172+
mime_type = uri[5: mime_end_pos].strip().lower()
173+
encoding_end_pos = uri.find(',', mime_end_pos + 1)
174+
if encoding_end_pos == -1:
175+
raise ValueError(f'Invalid data URI: {uri!r}.')
176+
encoding = uri[mime_end_pos + 1: encoding_end_pos].strip().lower()
177+
if encoding != 'base64':
178+
raise ValueError(f'Unsupported encoding: {encoding!r}.')
179+
base64_content = uri[encoding_end_pos + 1:].strip().encode()
180+
return mime_type, base64.b64decode(base64_content)
181+
161182
@classmethod
162183
def from_bytes(cls, content: bytes | str, **kwargs) -> 'Mime':
163184
if cls is Mime:

langfun/core/modalities/mime_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,19 @@ def test_from_uri(self):
9393
self.assertEqual(content.to_bytes(), b'bar')
9494
self.assertEqual(content.mime_type, 'text/plain')
9595

96+
content = mime.Mime.from_uri('data:text/plain;base64,Zm9v')
97+
self.assertIsNone(content.uri)
98+
self.assertEqual(content.mime_type, 'text/plain')
99+
self.assertEqual(content.content, b'foo')
100+
self.assertEqual(content.content_uri, 'data:text/plain;base64,Zm9v')
101+
self.assertEqual(content.embeddable_uri, 'data:text/plain;base64,Zm9v')
102+
103+
with self.assertRaisesRegex(ValueError, 'Invalid data URI'):
104+
mime.Mime.from_uri('data:text/plain')
105+
106+
with self.assertRaisesRegex(ValueError, 'Invalid data URI'):
107+
mime.Mime.from_uri('data:text/plain;base64,')
108+
96109
def assert_html_content(self, html, expected):
97110
expected = inspect.cleandoc(expected).strip()
98111
actual = html.content.strip()

0 commit comments

Comments
 (0)