|
5 | 5 | from functools import partial
|
6 | 6 | from pathlib import Path
|
7 | 7 | from tempfile import TemporaryDirectory
|
8 |
| -from typing import Iterable, List, Optional, Tuple, Union |
| 8 | +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
9 | 9 |
|
10 | 10 | import torch
|
11 | 11 | from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
@@ -157,42 +157,60 @@ def download_from_hf(
|
157 | 157 | )
|
158 | 158 |
|
159 | 159 |
|
| 160 | +def _parse_model_cfg( |
| 161 | + cfg: Dict[str, Any], |
| 162 | + extra_fields: Dict[str, Any], |
| 163 | +) -> Tuple[Dict[str, Any], str, Dict[str, Any]]: |
| 164 | + """""" |
| 165 | + # legacy "single‑dict" → split |
| 166 | + if "pretrained_cfg" not in cfg: |
| 167 | + pretrained_cfg = cfg |
| 168 | + cfg = { |
| 169 | + "architecture": pretrained_cfg.pop("architecture"), |
| 170 | + "num_features": pretrained_cfg.pop("num_features", None), |
| 171 | + "pretrained_cfg": pretrained_cfg, |
| 172 | + } |
| 173 | + if "labels" in pretrained_cfg: # rename ‑‑> label_names |
| 174 | + pretrained_cfg["label_names"] = pretrained_cfg.pop("labels") |
| 175 | + |
| 176 | + pretrained_cfg = cfg["pretrained_cfg"] |
| 177 | + pretrained_cfg.update(extra_fields) |
| 178 | + |
| 179 | + # top‑level overrides |
| 180 | + if "num_classes" in cfg: |
| 181 | + pretrained_cfg["num_classes"] = cfg["num_classes"] |
| 182 | + if "label_names" in cfg: |
| 183 | + pretrained_cfg["label_names"] = cfg.pop("label_names") |
| 184 | + if "label_descriptions" in cfg: |
| 185 | + pretrained_cfg["label_descriptions"] = cfg.pop("label_descriptions") |
| 186 | + |
| 187 | + model_args = cfg.get("model_args", {}) |
| 188 | + model_name = cfg["architecture"] |
| 189 | + return pretrained_cfg, model_name, model_args |
| 190 | + |
| 191 | + |
160 | 192 | def load_model_config_from_hf(
|
161 | 193 | model_id: str,
|
162 | 194 | cache_dir: Optional[Union[str, Path]] = None,
|
163 | 195 | ):
|
| 196 | + """Original HF‑Hub loader (unchanged download, shared parsing).""" |
164 | 197 | assert has_hf_hub(True)
|
165 |
| - cached_file = download_from_hf(model_id, 'config.json', cache_dir=cache_dir) |
166 |
| - |
167 |
| - hf_config = load_cfg_from_json(cached_file) |
168 |
| - if 'pretrained_cfg' not in hf_config: |
169 |
| - # old form, pull pretrain_cfg out of the base dict |
170 |
| - pretrained_cfg = hf_config |
171 |
| - hf_config = {} |
172 |
| - hf_config['architecture'] = pretrained_cfg.pop('architecture') |
173 |
| - hf_config['num_features'] = pretrained_cfg.pop('num_features', None) |
174 |
| - if 'labels' in pretrained_cfg: # deprecated name for 'label_names' |
175 |
| - pretrained_cfg['label_names'] = pretrained_cfg.pop('labels') |
176 |
| - hf_config['pretrained_cfg'] = pretrained_cfg |
177 |
| - |
178 |
| - # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now |
179 |
| - pretrained_cfg = hf_config['pretrained_cfg'] |
180 |
| - pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation |
181 |
| - pretrained_cfg['source'] = 'hf-hub' |
182 |
| - |
183 |
| - # model should be created with base config num_classes if its exist |
184 |
| - if 'num_classes' in hf_config: |
185 |
| - pretrained_cfg['num_classes'] = hf_config['num_classes'] |
186 |
| - |
187 |
| - # label meta-data in base config overrides saved pretrained_cfg on load |
188 |
| - if 'label_names' in hf_config: |
189 |
| - pretrained_cfg['label_names'] = hf_config.pop('label_names') |
190 |
| - if 'label_descriptions' in hf_config: |
191 |
| - pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions') |
192 |
| - |
193 |
| - model_args = hf_config.get('model_args', {}) |
194 |
| - model_name = hf_config['architecture'] |
195 |
| - return pretrained_cfg, model_name, model_args |
| 198 | + cfg_path = download_from_hf(model_id, "config.json", cache_dir=cache_dir) |
| 199 | + cfg = load_cfg_from_json(cfg_path) |
| 200 | + return _parse_model_cfg(cfg, {"hf_hub_id": model_id, "source": "hf-hub"}) |
| 201 | + |
| 202 | + |
| 203 | +def load_model_config_from_path( |
| 204 | + model_path: Union[str, Path], |
| 205 | +): |
| 206 | + """Load from ``<model_path>/config.json`` on the local filesystem.""" |
| 207 | + model_path = Path(model_path) |
| 208 | + cfg_file = model_path / "config.json" |
| 209 | + if not cfg_file.is_file(): |
| 210 | + raise FileNotFoundError(f"Config file not found: {cfg_file}") |
| 211 | + cfg = load_cfg_from_json(cfg_file) |
| 212 | + extra_fields = {"file": str(model_path), "source": "local-dir"} |
| 213 | + return _parse_model_cfg(cfg, extra_fields=extra_fields) |
196 | 214 |
|
197 | 215 |
|
198 | 216 | def load_state_dict_from_hf(
|
@@ -236,6 +254,51 @@ def load_state_dict_from_hf(
|
236 | 254 | return state_dict
|
237 | 255 |
|
238 | 256 |
|
| 257 | +_PREFERRED_FILES = ( |
| 258 | + "model.safetensors", |
| 259 | + "pytorch_model.bin", |
| 260 | + "pytorch_model.pth", |
| 261 | + "model.pth", |
| 262 | + "open_clip_model.safetensors", |
| 263 | + "open_clip_pytorch_model.safetensors", |
| 264 | + "open_clip_pytorch_model.bin", |
| 265 | + "open_clip_pytorch_model.pth", |
| 266 | +) |
| 267 | +_EXT_PRIORITY = ('.safetensors', '.pth', '.pth.tar', '.bin') |
| 268 | + |
| 269 | +def load_state_dict_from_path( |
| 270 | + path: str, |
| 271 | + weights_only: bool = False, |
| 272 | +): |
| 273 | + found_file = None |
| 274 | + for fname in _PREFERRED_FILES: |
| 275 | + p = path / fname |
| 276 | + if p.exists(): |
| 277 | + logging.info(f"Found preferred checkpoint: {p.name}") |
| 278 | + found_file = p |
| 279 | + break |
| 280 | + |
| 281 | + # fallback: first match per‑extension class |
| 282 | + for ext in _EXT_PRIORITY: |
| 283 | + files = sorted(path.glob(f"*{ext}")) |
| 284 | + if files: |
| 285 | + if len(files) > 1: |
| 286 | + logging.warning( |
| 287 | + f"Multiple {ext} checkpoints in {path}: {names}. " |
| 288 | + f"Using '{files[0].name}'." |
| 289 | + ) |
| 290 | + found_file = files[0] |
| 291 | + |
| 292 | + if not found_file: |
| 293 | + raise RuntimeError(f"No suitable checkpoints found in {path}.") |
| 294 | + |
| 295 | + try: |
| 296 | + state_dict = torch.load(found_file, map_location='cpu', weights_only=weights_only) |
| 297 | + except TypeError: |
| 298 | + state_dict = torch.load(found_file, map_location='cpu') |
| 299 | + return state_dict |
| 300 | + |
| 301 | + |
239 | 302 | def load_custom_from_hf(
|
240 | 303 | model_id: str,
|
241 | 304 | filename: str,
|
|
0 commit comments