Skip to content

Commit 1afc218

Browse files
lawrence-cjdg845HeliosZhaogithub-actions[bot]
authored
SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)
* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify; * fix bug and run text/image-to-vidoe success; * make style; quality; fix-copies; * add sana image-to-video pipeline in markdown; * add test case for sana image-to-video; * make style; * add a init file in sana-video test dir; * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update tests/pipelines/sana_video/test_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update tests/pipelines/sana_video/test_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * minor update; * fix bug and skip fp16 save test; Co-authored-by: Yuyang Zhao <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py Co-authored-by: dg845 <[email protected]> * add copied from for `encode_prompt` * Apply style fixes --------- Co-authored-by: dg845 <[email protected]> Co-authored-by: Yuyang Zhao <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 0c35b58 commit 1afc218

File tree

15 files changed

+1501
-34
lines changed

15 files changed

+1501
-34
lines changed

docs/source/en/api/pipelines/sana_video.md

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License. -->
1414

15-
# SanaVideoPipeline
15+
# Sana-Video
1616

1717
<div class="flex flex-wrap space-x-1">
1818
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
@@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
3737

3838
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
3939

40+
41+
## Generation Pipelines
42+
43+
<hfoptions id="generation pipelines">`
44+
<hfoption id="Text-to-Video">
45+
46+
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
47+
48+
```python
49+
model_id =
50+
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
51+
pipe.text_encoder.to(torch.bfloat16)
52+
pipe.vae.to(torch.float32)
53+
pipe.to("cuda")
54+
55+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
56+
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
57+
motion_scale = 30
58+
motion_prompt = f" motion score: {motion_scale}."
59+
prompt = prompt + motion_prompt
60+
61+
video = pipe(
62+
prompt=prompt,
63+
negative_prompt=negative_prompt,
64+
height=480,
65+
width=832,
66+
frames=81,
67+
guidance_scale=6,
68+
num_inference_steps=50,
69+
generator=torch.Generator(device="cuda").manual_seed(0),
70+
).frames[0]
71+
72+
export_to_video(video, "sana_video.mp4", fps=16)
73+
```
74+
75+
</hfoption>
76+
<hfoption id="Image-to-Video">
77+
78+
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
79+
80+
```python
81+
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
82+
pipe = SanaImageToVideoPipeline.from_pretrained(
83+
model_id,
84+
torch_dtype=torch.bfloat16,
85+
)
86+
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
87+
pipe.vae.to(torch.float32)
88+
pipe.text_encoder.to(torch.bfloat16)
89+
pipe.to("cuda")
90+
91+
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
92+
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
93+
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
94+
motion_scale = 30
95+
motion_prompt = f" motion score: {motion_scale}."
96+
prompt = prompt + motion_prompt
97+
98+
motion_scale = 30.0
99+
100+
video = pipe(
101+
image=image,
102+
prompt=prompt,
103+
negative_prompt=negative_prompt,
104+
height=480,
105+
width=832,
106+
frames=81,
107+
guidance_scale=6,
108+
num_inference_steps=50,
109+
generator=torch.Generator(device="cuda").manual_seed(0),
110+
).frames[0]
111+
112+
export_to_video(video, "sana-i2v.mp4", fps=16)
113+
```
114+
115+
</hfoption>
116+
</hfoptions>
117+
118+
40119
## Quantization
41120

42121
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
@@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
97176
- __call__
98177

99178

179+
## SanaImageToVideoPipeline
180+
181+
[[autodoc]] SanaImageToVideoPipeline
182+
- all
183+
- __call__
184+
185+
100186
## SanaVideoPipelineOutput
101187

102-
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
188+
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput

scripts/convert_sana_video_to_diffusers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def main(args):
8080

8181
# scheduler
8282
flow_shift = 8.0
83+
if args.task == "i2v":
84+
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
8385

8486
# model config
8587
layer_num = 20
@@ -312,6 +314,7 @@ def main(args):
312314
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
313315
help="Scheduler type to use.",
314316
)
317+
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
315318
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
316319
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
317320
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

src/diffusers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,13 @@
545545
"QwenImagePipeline",
546546
"ReduxImageEncoder",
547547
"SanaControlNetPipeline",
548+
"SanaImageToVideoPipeline",
548549
"SanaPAGPipeline",
549550
"SanaPipeline",
550551
"SanaSprintImg2ImgPipeline",
551552
"SanaSprintPipeline",
552553
"SanaVideoPipeline",
554+
"SanaVideoPipeline",
553555
"SemanticStableDiffusionPipeline",
554556
"ShapEImg2ImgPipeline",
555557
"ShapEPipeline",
@@ -1227,6 +1229,7 @@
12271229
QwenImagePipeline,
12281230
ReduxImageEncoder,
12291231
SanaControlNetPipeline,
1232+
SanaImageToVideoPipeline,
12301233
SanaPAGPipeline,
12311234
SanaPipeline,
12321235
SanaSprintImg2ImgPipeline,

src/diffusers/models/transformers/transformer_sana_video.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
237237
return freqs_cos, freqs_sin
238238

239239

240-
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
241240
class SanaModulatedNorm(nn.Module):
242241
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
243242
super().__init__()
@@ -247,7 +246,7 @@ def forward(
247246
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
248247
) -> torch.Tensor:
249248
hidden_states = self.norm(hidden_states)
250-
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
249+
shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
251250
hidden_states = hidden_states * (1 + scale) + shift
252251
return hidden_states
253252

@@ -423,8 +422,8 @@ def forward(
423422

424423
# 1. Modulation
425424
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
426-
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
427-
).chunk(6, dim=1)
425+
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
426+
).unbind(dim=2)
428427

429428
# 2. Self Attention
430429
norm_hidden_states = self.norm1(hidden_states)
@@ -635,13 +634,16 @@ def forward(
635634

636635
if guidance is not None:
637636
timestep, embedded_timestep = self.time_embed(
638-
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
637+
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
639638
)
640639
else:
641640
timestep, embedded_timestep = self.time_embed(
642-
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
641+
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
643642
)
644643

644+
timestep = timestep.view(batch_size, -1, timestep.size(-1))
645+
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
646+
645647
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
646648
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
647649

src/diffusers/pipelines/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,10 @@
308308
"SanaSprintPipeline",
309309
"SanaControlNetPipeline",
310310
"SanaSprintImg2ImgPipeline",
311+
]
312+
_import_structure["sana_video"] = [
311313
"SanaVideoPipeline",
314+
"SanaImageToVideoPipeline",
312315
]
313316
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
314317
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -749,8 +752,8 @@
749752
SanaPipeline,
750753
SanaSprintImg2ImgPipeline,
751754
SanaSprintPipeline,
752-
SanaVideoPipeline,
753755
)
756+
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
754757
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
755758
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
756759
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

src/diffusers/pipelines/sana/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
2727
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
2828
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
29-
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
3029

3130
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
3231
try:
@@ -40,7 +39,6 @@
4039
from .pipeline_sana_controlnet import SanaControlNetPipeline
4140
from .pipeline_sana_sprint import SanaSprintPipeline
4241
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
43-
from .pipeline_sana_video import SanaVideoPipeline
4442
else:
4543
import sys
4644

src/diffusers/pipelines/sana/pipeline_output.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import PIL.Image
6-
import torch
76

87
from ...utils import BaseOutput
98

@@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput):
2019
"""
2120

2221
images: Union[List[PIL.Image.Image], np.ndarray]
23-
24-
25-
@dataclass
26-
class SanaVideoPipelineOutput(BaseOutput):
27-
r"""
28-
Output class for Sana-Video pipelines.
29-
30-
Args:
31-
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
32-
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
33-
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
34-
`(batch_size, num_frames, channels, height, width)`.
35-
"""
36-
37-
frames: torch.Tensor
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import TYPE_CHECKING
2+
3+
from ...utils import (
4+
DIFFUSERS_SLOW_IMPORT,
5+
OptionalDependencyNotAvailable,
6+
_LazyModule,
7+
get_objects_from_module,
8+
is_torch_available,
9+
is_transformers_available,
10+
)
11+
12+
13+
_dummy_objects = {}
14+
_import_structure = {}
15+
16+
17+
try:
18+
if not (is_transformers_available() and is_torch_available()):
19+
raise OptionalDependencyNotAvailable()
20+
except OptionalDependencyNotAvailable:
21+
from ...utils import dummy_torch_and_transformers_objects # noqa F403
22+
23+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24+
else:
25+
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
26+
_import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
27+
28+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
29+
try:
30+
if not (is_transformers_available() and is_torch_available()):
31+
raise OptionalDependencyNotAvailable()
32+
33+
except OptionalDependencyNotAvailable:
34+
from ...utils.dummy_torch_and_transformers_objects import *
35+
else:
36+
from .pipeline_sana_video import SanaVideoPipeline
37+
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
38+
else:
39+
import sys
40+
41+
sys.modules[__name__] = _LazyModule(
42+
__name__,
43+
globals()["__file__"],
44+
_import_structure,
45+
module_spec=__spec__,
46+
)
47+
48+
for name, value in _dummy_objects.items():
49+
setattr(sys.modules[__name__], name, value)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
5+
from ...utils import BaseOutput
6+
7+
8+
@dataclass
9+
class SanaVideoPipelineOutput(BaseOutput):
10+
r"""
11+
Output class for Sana-Video pipelines.
12+
13+
Args:
14+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17+
`(batch_size, num_frames, channels, height, width)`.
18+
"""
19+
20+
frames: torch.Tensor

src/diffusers/pipelines/sana/pipeline_sana_video.py renamed to src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,16 @@
9595
>>> from diffusers import SanaVideoPipeline
9696
>>> from diffusers.utils import export_to_video
9797
98-
>>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
99-
>>> pipe = SanaVideoPipeline.from_pretrained(model_id)
98+
>>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
10099
>>> pipe.transformer.to(torch.bfloat16)
101100
>>> pipe.text_encoder.to(torch.bfloat16)
102101
>>> pipe.vae.to(torch.float32)
103102
>>> pipe.to("cuda")
104-
>>> model_score = 30
103+
>>> motion_score = 30
105104
106105
>>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
107106
>>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
108-
>>> motion_prompt = f" motion score: {model_score}."
107+
>>> motion_prompt = f" motion score: {motion_score}."
109108
>>> prompt = prompt + motion_prompt
110109
111110
>>> output = pipe(
@@ -231,6 +230,7 @@ def __init__(
231230

232231
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
233232

233+
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
234234
def _get_gemma_prompt_embeds(
235235
self,
236236
prompt: Union[str, List[str]],
@@ -827,9 +827,9 @@ def __call__(
827827
Examples:
828828
829829
Returns:
830-
[`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
831-
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
832-
otherwise a `tuple` is returned where the first element is a list with the generated videos
830+
[`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
831+
If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
832+
returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
833833
"""
834834

835835
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):

0 commit comments

Comments
 (0)