-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathupload.py
48 lines (39 loc) · 1.79 KB
/
upload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import os
import shutil
import tempfile
from huggingface_hub import HfApi, Repository
from huggingface_hub.utils._errors import HfHubHTTPError
def upload_model_to_hf(username, repo_name, ckpt_dir):
# create a new repository on Hugging Face Hub if it doesn't exist
api = HfApi()
repo_id = f"{username}/{repo_name}"
try:
api.create_repo(repo_id=repo_id, private=False)
print(f"Repository {repo_id} created successfully.")
except HfHubHTTPError as e:
if e.response.status_code == 409:
print(f"Repository {repo_id} already exists.")
else:
raise e
# Create a temporary directory for cloning the repository
with tempfile.TemporaryDirectory() as temp_dir:
repo = Repository(local_dir=temp_dir, clone_from=repo_id)
for file_name in os.listdir(ckpt_dir):
full_file_name = os.path.join(ckpt_dir, file_name)
if os.path.isfile(full_file_name):
shutil.copy(full_file_name, repo.local_dir)
# push the files to the repository
repo.push_to_hub()
print(f"Model {repo_name} has been successfully uploaded to Hugging Face Hub.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upload model checkpoint to Hugging Face Hub.")
parser.add_argument('--repo_name', type=str, required=True,
help="The name of the Hugging Face repository.")
parser.add_argument('--ckpt_dir', type=str, required=True,
help="The directory of the model checkpoint.")
parser.add_argument('--username', type=str, default="AI4Library",
help="Your Hugging Face username.")
args = parser.parse_args()
upload_model_to_hf(args.username, args.repo_name, args.ckpt_dir)