Skip to content

Commit e4046d4

Browse files
committed
add to_parquet and push_to_hub
1 parent 50fcf69 commit e4046d4

File tree

5 files changed

+532
-21
lines changed

5 files changed

+532
-21
lines changed

src/datasets/features/audio.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray]) -> pa.Str
241241
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
242242
return array_cast(storage, self.pa_type)
243243

244-
def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
244+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
245245
"""Embed audio files into the Arrow array.
246246
247247
Args:
@@ -252,12 +252,24 @@ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
252252
`pa.StructArray`: Array in the Audio arrow storage type, that is
253253
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
254254
"""
255+
if token_per_repo_id is None:
256+
token_per_repo_id = {}
255257

256258
@no_op_if_value_is_null
257259
def path_to_bytes(path):
258-
with xopen(path, "rb") as f:
259-
bytes_ = f.read()
260-
return bytes_
260+
source_url = path.split("::")[-1]
261+
pattern = (
262+
config.HUB_DATASETS_URL
263+
if source_url.startswith(config.HF_ENDPOINT)
264+
else config.HUB_DATASETS_HFFS_URL
265+
)
266+
source_url_fields = string_to_dict(source_url, pattern)
267+
token = (
268+
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
269+
)
270+
download_config = DownloadConfig(token=token)
271+
with xopen(path, "rb", download_config=download_config) as f:
272+
return f.read()
261273

262274
bytes_array = pa.array(
263275
[

src/datasets/features/image.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
250250
)
251251
return array_cast(storage, self.pa_type)
252252

253-
def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
253+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
254254
"""Embed image files into the Arrow array.
255255
256256
Args:
@@ -261,12 +261,24 @@ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
261261
`pa.StructArray`: Array in the Image arrow storage type, that is
262262
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
263263
"""
264+
if token_per_repo_id is None:
265+
token_per_repo_id = {}
264266

265267
@no_op_if_value_is_null
266268
def path_to_bytes(path):
267-
with xopen(path, "rb") as f:
268-
bytes_ = f.read()
269-
return bytes_
269+
source_url = path.split("::")[-1]
270+
pattern = (
271+
config.HUB_DATASETS_URL
272+
if source_url.startswith(config.HF_ENDPOINT)
273+
else config.HUB_DATASETS_HFFS_URL
274+
)
275+
source_url_fields = string_to_dict(source_url, pattern)
276+
token = (
277+
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
278+
)
279+
download_config = DownloadConfig(token=token)
280+
with xopen(path, "rb", download_config=download_config) as f:
281+
return f.read()
270282

271283
bytes_array = pa.array(
272284
[

src/datasets/features/pdf.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
216216
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
217217
return array_cast(storage, self.pa_type)
218218

219-
def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
219+
def embed_storage(self, storage: pa.StructArray, token_per_repo_id=None) -> pa.StructArray:
220220
"""Embed PDF files into the Arrow array.
221221
222222
Args:
@@ -227,12 +227,24 @@ def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
227227
`pa.StructArray`: Array in the PDF arrow storage type, that is
228228
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
229229
"""
230+
if token_per_repo_id is None:
231+
token_per_repo_id = {}
230232

231233
@no_op_if_value_is_null
232234
def path_to_bytes(path):
233-
with xopen(path, "rb") as f:
234-
bytes_ = f.read()
235-
return bytes_
235+
source_url = path.split("::")[-1]
236+
pattern = (
237+
config.HUB_DATASETS_URL
238+
if source_url.startswith(config.HF_ENDPOINT)
239+
else config.HUB_DATASETS_HFFS_URL
240+
)
241+
source_url_fields = string_to_dict(source_url, pattern)
242+
token = (
243+
token_per_repo_id.get(source_url_fields["repo_id"]) if source_url_fields is not None else None
244+
)
245+
download_config = DownloadConfig(token=token)
246+
with xopen(path, "rb", download_config=download_config) as f:
247+
return f.read()
236248

237249
bytes_array = pa.array(
238250
[

0 commit comments

Comments
 (0)