Skip to content

Commit

Permalink
✏️ Add repo_type to list_repo_files
Browse files Browse the repository at this point in the history
  • Loading branch information
arxyzan committed Jun 11, 2024
1 parent a4d47ab commit edb20e6
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion hezar/preprocessors/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def load(
"""
subfolder = subfolder or cls.preprocessor_subfolder
cache_dir = cache_dir or HEZAR_CACHE_DIR
preprocessor_files = list_repo_files(hub_or_local_path, subfolder=subfolder)
preprocessor_files = list_repo_files(hub_or_local_path, subfolder=subfolder, repo_type="model")
preprocessors = PreprocessorsContainer()
for f in preprocessor_files:
if f.endswith(".yaml"):
Expand Down
5 changes: 3 additions & 2 deletions hezar/utils/hub_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,14 @@ def clone_repo(repo_id: str, save_path: str, **kwargs):
return repo.local_dir


def list_repo_files(hub_or_local_path: str, subfolder: str = None):
def list_repo_files(hub_or_local_path: str, subfolder: str = None, repo_type: str | RepoType = RepoType.MODEL):
"""
List all files in a Hub or local model repo
Args:
hub_or_local_path: Path to hub or local repo
subfolder: Optional subfolder path
repo_type: Repo type of either dataset or model
Returns:
A list of all file names
Expand All @@ -98,7 +99,7 @@ def list_repo_files(hub_or_local_path: str, subfolder: str = None):
for file in files_:
files.append(os.path.relpath(os.path.join(root, file), hub_or_local_path))
else:
files = HfApi().list_repo_files(hub_or_local_path, repo_type=str(RepoType.MODEL))
files = HfApi().list_repo_files(hub_or_local_path, repo_type=str(repo_type))

if subfolder:
files = [os.path.relpath(f, subfolder) for f in files if subfolder in f]
Expand Down

0 comments on commit edb20e6

Please sign in to comment.