Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)#911
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #911 +/- ##
==========================================
+ Coverage 73.10% 73.11% +0.01%
==========================================
Files 205 205
Lines 22294 22294
==========================================
+ Hits 16297 16300 +3
+ Misses 5997 5994 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
jingyu-ml
left a comment
There was a problem hiding this comment.
Left some comments, overall it looks good to me.
69107c0 to
7e74e91
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughIntroduces support for merging a base safetensors checkpoint into exported diffusion transformer models. Adds utility functions to detect model type and merge transformer state dicts with base checkpoint data, then integrates this merge workflow into the export pipeline with quantization metadata support. Changes
Sequence DiagramsequenceDiagram
actor User
participant Exporter as Export Pipeline
participant TypeDetector as Model Type Detector
participant MergeRegistry as Merge Registry
participant Merger as Merge Function
participant BaseCheckpoint as Base Safetensors
participant Transformer as Transformer Dict
participant MetadataHandler as Metadata Handler
participant Output as Safetensors Output
User->>Exporter: export_hf_checkpoint(model, merged_base_safetensor_path)
Exporter->>TypeDetector: get_diffusion_model_type(pipe)
TypeDetector-->>Exporter: model_type (e.g., "ltx2")
Exporter->>MergeRegistry: DIFFUSION_MERGE_FUNCTIONS[model_type]
MergeRegistry-->>Exporter: _merge_ltx2 function
Exporter->>Merger: _merge_ltx2(transformer_state_dict, base_path)
Merger->>BaseCheckpoint: read VAE, vocoder, embeddings
BaseCheckpoint-->>Merger: base components
Merger->>Transformer: merge base components with transformer keys
Transformer-->>Merger: merged_state_dict
Exporter->>MetadataHandler: attach quantization_config & metadata
MetadataHandler-->>Exporter: enriched state_dict
Exporter->>Output: save model.safetensors with metadata
Output-->>User: exported checkpoint
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
modelopt/torch/export/diffusers_utils.py (1)
681-723: Double file read:load_file+safe_openon the same checkpoint.The base safetensors file is read twice: once via
load_file(line 681) to get tensors, and again viasafe_open(line 719) to get metadata. This means parsing a potentially multi-GB file twice.You can read metadata in the same
safe_opencall and also load tensors from it, or usesafe_openfor both purposes:Proposed: single-pass using safe_open
- base_state = load_file(merged_base_safetensor_path) + with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: + base_metadata = f.metadata() or {} + base_state = {key: f.get_tensor(key) for key in f.keys()} non_transformer_prefixes = [ ... ] ... merged = dict(base_non_transformer) merged.update(base_connectors) merged.update(prefixed) - with safe_open(merged_base_safetensor_path, framework="pt", device="cpu") as f: - base_metadata = f.metadata() or {} del base_state return merged, base_metadata🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/diffusers_utils.py` around lines 681 - 723, The code currently calls load_file(merged_base_safetensor_path) to populate base_state and later re-opens the same file via safe_open(merged_base_safetensor_path, ...) just to read metadata, causing a double read; replace this by using safe_open once to both access tensors and metadata: open the safetensor with safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read tensors into base_state from that handle (instead of load_file), extract base_metadata from f.metadata(), then proceed to build base_non_transformer, base_connectors, prefixed, merged and return merged and base_metadata; update references to base_state accordingly and remove the redundant load_file call.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/export/diffusers_utils.pymodelopt/torch/export/unified_export_hf.py
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 885-887: The code unconditionally calls
get_diffusion_model_type(pipe) when merged_base_safetensor_path is set, which
will raise a ValueError for non-LTX-2 diffusers; update export_hf_checkpoint to
first check the pipeline type (or a predicate like is_ltx2_pipeline(pipe))
before calling get_diffusion_model_type, and if merged_base_safetensor_path is
provided for an unsupported pipeline either raise a clearer, descriptive error
mentioning export_hf_checkpoint and merged_base_safetensor_path or document this
constraint in the function docstring so users aren’t met with an opaque
ValueError from get_diffusion_model_type.
- Around line 960-968: The current path calls
_save_component_state_dict_safetensors(component, component_export_dir,
merged_base_safetensor_path, model_type=...) for any component that doesn't
implement save_pretrained, which unintentionally applies the
merged_base_safetensor_path merge to non-quantized components; update the logic
so merged_base_safetensor_path is only passed when the component is quantized
(e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.
- Around line 136-169: The computed metadata (metadata and metadata_full) is
discarded for non-merge exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.
---
Nitpick comments:
In `@modelopt/torch/export/diffusers_utils.py`:
- Around line 681-723: The code currently calls
load_file(merged_base_safetensor_path) to populate base_state and later re-opens
the same file via safe_open(merged_base_safetensor_path, ...) just to read
metadata, causing a double read; replace this by using safe_open once to both
access tensors and metadata: open the safetensor with
safe_open(merged_base_safetensor_path, framework="pt", device="cpu"), read
tensors into base_state from that handle (instead of load_file), extract
base_metadata from f.metadata(), then proceed to build base_non_transformer,
base_connectors, prefixed, merged and return merged and base_metadata; update
references to base_state accordingly and remove the redundant load_file call.
| metadata: dict[str, str] = {} | ||
| metadata_full: dict[str, str] = {} | ||
| if merged_base_safetensor_path is not None and model_type is not None: | ||
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | ||
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | ||
| if hf_quant_config is not None: | ||
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | ||
|
|
||
| # Build per-layer _quantization_metadata for ComfyUI | ||
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | ||
| layer_metadata = {} | ||
| for k in cpu_state_dict: | ||
| if k.endswith((".weight_scale", ".weight_scale_2")): | ||
| layer_name = k.rsplit(".", 1)[0] | ||
| if layer_name.endswith(".weight"): | ||
| layer_name = layer_name.rsplit(".", 1)[0] | ||
| if layer_name not in layer_metadata: | ||
| layer_metadata[layer_name] = {"format": quant_algo} | ||
| metadata_full["_quantization_metadata"] = json.dumps( | ||
| { | ||
| "format_version": "1.0", | ||
| "layers": layer_metadata, | ||
| } | ||
| ) | ||
|
|
||
| metadata["_export_format"] = "safetensors_state_dict" | ||
| metadata["_class_name"] = type(component).__name__ | ||
| metadata_full.update(metadata) | ||
|
|
||
| save_file( | ||
| cpu_state_dict, | ||
| str(component_export_dir / "model.safetensors"), | ||
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | ||
| ) |
There was a problem hiding this comment.
Metadata is discarded for non-merge exports.
On line 168, metadata is only passed to save_file when merged_base_safetensor_path is not None. For non-merge exports through this function, the _export_format and _class_name metadata (lines 161-162) are computed but thrown away — save_file is called with metadata=None.
If metadata should always be attached (even for non-merge exports), pass it unconditionally:
Proposed fix
save_file(
cpu_state_dict,
str(component_export_dir / "model.safetensors"),
- metadata=metadata_full if merged_base_safetensor_path is not None else None,
+ metadata=metadata_full,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| metadata: dict[str, str] = {} | |
| metadata_full: dict[str, str] = {} | |
| if merged_base_safetensor_path is not None and model_type is not None: | |
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | |
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | |
| if hf_quant_config is not None: | |
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | |
| # Build per-layer _quantization_metadata for ComfyUI | |
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | |
| layer_metadata = {} | |
| for k in cpu_state_dict: | |
| if k.endswith((".weight_scale", ".weight_scale_2")): | |
| layer_name = k.rsplit(".", 1)[0] | |
| if layer_name.endswith(".weight"): | |
| layer_name = layer_name.rsplit(".", 1)[0] | |
| if layer_name not in layer_metadata: | |
| layer_metadata[layer_name] = {"format": quant_algo} | |
| metadata_full["_quantization_metadata"] = json.dumps( | |
| { | |
| "format_version": "1.0", | |
| "layers": layer_metadata, | |
| } | |
| ) | |
| metadata["_export_format"] = "safetensors_state_dict" | |
| metadata["_class_name"] = type(component).__name__ | |
| metadata_full.update(metadata) | |
| save_file( | |
| cpu_state_dict, | |
| str(component_export_dir / "model.safetensors"), | |
| metadata=metadata_full if merged_base_safetensor_path is not None else None, | |
| ) | |
| metadata: dict[str, str] = {} | |
| metadata_full: dict[str, str] = {} | |
| if merged_base_safetensor_path is not None and model_type is not None: | |
| merge_fn = DIFFUSION_MERGE_FUNCTIONS[model_type] | |
| cpu_state_dict, metadata_full = merge_fn(cpu_state_dict, merged_base_safetensor_path) | |
| if hf_quant_config is not None: | |
| metadata_full["quantization_config"] = json.dumps(hf_quant_config) | |
| # Build per-layer _quantization_metadata for ComfyUI | |
| quant_algo = hf_quant_config.get("quant_algo", "unknown").lower() | |
| layer_metadata = {} | |
| for k in cpu_state_dict: | |
| if k.endswith((".weight_scale", ".weight_scale_2")): | |
| layer_name = k.rsplit(".", 1)[0] | |
| if layer_name.endswith(".weight"): | |
| layer_name = layer_name.rsplit(".", 1)[0] | |
| if layer_name not in layer_metadata: | |
| layer_metadata[layer_name] = {"format": quant_algo} | |
| metadata_full["_quantization_metadata"] = json.dumps( | |
| { | |
| "format_version": "1.0", | |
| "layers": layer_metadata, | |
| } | |
| ) | |
| metadata["_export_format"] = "safetensors_state_dict" | |
| metadata["_class_name"] = type(component).__name__ | |
| metadata_full.update(metadata) | |
| save_file( | |
| cpu_state_dict, | |
| str(component_export_dir / "model.safetensors"), | |
| metadata=metadata_full, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 136 - 169, The
computed metadata (metadata and metadata_full) is discarded for non-merge
exports because save_file is only given metadata when
merged_base_safetensor_path is not None; change the save_file call in
unified_export_hf.py to always pass the assembled metadata (use metadata_full
which is updated with metadata) instead of conditionally passing None — update
the save_file invocation (function save_file, variables metadata_full,
merged_base_safetensor_path, cpu_state_dict, component_export_dir, component) to
use metadata=metadata_full unconditionally so _export_format and _class_name are
preserved for all exports.
| # Resolve model type once (only needed when merging with a base checkpoint) | ||
| model_type = get_diffusion_model_type(pipe) if merged_base_safetensor_path else None | ||
|
|
There was a problem hiding this comment.
get_diffusion_model_type is called unconditionally when merged_base_safetensor_path is truthy — will raise ValueError for non-LTX-2 diffusers pipelines.
If a user passes merged_base_safetensor_path for a standard diffusers pipeline (e.g., StableDiffusion), get_diffusion_model_type(pipe) will raise ValueError with a somewhat opaque message. Consider validating earlier or documenting this limitation more prominently in export_hf_checkpoint's docstring.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 885 - 887, The code
unconditionally calls get_diffusion_model_type(pipe) when
merged_base_safetensor_path is set, which will raise a ValueError for non-LTX-2
diffusers; update export_hf_checkpoint to first check the pipeline type (or a
predicate like is_ltx2_pipeline(pipe)) before calling get_diffusion_model_type,
and if merged_base_safetensor_path is provided for an unsupported pipeline
either raise a clearer, descriptive error mentioning export_hf_checkpoint and
merged_base_safetensor_path or document this constraint in the function
docstring so users aren’t met with an opaque ValueError from
get_diffusion_model_type.
| elif hasattr(component, "save_pretrained"): | ||
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | ||
| else: | ||
| _save_component_state_dict_safetensors(component, component_export_dir) | ||
| _save_component_state_dict_safetensors( | ||
| component, | ||
| component_export_dir, | ||
| merged_base_safetensor_path, | ||
| model_type=model_type, | ||
| ) |
There was a problem hiding this comment.
Non-quantized components also receive merged_base_safetensor_path — unintentional merge?
When a non-quantized component falls through to _save_component_state_dict_safetensors (lines 963-968), it receives merged_base_safetensor_path and model_type. This means the merge function will run on the non-quantized component's state dict too, adding all non-transformer base weights (VAE, vocoder, etc.) into it.
In the current LTX-2 flow there's only one component, so this is harmless. But for future model types with multiple components, this would produce incorrect merged checkpoints for non-quantized components.
Consider guarding by only passing merged_base_safetensor_path for quantized components, or adding a comment clarifying the assumption:
Proposed safeguard
else:
_save_component_state_dict_safetensors(
component,
component_export_dir,
- merged_base_safetensor_path,
- model_type=model_type,
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| elif hasattr(component, "save_pretrained"): | |
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | |
| else: | |
| _save_component_state_dict_safetensors(component, component_export_dir) | |
| _save_component_state_dict_safetensors( | |
| component, | |
| component_export_dir, | |
| merged_base_safetensor_path, | |
| model_type=model_type, | |
| ) | |
| elif hasattr(component, "save_pretrained"): | |
| component.save_pretrained(component_export_dir, max_shard_size=max_shard_size) | |
| else: | |
| _save_component_state_dict_safetensors( | |
| component, | |
| component_export_dir, | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 960 - 968, The
current path calls _save_component_state_dict_safetensors(component,
component_export_dir, merged_base_safetensor_path, model_type=...) for any
component that doesn't implement save_pretrained, which unintentionally applies
the merged_base_safetensor_path merge to non-quantized components; update the
logic so merged_base_safetensor_path is only passed when the component is
quantized (e.g., detect quantization via a flag or type check before calling
_save_component_state_dict_safetensors) or call
_save_component_state_dict_safetensors without merged_base_safetensor_path for
non-quantized components, ensuring references to
_save_component_state_dict_safetensors, merged_base_safetensor_path,
component.save_pretrained and model_type are used to locate and change the code.
jingyu-ml
left a comment
There was a problem hiding this comment.
LGTM from my side. Let’s wait for @Edwardf0t1’s review as well, and then we can merge.
…del(e.g., LTX-2) Signed-off-by: ynankani <[email protected]>
…del(e.g., LTX-2) Signed-off-by: ynankani <[email protected]>
Signed-off-by: ynankani <[email protected]>
Signed-off-by: ynankani <[email protected]>
Signed-off-by: ynankani <[email protected]>
7e74e91 to
429a793
Compare
| export_dir: The directory to save the exported checkpoint. | ||
| components: Optional list of component names to export. Only used for pipelines. | ||
| If None, all components are exported. | ||
| merged_base_safetensor_path: If provided, merge the exported transformer with |
There was a problem hiding this comment.
Could you further document what is this, e.g. giving an example?
| components: Only used for diffusers pipelines. Optional list of component names | ||
| to export. If None, all quantized components are exported. | ||
| extra_state_dict: Extra state dictionary to add to the exported model. | ||
| **kwargs: Internal-only keyword arguments. Supported keys: |
There was a problem hiding this comment.
why do we want to hide it instead of making it explicit?
Edwardf0t1
left a comment
There was a problem hiding this comment.
Thanks @ynankani for adding the modelopt ckpt support for ComfyUI. Please make sure we add it in our ckpt specs doc. @jingyu-ml
| Non-transformer components (VAE, vocoder, text encoders) and embeddings | ||
| connectors are taken from the base checkpoint. Transformer keys are | ||
| re-prefixed with ``model.diffusion_model.`` for ComfyUI compatibility. |
There was a problem hiding this comment.
Does ComfyUI require this merge in general, or it's ltx-2 specific?
|
|
||
| Args: | ||
| component: The nn.Module to save. | ||
| component_export_dir: Directory to save model.safetensors and config.json. |
There was a problem hiding this comment.
will config.json include per layer quantization config as well?
| @@ -112,19 +114,62 @@ def _is_enabled_quantizer(quantizer): | |||
|
|
|||
|
|
|||
| def _save_component_state_dict_safetensors( | |||
There was a problem hiding this comment.
@jingyu-ml @ynankani Do you think it's better to move this function to diffusers_utils.py?
What does this PR do
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)
Type of change:
Overview:
Add support for export ComfyUI compatible checkpoint for diffusion model(e.g., LTX-2)
Usage
Testing
a) initializing a twoStagePipeline object
b) calling mtq.quantize on transformer with NVFP4_DEFAULT_CFG
c) then exporting with export_hf_checkpoint passing the param merged_base_safetensor_path to generate merged
checkpoint
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes