@@ -26,9 +26,10 @@ def fetch_model(
26
26
path : Optional [str ] = None ,
27
27
access_token : Optional [str ] = None ,
28
28
source : str = "modelscope" ,
29
+ fetch_safetensors : bool = True ,
29
30
) -> str :
30
31
if source == "modelscope" :
31
- return fetch_modelscope_model (model_uri , revision , path , access_token )
32
+ return fetch_modelscope_model (model_uri , revision , path , access_token , fetch_safetensors )
32
33
if source == "civitai" :
33
34
return fetch_civitai_model (model_uri )
34
35
raise ValueError (f'source should be one of { MODEL_SOURCES } but got "{ source } "' )
@@ -39,6 +40,7 @@ def fetch_modelscope_model(
39
40
revision : Optional [str ] = None ,
40
41
path : Optional [str ] = None ,
41
42
access_token : Optional [str ] = None ,
43
+ fetch_safetensors : bool = True ,
42
44
) -> str :
43
45
lock_file_name = f"modelscope.{ model_id .replace ('/' , '--' )} .{ revision if revision else '__version' } .lock"
44
46
lock_file_path = os .path .join (DIFFSYNTH_FILELOCK_DIR , lock_file_name )
@@ -55,7 +57,7 @@ def fetch_modelscope_model(
55
57
else :
56
58
path = dirpath
57
59
58
- if os .path .isdir (path ):
60
+ if os .path .isdir (path ) and fetch_safetensors :
59
61
return _fetch_safetensors (path )
60
62
return path
61
63
0 commit comments