@@ -499,11 +499,16 @@ def create_model(
499499 # Instantiate the model
500500 logging .info (f"Instantiating model architecture: { model_class .__name__ } " )
501501 model = model_class (** final_model_cfg , cast_dtype = cast_dtype )
502- _set_model_device_and_precision (model , device , precision , is_timm_model )
502+
503+ # The model could be in the meta device if
504+ model_is_in_meta_device = next (model .parameters ()).device .type != "meta"
505+
506+ if not model_is_in_meta_device :
507+ _set_model_device_and_precision (model , device , precision , is_timm_model )
503508
504509 # Load Full Pretrained CLIP Weights (if path exists)
505510 pretrained_loaded = False
506- if checkpoint_path :
511+ if checkpoint_path and not model_is_in_meta_device :
507512 logging .info (f'Loading full pretrained weights from: { checkpoint_path } ' )
508513 # Use the load_checkpoint helper which handles state dict loading, conversions, etc.
509514 # Use strict=True by default for full model loading to catch mismatches.
@@ -518,7 +523,7 @@ def create_model(
518523
519524 # Load tower-specific weights (image and text), after the full CLIP checkpoint, potentially overwriting parts.
520525 pretrained_image_loaded = False # Track if specific image weights loaded
521- if pretrained_image_path :
526+ if pretrained_image_path and not model_is_in_meta_device :
522527 if os .path .isfile (pretrained_image_path ):
523528 logging .info (f"Attempting to load image tower weights from: { pretrained_image_path } " )
524529 try :
@@ -547,7 +552,7 @@ def create_model(
547552 logging .warning (f"Invalid file path specified for pretrained_image_path: { pretrained_image_path } " )
548553
549554 pretrained_text_loaded = False # Track if specific text weights loaded
550- if pretrained_text_path :
555+ if pretrained_text_path and not model_is_in_meta_device :
551556 if os .path .isfile (pretrained_text_path ):
552557 logging .info (f"Attempting to load text tower weights from: { pretrained_text_path } " )
553558 try :
@@ -585,6 +590,8 @@ def create_model(
585590 elif not pretrained_loaded and partially_loaded :
586591 # Some tower weights loaded
587592 logging .warning (f"Model { model_name } initialized partially." )
593+ elif model_is_in_meta_device :
594+ logging .info ("The model is in the 'meta' device and thus it was not initialized." )
588595 elif not pretrained_loaded and not partially_loaded :
589596 # Absolutely no weights were loaded from any source
590597 logging .warning (f"No pretrained weights loaded for model '{ model_name } '. Model initialized randomly." )
0 commit comments