Skip to content

Commit ed13a97

Browse files
authored
fix: uploading large files saving to disk instead of memory (#4935)
* fix: uploading large files saving to disk instead of memory Signed-off-by: Frost Ming <[email protected]> * fix: context managers Signed-off-by: Frost Ming <[email protected]>
1 parent 737d402 commit ed13a97

File tree

3 files changed

+88
-72
lines changed

3 files changed

+88
-72
lines changed

src/bentoml/_internal/cloud/base.py

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

3-
import io
43
import typing as t
54
from abc import ABC
65
from abc import abstractmethod
76
from contextlib import contextmanager
87

8+
import attrs
99
from rich.console import Group
1010
from rich.live import Live
1111
from rich.panel import Panel
@@ -33,26 +33,40 @@
3333
FILE_CHUNK_SIZE = 100 * 1024 * 1024 # 100Mb
3434

3535

36-
class CallbackIOWrapper(io.BytesIO):
37-
read_cb: t.Callable[[int], None] | None
38-
write_cb: t.Callable[[int], None] | None
36+
@attrs.define
37+
class CallbackIOWrapper(t.IO[bytes]):
38+
file: t.IO[bytes]
39+
read_cb: t.Callable[[int], None] | None = None
40+
write_cb: t.Callable[[int], None] | None = None
41+
start: int | None = None
42+
end: int | None = None
3943

40-
def __init__(
41-
self,
42-
buffer: t.Any = None,
43-
*,
44-
read_cb: t.Callable[[int], None] | None = None,
45-
write_cb: t.Callable[[int], None] | None = None,
46-
):
47-
self.read_cb = read_cb
48-
self.write_cb = write_cb
49-
super().__init__(buffer)
44+
def __attrs_post_init__(self) -> None:
45+
self.file.seek(self.start or 0, 0)
5046

51-
def read(self, size: int | None = None) -> bytes:
52-
if size is not None:
53-
res = super().read(size)
47+
def seek(self, offset: int, whence: int = 0) -> int:
48+
if whence == 2 and self.end is not None:
49+
length = self.file.seek(self.end, 0)
5450
else:
55-
res = super().read()
51+
length = self.file.seek(offset, whence)
52+
return length - (self.start or 0)
53+
54+
def tell(self) -> int:
55+
return self.file.tell()
56+
57+
def fileno(self) -> int:
58+
# Raise OSError to prevent access to the underlying file descriptor
59+
raise OSError("fileno")
60+
61+
def __getattr__(self, name: str) -> t.Any:
62+
return getattr(self.file, name)
63+
64+
def read(self, size: int = -1) -> bytes:
65+
pos = self.tell()
66+
if self.end is not None:
67+
if size < 0 or size > self.end - pos:
68+
size = self.end - pos
69+
res = self.file.read(size)
5670
if self.read_cb is not None:
5771
self.read_cb(len(res))
5872
return res
@@ -64,6 +78,9 @@ def write(self, data: bytes) -> t.Any: # type: ignore # python buffer types ar
6478
self.write_cb(len(data))
6579
return res
6680

81+
def __iter__(self) -> t.Iterator[bytes]:
82+
return iter(self.file)
83+
6784

6885
class Spinner:
6986
"""A UI component that renders as follows:
@@ -109,20 +126,23 @@ def console(self) -> "Console":
109126
def spin(self, text: str) -> t.Generator[TaskID, None, None]:
110127
"""Create a spinner as a context manager."""
111128
try:
112-
task_id = self.update(text)
129+
task_id = self.update(text, new=True)
113130
yield task_id
114131
finally:
115132
self._spinner_task_id = None
116133
self._spinner_progress.stop_task(task_id)
117134
self._spinner_progress.update(task_id, visible=False)
118135

119-
def update(self, text: str) -> TaskID:
136+
def update(self, text: str, new: bool = False) -> TaskID:
120137
"""Update the spin text."""
121-
if self._spinner_task_id is None:
122-
self._spinner_task_id = self._spinner_progress.add_task(text)
138+
if self._spinner_task_id is None or new:
139+
task_id = self._spinner_progress.add_task(text)
140+
if self._spinner_task_id is None:
141+
self._spinner_task_id = task_id
123142
else:
124-
self._spinner_progress.update(self._spinner_task_id, description=text)
125-
return self._spinner_task_id
143+
task_id = self._spinner_task_id
144+
self._spinner_progress.update(task_id, description=text)
145+
return task_id
126146

127147
def __rich_console__(
128148
self, console: "Console", options: "ConsoleOptions"

src/bentoml/_internal/cloud/bentocloud.py

Lines changed: 42 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

3+
import math
34
import tarfile
45
import tempfile
5-
import threading
66
import typing as t
77
import warnings
88
from concurrent.futures import ThreadPoolExecutor
@@ -84,6 +84,7 @@ def _do_push_bento(
8484
threads: int = 10,
8585
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
8686
model_store: ModelStore = Provide[BentoMLContainer.model_store],
87+
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
8788
):
8889
name = bento.tag.name
8990
version = bento.tag.version
@@ -213,10 +214,11 @@ def push_model(model: Model) -> None:
213214
presigned_upload_url = remote_bento.presigned_upload_url
214215

215216
def io_cb(x: int):
216-
with io_mutex:
217-
self.spinner.transmission_progress.update(upload_task_id, advance=x)
217+
self.spinner.transmission_progress.update(upload_task_id, advance=x)
218218

219-
with CallbackIOWrapper(read_cb=io_cb) as tar_io:
219+
with NamedTemporaryFile(
220+
prefix="bentoml-bento-", suffix=".tar", dir=bentoml_tmp_dir
221+
) as tar_io:
220222
with self.spinner.spin(
221223
text=f'Creating tar archive for bento "{bento.tag}"..'
222224
):
@@ -232,42 +234,38 @@ def filter_(
232234
return tar_info
233235

234236
tar.add(bento.path, arcname="./", filter=filter_)
235-
tar_io.seek(0, 0)
236237

237238
with self.spinner.spin(text=f'Start uploading bento "{bento.tag}"..'):
238239
rest_client.v1.start_upload_bento(
239240
bento_repository_name=bento_repository.name, version=version
240241
)
241-
242-
file_size = tar_io.getbuffer().nbytes
242+
file_size = tar_io.tell()
243+
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)
243244

244245
self.spinner.transmission_progress.update(
245246
upload_task_id, completed=0, total=file_size, visible=True
246247
)
247248
self.spinner.transmission_progress.start_task(upload_task_id)
248249

249-
io_mutex = threading.Lock()
250-
251250
if transmission_strategy == "proxy":
252251
try:
253252
rest_client.v1.upload_bento(
254253
bento_repository_name=bento_repository.name,
255254
version=version,
256-
data=tar_io,
255+
data=io_with_cb,
257256
)
258257
except Exception as e: # pylint: disable=broad-except
259258
self.spinner.log(f'[bold red]Failed to upload bento "{bento.tag}"')
260259
raise e
261260
self.spinner.log(f'[bold green]Successfully pushed bento "{bento.tag}"')
262261
return
263262
finish_req = FinishUploadBentoSchema(
264-
status=BentoUploadStatus.SUCCESS.value,
265-
reason="",
263+
status=BentoUploadStatus.SUCCESS.value, reason=""
266264
)
267265
try:
268266
if presigned_upload_url is not None:
269267
resp = httpx.put(
270-
presigned_upload_url, content=tar_io, timeout=36000
268+
presigned_upload_url, content=io_with_cb, timeout=36000
271269
)
272270
if resp.status_code != 200:
273271
finish_req = FinishUploadBentoSchema(
@@ -289,7 +287,8 @@ def filter_(
289287

290288
upload_id: str = remote_bento.upload_id
291289

292-
chunks_count = file_size // FILE_CHUNK_SIZE + 1
290+
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
291+
tar_io.file.close()
293292

294293
def chunk_upload(
295294
upload_id: str, chunk_number: int
@@ -310,18 +309,16 @@ def chunk_upload(
310309
with self.spinner.spin(
311310
text=f'({chunk_number}/{chunks_count}) Uploading chunk of Bento "{bento.tag}"...'
312311
):
313-
chunk = (
314-
tar_io.getbuffer()[
315-
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
316-
* FILE_CHUNK_SIZE
317-
]
318-
if chunk_number < chunks_count
319-
else tar_io.getbuffer()[
320-
(chunk_number - 1) * FILE_CHUNK_SIZE :
321-
]
322-
)
312+
with open(tar_io.name, "rb") as f:
313+
chunk_io = CallbackIOWrapper(
314+
f,
315+
read_cb=io_cb,
316+
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
317+
end=chunk_number * FILE_CHUNK_SIZE
318+
if chunk_number < chunks_count
319+
else None,
320+
)
323321

324-
with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
325322
resp = httpx.put(
326323
remote_bento.presigned_upload_url,
327324
content=chunk_io,
@@ -588,6 +585,7 @@ def _do_push_model(
588585
force: bool = False,
589586
threads: int = 10,
590587
rest_client: RestApiClient = Provide[BentoMLContainer.rest_api_client],
588+
bentoml_tmp_dir: str = Provide[BentoMLContainer.tmp_bento_store_dir],
591589
):
592590
name = model.tag.name
593591
version = model.tag.version
@@ -663,38 +661,37 @@ def _do_push_model(
663661
transmission_strategy = "presigned_url"
664662
presigned_upload_url = remote_model.presigned_upload_url
665663

666-
io_mutex = threading.Lock()
667-
668664
def io_cb(x: int):
669-
with io_mutex:
670-
self.spinner.transmission_progress.update(upload_task_id, advance=x)
665+
self.spinner.transmission_progress.update(upload_task_id, advance=x)
671666

672-
with CallbackIOWrapper(read_cb=io_cb) as tar_io:
667+
with NamedTemporaryFile(
668+
prefix="bentoml-model-", suffix=".tar", dir=bentoml_tmp_dir
669+
) as tar_io:
673670
with self.spinner.spin(
674671
text=f'Creating tar archive for model "{model.tag}"..'
675672
):
676673
with tarfile.open(fileobj=tar_io, mode="w:") as tar:
677674
tar.add(model.path, arcname="./")
678-
tar_io.seek(0, 0)
679675
with self.spinner.spin(text=f'Start uploading model "{model.tag}"..'):
680676
rest_client.v1.start_upload_model(
681677
model_repository_name=model_repository.name, version=version
682678
)
683-
file_size = tar_io.getbuffer().nbytes
679+
file_size = tar_io.tell()
684680
self.spinner.transmission_progress.update(
685681
upload_task_id,
686682
description=f'Uploading model "{model.tag}"',
687683
total=file_size,
688684
visible=True,
689685
)
690686
self.spinner.transmission_progress.start_task(upload_task_id)
687+
io_with_cb = CallbackIOWrapper(tar_io, read_cb=io_cb)
691688

692689
if transmission_strategy == "proxy":
693690
try:
694691
rest_client.v1.upload_model(
695692
model_repository_name=model_repository.name,
696693
version=version,
697-
data=tar_io,
694+
data=io_with_cb,
698695
)
699696
except Exception as e: # pylint: disable=broad-except
700697
self.spinner.log(f'[bold red]Failed to upload model "{model.tag}"')
@@ -708,7 +705,7 @@ def io_cb(x: int):
708705
try:
709706
if presigned_upload_url is not None:
710707
resp = httpx.put(
711-
presigned_upload_url, content=tar_io, timeout=36000
708+
presigned_upload_url, content=io_with_cb, timeout=36000
712709
)
713710
if resp.status_code != 200:
714711
finish_req = FinishUploadModelSchema(
@@ -730,7 +727,8 @@ def io_cb(x: int):
730727

731728
upload_id: str = remote_model.upload_id
732729

733-
chunks_count = file_size // FILE_CHUNK_SIZE + 1
730+
chunks_count = math.ceil(file_size / FILE_CHUNK_SIZE)
731+
tar_io.file.close()
734732

735733
def chunk_upload(
736734
upload_id: str, chunk_number: int
@@ -752,18 +750,16 @@ def chunk_upload(
752750
with self.spinner.spin(
753751
text=f'({chunk_number}/{chunks_count}) Uploading chunk of model "{model.tag}"...'
754752
):
755-
chunk = (
756-
tar_io.getbuffer()[
757-
(chunk_number - 1) * FILE_CHUNK_SIZE : chunk_number
758-
* FILE_CHUNK_SIZE
759-
]
760-
if chunk_number < chunks_count
761-
else tar_io.getbuffer()[
762-
(chunk_number - 1) * FILE_CHUNK_SIZE :
763-
]
764-
)
753+
with open(tar_io.name, "rb") as f:
754+
chunk_io = CallbackIOWrapper(
755+
f,
756+
read_cb=io_cb,
757+
start=(chunk_number - 1) * FILE_CHUNK_SIZE,
758+
end=chunk_number * FILE_CHUNK_SIZE
759+
if chunk_number < chunks_count
760+
else None,
761+
)
765762

766-
with CallbackIOWrapper(chunk, read_cb=io_cb) as chunk_io:
767763
resp = httpx.put(
768764
remote_model.presigned_upload_url,
769765
content=chunk_io,

src/bentoml/_internal/cloud/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def finish_upload_bento(
263263
return schema_from_json(resp.text, BentoSchema)
264264

265265
def upload_bento(
266-
self, bento_repository_name: str, version: str, data: t.BinaryIO
266+
self, bento_repository_name: str, version: str, data: t.IO[bytes]
267267
) -> None:
268268
url = urljoin(
269269
self.endpoint,
@@ -416,7 +416,7 @@ def finish_upload_model(
416416
return schema_from_json(resp.text, ModelSchema)
417417

418418
def upload_model(
419-
self, model_repository_name: str, version: str, data: t.BinaryIO
419+
self, model_repository_name: str, version: str, data: t.IO[bytes]
420420
) -> None:
421421
url = urljoin(
422422
self.endpoint,

0 commit comments

Comments
 (0)