Skip to content

Commit eb287cf

Browse files
authored
Support loading a model in a non-init context
1 parent 8480ac9 commit eb287cf

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/open_clip/factory.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,11 +499,16 @@ def create_model(
499499
# Instantiate the model
500500
logging.info(f"Instantiating model architecture: {model_class.__name__}")
501501
model = model_class(**final_model_cfg, cast_dtype=cast_dtype)
502-
_set_model_device_and_precision(model, device, precision, is_timm_model)
502+
503+
# The model could be in the meta device if
504+
model_is_in_meta_device = next(model.parameters()).device.type != "meta"
505+
506+
if not model_is_in_meta_device:
507+
_set_model_device_and_precision(model, device, precision, is_timm_model)
503508

504509
# Load Full Pretrained CLIP Weights (if path exists)
505510
pretrained_loaded = False
506-
if checkpoint_path:
511+
if checkpoint_path and not model_is_in_meta_device:
507512
logging.info(f'Loading full pretrained weights from: {checkpoint_path}')
508513
# Use the load_checkpoint helper which handles state dict loading, conversions, etc.
509514
# Use strict=True by default for full model loading to catch mismatches.
@@ -518,7 +523,7 @@ def create_model(
518523

519524
# Load tower-specific weights (image and text), after the full CLIP checkpoint, potentially overwriting parts.
520525
pretrained_image_loaded = False # Track if specific image weights loaded
521-
if pretrained_image_path:
526+
if pretrained_image_path and not model_is_in_meta_device:
522527
if os.path.isfile(pretrained_image_path):
523528
logging.info(f"Attempting to load image tower weights from: {pretrained_image_path}")
524529
try:
@@ -547,7 +552,7 @@ def create_model(
547552
logging.warning(f"Invalid file path specified for pretrained_image_path: {pretrained_image_path}")
548553

549554
pretrained_text_loaded = False # Track if specific text weights loaded
550-
if pretrained_text_path:
555+
if pretrained_text_path and not model_is_in_meta_device:
551556
if os.path.isfile(pretrained_text_path):
552557
logging.info(f"Attempting to load text tower weights from: {pretrained_text_path}")
553558
try:
@@ -585,6 +590,8 @@ def create_model(
585590
elif not pretrained_loaded and partially_loaded:
586591
# Some tower weights loaded
587592
logging.warning(f"Model {model_name} initialized partially.")
593+
elif model_is_in_meta_device:
594+
logging.info("The model is in the 'meta' device and thus it was not initialized.")
588595
elif not pretrained_loaded and not partially_loaded:
589596
# Absolutely no weights were loaded from any source
590597
logging.warning(f"No pretrained weights loaded for model '{model_name}'. Model initialized randomly.")

0 commit comments

Comments
 (0)