Skip to content

Commit 3ff3899

Browse files
committed
Add local-dir: schema support for model loading (config + weights) from folder
1 parent ceca5ef commit 3ff3899

File tree

3 files changed

+128
-46
lines changed

3 files changed

+128
-46
lines changed

timm/models/_builder.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
1212
from timm.models._features_fx import FeatureGraphNet
1313
from timm.models._helpers import load_state_dict
14-
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
15-
load_custom_from_hf
14+
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf, \
15+
load_state_dict_from_path, load_custom_from_hf
1616
from timm.models._manipulate import adapt_input_conv
1717
from timm.models._pretrained import PretrainedCfg
1818
from timm.models._prune import adapt_model_from_file
@@ -45,6 +45,9 @@ def _resolve_pretrained_source(pretrained_cfg):
4545
load_from = 'hf-hub'
4646
assert hf_hub_id
4747
pretrained_loc = hf_hub_id
48+
elif cfg_source == 'local-dir':
49+
load_from = 'local-dir'
50+
pretrained_loc = pretrained_file
4851
else:
4952
# default source == timm or unspecified
5053
if pretrained_sd:
@@ -211,6 +214,13 @@ def load_pretrained(
211214
state_dict = load_state_dict_from_hf(*pretrained_loc, cache_dir=cache_dir)
212215
else:
213216
state_dict = load_state_dict_from_hf(pretrained_loc, weights_only=True, cache_dir=cache_dir)
217+
elif load_from == 'local-dir':
218+
_logger.info(f'Loading pretrained weights from local directory ({pretrained_loc})')
219+
pretrained_path = Path(pretrained_loc)
220+
if pretrained_path.is_dir():
221+
state_dict = load_state_dict_from_path(pretrained_path)
222+
else:
223+
RuntimeError(f"Specified path is not a directory: {pretrained_loc}")
214224
else:
215225
model_name = pretrained_cfg.get('architecture', 'this model')
216226
raise RuntimeError(f"No pretrained weights exist for {model_name}. Use `pretrained=False` for random init.")

timm/models/_factory.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from timm.layers import set_layer_config
77
from ._helpers import load_checkpoint
8-
from ._hub import load_model_config_from_hf
8+
from ._hub import load_model_config_from_hf, load_model_config_from_path
99
from ._pretrained import PretrainedCfg
1010
from ._registry import is_model, model_entrypoint, split_model_name_tag
1111

@@ -18,13 +18,15 @@ def parse_model_name(model_name: str):
1818
# NOTE for backwards compat, deprecate hf_hub use
1919
model_name = model_name.replace('hf_hub', 'hf-hub')
2020
parsed = urlsplit(model_name)
21-
assert parsed.scheme in ('', 'timm', 'hf-hub')
21+
assert parsed.scheme in ('', 'hf-hub', 'local-dir')
2222
if parsed.scheme == 'hf-hub':
2323
# FIXME may use fragment as revision, currently `@` in URI path
2424
return parsed.scheme, parsed.path
25+
elif parsed.scheme == 'local-dir':
26+
return parsed.scheme, parsed.path
2527
else:
2628
model_name = os.path.split(parsed.path)[-1]
27-
return 'timm', model_name
29+
return None, model_name
2830

2931

3032
def safe_model_name(model_name: str, remove_source: bool = True):
@@ -100,20 +102,27 @@ def create_model(
100102
# non-supporting models don't break and default args remain in effect.
101103
kwargs = {k: v for k, v in kwargs.items() if v is not None}
102104

103-
model_source, model_name = parse_model_name(model_name)
104-
if model_source == 'hf-hub':
105+
model_source, model_id = parse_model_name(model_name)
106+
if model_source:
105107
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
106-
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
107-
# load model weights + pretrained_cfg from Hugging Face hub.
108-
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
109-
model_name,
110-
cache_dir=cache_dir,
111-
)
108+
if model_source == 'hf-hub':
109+
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
110+
# load model weights + pretrained_cfg from Hugging Face hub.
111+
pretrained_cfg, model_name, model_args = load_model_config_from_hf(
112+
model_id,
113+
cache_dir=cache_dir,
114+
)
115+
elif model_source == 'local-dir':
116+
pretrained_cfg, model_name, model_args = load_model_config_from_path(
117+
model_id,
118+
)
119+
else:
120+
assert False, f'Unknown model_source {model_source}'
112121
if model_args:
113122
for k, v in model_args.items():
114123
kwargs.setdefault(k, v)
115124
else:
116-
model_name, pretrained_tag = split_model_name_tag(model_name)
125+
model_name, pretrained_tag = split_model_name_tag(model_id)
117126
if pretrained_tag and not pretrained_cfg:
118127
# a valid pretrained_cfg argument takes priority over tag in model name
119128
pretrained_cfg = pretrained_tag

timm/models/_hub.py

+95-32
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from functools import partial
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
8-
from typing import Iterable, List, Optional, Tuple, Union
8+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
99

1010
import torch
1111
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@@ -157,42 +157,60 @@ def download_from_hf(
157157
)
158158

159159

160+
def _parse_model_cfg(
161+
cfg: Dict[str, Any],
162+
extra_fields: Dict[str, Any],
163+
) -> Tuple[Dict[str, Any], str, Dict[str, Any]]:
164+
""""""
165+
# legacy "single‑dict" → split
166+
if "pretrained_cfg" not in cfg:
167+
pretrained_cfg = cfg
168+
cfg = {
169+
"architecture": pretrained_cfg.pop("architecture"),
170+
"num_features": pretrained_cfg.pop("num_features", None),
171+
"pretrained_cfg": pretrained_cfg,
172+
}
173+
if "labels" in pretrained_cfg: # rename ‑‑> label_names
174+
pretrained_cfg["label_names"] = pretrained_cfg.pop("labels")
175+
176+
pretrained_cfg = cfg["pretrained_cfg"]
177+
pretrained_cfg.update(extra_fields)
178+
179+
# top‑level overrides
180+
if "num_classes" in cfg:
181+
pretrained_cfg["num_classes"] = cfg["num_classes"]
182+
if "label_names" in cfg:
183+
pretrained_cfg["label_names"] = cfg.pop("label_names")
184+
if "label_descriptions" in cfg:
185+
pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions")
186+
187+
model_args = cfg.get("model_args", {})
188+
model_name = cfg["architecture"]
189+
return pretrained_cfg, model_name, model_args
190+
191+
160192
def load_model_config_from_hf(
161193
model_id: str,
162194
cache_dir: Optional[Union[str, Path]] = None,
163195
):
196+
"""Original HF‑Hub loader (unchanged download, shared parsing)."""
164197
assert has_hf_hub(True)
165-
cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir)
166-
167-
hf_config = load_cfg_from_json(cached_file)
168-
if 'pretrained_cfg' not in hf_config:
169-
# old form, pull pretrain_cfg out of the base dict
170-
pretrained_cfg = hf_config
171-
hf_config = {}
172-
hf_config['architecture'] = pretrained_cfg.pop('architecture')
173-
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
174-
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
175-
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
176-
hf_config['pretrained_cfg'] = pretrained_cfg
177-
178-
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
179-
pretrained_cfg = hf_config['pretrained_cfg']
180-
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
181-
pretrained_cfg['source'] = 'hf-hub'
182-
183-
# model should be created with base config num_classes if its exist
184-
if 'num_classes' in hf_config:
185-
pretrained_cfg['num_classes'] = hf_config['num_classes']
186-
187-
# label meta-data in base config overrides saved pretrained_cfg on load
188-
if 'label_names' in hf_config:
189-
pretrained_cfg['label_names'] = hf_config.pop('label_names')
190-
if 'label_descriptions' in hf_config:
191-
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
192-
193-
model_args = hf_config.get('model_args', {})
194-
model_name = hf_config['architecture']
195-
return pretrained_cfg, model_name, model_args
198+
cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir)
199+
cfg = load_cfg_from_json(cfg_path)
200+
return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"})
201+
202+
203+
def load_model_config_from_path(
204+
model_path: Union[str, Path],
205+
):
206+
"""Load from ``<model_path>/config.json`` on the local filesystem."""
207+
model_path = Path(model_path)
208+
cfg_file = model_path / "config.json"
209+
if not cfg_file.is_file():
210+
raise FileNotFoundError(f"Config file not found: {cfg_file}")
211+
cfg = load_cfg_from_json(cfg_file)
212+
extra_fields = {"file": str(model_path), "source": "local-dir"}
213+
return _parse_model_cfg(cfg, extra_fields=extra_fields)
196214

197215

198216
def load_state_dict_from_hf(
@@ -236,6 +254,51 @@ def load_state_dict_from_hf(
236254
return state_dict
237255

238256

257+
_PREFERRED_FILES = (
258+
"model.safetensors",
259+
"pytorch_model.bin",
260+
"pytorch_model.pth",
261+
"model.pth",
262+
"open_clip_model.safetensors",
263+
"open_clip_pytorch_model.safetensors",
264+
"open_clip_pytorch_model.bin",
265+
"open_clip_pytorch_model.pth",
266+
)
267+
_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin')
268+
269+
def load_state_dict_from_path(
270+
path: str,
271+
weights_only: bool = False,
272+
):
273+
found_file = None
274+
for fname in _PREFERRED_FILES:
275+
p = path / fname
276+
if p.exists():
277+
logging.info(f"Found preferred checkpoint: {p.name}")
278+
found_file = p
279+
break
280+
281+
# fallback: first match per‑extension class
282+
for ext in _EXT_PRIORITY:
283+
files = sorted(path.glob(f"*{ext}"))
284+
if files:
285+
if len(files) > 1:
286+
logging.warning(
287+
f"Multiple {ext} checkpoints in {path}: {names}. "
288+
f"Using '{files[0].name}'."
289+
)
290+
found_file = files[0]
291+
292+
if not found_file:
293+
raise RuntimeError(f"No suitable checkpoints found in {path}.")
294+
295+
try:
296+
state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only)
297+
except TypeError:
298+
state_dict = torch.load(found_file, map_location='cpu')
299+
return state_dict
300+
301+
239302
def load_custom_from_hf(
240303
model_id: str,
241304
filename: str,

0 commit comments

Comments
 (0)