Skip to content

Commit 287331d

Browse files
committed
update
1 parent 67dc65e commit 287331d

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/diffusers/models/model_loading_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,9 @@ def _load_shard_file(
354354
state_dict_folder=None,
355355
ignore_mismatched_sizes=False,
356356
low_cpu_mem_usage=False,
357+
disable_mmap=False,
357358
):
358-
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
359+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
359360
mismatched_keys = _find_mismatched_keys(
360361
state_dict,
361362
model_state_dict,
@@ -401,6 +402,7 @@ def _load_shard_files_with_threadpool(
401402
state_dict_folder=None,
402403
ignore_mismatched_sizes=False,
403404
low_cpu_mem_usage=False,
405+
disable_mmap=False,
404406
):
405407
# Do not spawn anymore workers than you need
406408
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
@@ -427,6 +429,7 @@ def _load_shard_files_with_threadpool(
427429
state_dict_folder=state_dict_folder,
428430
ignore_mismatched_sizes=ignore_mismatched_sizes,
429431
low_cpu_mem_usage=low_cpu_mem_usage,
432+
disable_mmap=disable_mmap,
430433
)
431434

432435
with ThreadPoolExecutor(max_workers=num_workers) as executor:

src/diffusers/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,6 +1298,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
12981298
keep_in_fp32_modules=keep_in_fp32_modules,
12991299
dduf_entries=dduf_entries,
13001300
is_parallel_loading_enabled=is_parallel_loading_enabled,
1301+
disable_mmap=disable_mmap,
13011302
)
13021303
loading_info = {
13031304
"missing_keys": missing_keys,
@@ -1584,6 +1585,7 @@ def _load_pretrained_model(
15841585
offload_folder: Optional[Union[str, os.PathLike]] = None,
15851586
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
15861587
is_parallel_loading_enabled: Optional[bool] = False,
1588+
disable_mmap: bool = False,
15871589
):
15881590
model_state_dict = model.state_dict()
15891591
expected_keys = list(model_state_dict.keys())
@@ -1652,6 +1654,7 @@ def _load_pretrained_model(
16521654
state_dict_folder=state_dict_folder,
16531655
ignore_mismatched_sizes=ignore_mismatched_sizes,
16541656
low_cpu_mem_usage=low_cpu_mem_usage,
1657+
disable_mmap=disable_mmap,
16551658
)
16561659

16571660
if is_parallel_loading_enabled:

0 commit comments

Comments
 (0)