Skip to content

Commit 3f46b91

Browse files
committed
refact: use bucket path
1 parent 19a45ac commit 3f46b91

File tree

4 files changed

+84
-54
lines changed

4 files changed

+84
-54
lines changed

python/aibrix/aibrix/downloader/base.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from pathlib import Path
2121
from typing import List, Optional
2222

23-
from aibrix import envs
2423
from aibrix.logger import init_logger
2524

25+
from aibrix import envs
26+
2627
logger = init_logger(__name__)
2728

2829

@@ -32,6 +33,8 @@ class BaseDownloader(ABC):
3233

3334
model_uri: str
3435
model_name: str
36+
bucket_path: str
37+
bucket_name: Optional[str]
3538
allow_file_suffix: Optional[List[str]] = field(
3639
default_factory=lambda: envs.DOWNLOADER_ALLOW_FILE_SUFFIX
3740
)
@@ -59,7 +62,13 @@ def _support_range_download(self) -> bool:
5962
pass
6063

6164
@abstractmethod
62-
def download(self, filename: str, local_path: Path, enable_range: bool = True):
65+
def download(
66+
self,
67+
local_path: Path,
68+
bucket_path: str,
69+
bucket_name: str = None,
70+
enable_range: bool = True,
71+
):
6372
pass
6473

6574
def download_directory(self, local_path: Path):
@@ -68,9 +77,9 @@ def download_directory(self, local_path: Path):
6877
directory method for ``Downloader``. Otherwise, the following logic will be
6978
used to download the directory.
7079
"""
71-
directory_list = self._directory_list(self.model_uri)
80+
directory_list = self._directory_list(self.bucket_path)
7281
if self.allow_file_suffix is None:
73-
logger.info(f"All files from {self.model_uri} will be downloaded.")
82+
logger.info(f"All files from {self.bucket_path} will be downloaded.")
7483
filtered_files = directory_list
7584
else:
7685
filtered_files = [
@@ -92,8 +101,9 @@ def download_directory(self, local_path: Path):
92101
futures = [
93102
executor.submit(
94103
self.download,
95-
filename=file,
96104
local_path=local_path,
105+
bucket_path=file,
106+
bucket_name=self.bucket_name,
97107
enable_range=False,
98108
)
99109
for file in filtered_files
@@ -111,7 +121,7 @@ def download_directory(self, local_path: Path):
111121
st = time.perf_counter()
112122
for file in filtered_files:
113123
# use range download to speedup download
114-
self.download(file, local_path, True)
124+
self.download(local_path, file, self.bucket_name, True)
115125
duration = time.perf_counter() - st
116126
logger.info(
117127
f"Downloader {self.__class__.__name__} download "
@@ -132,11 +142,12 @@ def download_model(self, local_path: Optional[str] = None):
132142
# TODO check local file exists
133143

134144
if self._is_directory():
135-
self.download_directory(model_path)
145+
self.download_directory(local_path=model_path)
136146
else:
137147
self.download(
138-
filename=self.model_uri,
139148
local_path=model_path,
149+
bucket_path=self.bucket_path,
150+
bucket_name=self.bucket_name,
140151
enable_range=self._support_range_download(),
141152
)
142153

python/aibrix/aibrix/downloader/huggingface.py

+20-7
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from pathlib import Path
1616
from typing import List, Optional
1717

18+
from aibrix.downloader.base import BaseDownloader
19+
from aibrix.logger import init_logger
1820
from huggingface_hub import HfApi, hf_hub_download, snapshot_download
1921

2022
from aibrix import envs
21-
from aibrix.downloader.base import BaseDownloader
22-
from aibrix.logger import init_logger
2323

2424
logger = init_logger(__name__)
2525

@@ -36,12 +36,19 @@ def __init__(self, model_uri: str, model_name: Optional[str] = None):
3636
else:
3737
model_name = _parse_model_name_from_uri(model_uri)
3838

39-
super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore
40-
4139
self.hf_token = envs.DOWNLOADER_HF_TOKEN
4240
self.hf_endpoint = envs.DOWNLOADER_HF_ENDPOINT
4341
self.hf_revision = envs.DOWNLOADER_HF_REVISION
4442

43+
super().__init__(
44+
model_uri=model_uri,
45+
model_name=model_name,
46+
bucket_path=model_uri,
47+
bucket_name=None,
48+
) # type: ignore
49+
50+
# Dependent on the attributes generated in the base class,
51+
# so place it after the super().__init__() call.
4552
self.allow_patterns = (
4653
None
4754
if self.allow_file_suffix is None
@@ -58,7 +65,7 @@ def _valid_config(self):
5865
assert (
5966
len(self.model_uri.split("/")) == 2
6067
), "Model uri must be in `repo/name` format."
61-
68+
assert self.bucket_name is None, "Bucket name is empty in HuggingFace."
6269
assert self.model_name is not None, "Model name is not set."
6370

6471
def _is_directory(self) -> bool:
@@ -74,10 +81,16 @@ def _directory_list(self, path: str) -> List[str]:
7481
def _support_range_download(self) -> bool:
7582
return False
7683

77-
def download(self, filename: str, local_path: Path, enable_range: bool = True):
84+
def download(
85+
self,
86+
local_path: Path,
87+
bucket_path: str,
88+
bucket_name: str = None,
89+
enable_range: bool = True,
90+
):
7891
hf_hub_download(
7992
repo_id=self.model_uri,
80-
filename=filename,
93+
filename=bucket_path,
8194
local_dir=local_path,
8295
revision=self.hf_revision,
8396
token=self.hf_token,

python/aibrix/aibrix/downloader/s3.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
from urllib.parse import urlparse
1818

1919
import boto3
20+
from aibrix.downloader.base import BaseDownloader
2021
from boto3.s3.transfer import TransferConfig
2122
from tqdm import tqdm
2223

2324
from aibrix import envs
24-
from aibrix.downloader.base import BaseDownloader
2525

2626

2727
def _parse_bucket_info_from_uri(uri: str) -> Tuple[str, str]:
@@ -39,8 +39,6 @@ def __init__(self, model_uri):
3939
endpoint = envs.DOWNLOADER_AWS_ENDPOINT
4040
region = envs.DOWNLOADER_AWS_REGION
4141
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)
42-
self.bucket_name = bucket_name
43-
self.bucket_path = bucket_path
4442

4543
self.client = boto3.client(
4644
service_name="s3",
@@ -50,7 +48,12 @@ def __init__(self, model_uri):
5048
aws_secret_access_key=sk,
5149
)
5250

53-
super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore
51+
super().__init__(
52+
model_uri=model_uri,
53+
model_name=model_name,
54+
bucket_path=bucket_path,
55+
bucket_name=bucket_name,
56+
) # type: ignore
5457

5558
def _valid_config(self):
5659
assert (
@@ -79,26 +82,28 @@ def _is_directory(self) -> bool:
7982

8083
def _directory_list(self, path: str) -> List[str]:
8184
objects_out = self.client.list_objects_v2(
82-
Bucket=self.bucket_name, Delimiter="/", Prefix=self.bucket_path
85+
Bucket=self.bucket_name, Delimiter="/", Prefix=path
8386
)
8487
contents = objects_out.get("Contents", [])
8588
return [content.get("Key") for content in contents]
8689

8790
def _support_range_download(self) -> bool:
8891
return True
8992

90-
def download(self, filename: str, local_path: Path, enable_range: bool = True):
91-
# filename should extract from model_uri when it is not a directory
92-
if not self._is_directory():
93-
filename = self.bucket_path
94-
93+
def download(
94+
self,
95+
local_path: Path,
96+
bucket_path: str,
97+
bucket_name: str = None,
98+
enable_range: bool = True,
99+
):
95100
# check if file exist
96101
try:
97-
meta_data = self.client.head_object(Bucket=self.bucket_name, Key=filename)
102+
meta_data = self.client.head_object(Bucket=bucket_name, Key=bucket_path)
98103
except Exception as e:
99-
raise ValueError(f"TOS file {filename} not exist for {e}.")
104+
raise ValueError(f"S3 bucket path {bucket_path} not exist for {e}.")
100105

101-
_file_name = filename.split("/")[-1]
106+
_file_name = bucket_path.split("/")[-1]
102107
# S3 client does not support Path, convert it to str
103108
local_file = str(local_path.joinpath(_file_name).absolute())
104109

@@ -122,8 +127,8 @@ def download_progress(bytes_transferred):
122127
pbar.update(bytes_transferred)
123128

124129
self.client.download_file(
125-
Bucket=self.bucket_name,
126-
Key=filename,
130+
Bucket=bucket_name,
131+
Key=bucket_path,
127132
Filename=local_file,
128133
Config=config,
129134
Callback=download_progress,

python/aibrix/aibrix/downloader/tos.py

+25-24
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
from urllib.parse import urlparse
1919

2020
import tos
21+
from aibrix.downloader.base import BaseDownloader
2122
from tos import DataTransferType
2223
from tqdm import tqdm
2324

2425
from aibrix import envs
25-
from aibrix.downloader.base import BaseDownloader
2626

2727
tos_logger = logging.getLogger("tos")
2828
tos_logger.setLevel(logging.WARNING)
@@ -40,18 +40,18 @@ def __init__(self, model_uri):
4040
model_name = envs.DOWNLOADER_MODEL_NAME
4141
ak = envs.DOWNLOADER_TOS_ACCESS_KEY
4242
sk = envs.DOWNLOADER_TOS_SECRET_KEY
43-
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ''
44-
region = envs.DOWNLOADER_TOS_REGION or ''
43+
endpoint = envs.DOWNLOADER_TOS_ENDPOINT or ""
44+
region = envs.DOWNLOADER_TOS_REGION or ""
4545
bucket_name, bucket_path = _parse_bucket_info_from_uri(model_uri)
46-
self.bucket_name = bucket_name
47-
self.bucket_path = bucket_path
48-
49-
self.client = tos.TosClientV2(ak=ak,
50-
sk=sk,
51-
endpoint=endpoint,
52-
region=region)
53-
54-
super().__init__(model_uri=model_uri, model_name=model_name) # type: ignore
46+
47+
self.client = tos.TosClientV2(ak=ak, sk=sk, endpoint=endpoint, region=region)
48+
49+
super().__init__(
50+
model_uri=model_uri,
51+
model_name=model_name,
52+
bucket_path=bucket_path,
53+
bucket_name=bucket_name,
54+
) # type: ignore
5555

5656
def _valid_config(self):
5757
assert (
@@ -83,26 +83,27 @@ def _is_directory(self) -> bool:
8383
def _directory_list(self, path: str) -> List[str]:
8484
# TODO cache list_objects_type2 result to avoid too many requests
8585
objects_out = self.client.list_objects_type2(
86-
self.bucket_name, prefix=self.bucket_path, delimiter="/"
86+
self.bucket_name, prefix=path, delimiter="/"
8787
)
88-
8988
return [obj.key for obj in objects_out.contents]
9089

9190
def _support_range_download(self) -> bool:
9291
return True
9392

94-
def download(self, filename: str, local_path: Path, enable_range: bool = True):
95-
# filename should extract from model_uri when it is not a directory
96-
if not self._is_directory():
97-
filename = self.bucket_path
98-
93+
def download(
94+
self,
95+
local_path: Path,
96+
bucket_path: str,
97+
bucket_name: str = None,
98+
enable_range: bool = True,
99+
):
99100
# check if file exist
100101
try:
101-
meta_data = self.client.head_object(bucket=self.bucket_name, key=filename)
102+
meta_data = self.client.head_object(bucket=bucket_name, key=bucket_path)
102103
except Exception as e:
103-
raise ValueError(f"TOS file {filename} not exist for {e}.")
104+
raise ValueError(f"TOS bucket path {bucket_path} not exist for {e}.")
104105

105-
_file_name = filename.split("/")[-1]
106+
_file_name = bucket_path.split("/")[-1]
106107
# TOS client does not support Path, convert it to str
107108
local_file = str(local_path.joinpath(_file_name).absolute())
108109
task_num = envs.DOWNLOADER_NUM_THREADS if enable_range else 1
@@ -121,8 +122,8 @@ def download_progress(
121122
pbar.update(rw_once_bytes)
122123

123124
self.client.download_file(
124-
bucket=self.bucket_name,
125-
key=filename,
125+
bucket=bucket_name,
126+
key=bucket_path,
126127
file_path=local_file,
127128
task_num=task_num,
128129
data_transfer_listener=download_progress,

0 commit comments

Comments
 (0)