Skip to content

Commit

Permalink
Black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Dockhorn committed Feb 29, 2024
1 parent 1e30a2d commit c51e4e3
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 55 deletions.
107 changes: 66 additions & 41 deletions scripts/demo/gradio_app.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,56 @@
# Adding this at the very top of app.py to make 'generative-models' directory discoverable
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'generative-models'))
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "generative-models"))

import math
import random
import uuid
from glob import glob
from pathlib import Path
from typing import Optional

import cv2
import gradio as gr
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor

from scripts.sampling.simple_video_sample import (
get_batch, get_unique_embedder_keys_from_conditioner, load_model)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
from scripts.sampling.simple_video_sample import load_model, get_unique_embedder_keys_from_conditioner, get_batch

import gradio as gr
import uuid
import random
from huggingface_hub import hf_hub_download

# To download all svd models
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
#hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt", filename="svd_xt.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid", filename="svd.safetensors", local_dir="checkpoints")
# hf_hub_download(repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1", filename="svd_xt_1_1.safetensors", local_dir="checkpoints")


# Define the repo, local directory and filename
repo_id="stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
repo_id = "stabilityai/stable-video-diffusion-img2vid-xt-1-1" # replace with "stabilityai/stable-video-diffusion-img2vid-xt" or "stabilityai/stable-video-diffusion-img2vid" for other models
filename = "svd_xt_1_1.safetensors" # replace with "svd_xt.safetensors" or "svd.safetensors" for other models
local_dir = "checkpoints"
local_file_path = os.path.join(local_dir, filename)

# Check if the file already exists
if not os.path.exists(local_file_path):
# If the file doesn't exist, download it
hf_hub_download(
repo_id=repo_id,
filename=filename,
local_dir=local_dir
)
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir)
print("File downloaded.")
else:
print("File already exists. No need to download.")


version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models
version = "svd_xt_1_1" # replace with 'svd_xt' or 'svd' for other models
device = "cuda"
max_64_bit_int = 2**63 - 1

Expand All @@ -71,6 +68,7 @@
num_steps,
)


def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
seed: Optional[int] = None,
Expand All @@ -82,18 +80,18 @@ def sample(
decoding_t: int = 7, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
output_folder: str = "outputs",
progress=gr.Progress(track_tqdm=True)
progress=gr.Progress(track_tqdm=True),
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
fps_id = int(fps_id ) #casting float slider values to int)
if(randomize_seed):
fps_id = int(fps_id) # casting float slider values to int)
if randomize_seed:
seed = random.randint(0, max_64_bit_int)

torch.manual_seed(seed)

path = Path(input_path)
all_img_paths = []
if path.is_file():
Expand Down Expand Up @@ -223,7 +221,7 @@ def denoiser(input, sigma, c):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()

return video_path, seed


Expand Down Expand Up @@ -260,24 +258,51 @@ def resize_image(image_path, output_size=(1024, 576)):

return cropped_image


with gr.Blocks() as demo:
gr.Markdown('''# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
gr.Markdown(
"""# Community demo for Stable Video Diffusion - Img2Vid - XT ([model](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt), [paper](https://stability.ai/research/stable-video-diffusion-scaling-latent-video-diffusion-models-to-large-datasets))
#### Research release ([_non-commercial_](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/LICENSE)): generate `4s` vid from a single image at (`25 frames` at `6 fps`). Generation takes ~60s in an A100. [Join the waitlist for Stability's upcoming web experience](https://stability.ai/contact).
''')
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="filepath")
generate_btn = gr.Button("Generate")
video = gr.Video()
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(label="Seed", value=42, randomize=True, minimum=0, maximum=max_64_bit_int, step=1)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(label="Motion bucket id", info="Controls how much motion to add/remove from the image", value=127, minimum=1, maximum=255)
fps_id = gr.Slider(label="Frames per second", info="The length of your video in seconds will be 25/fps", value=6, minimum=5, maximum=30)

image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
generate_btn.click(fn=sample, inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id], outputs=[video, seed], api_name="video")

"""
)
with gr.Row():
with gr.Column():
image = gr.Image(label="Upload your image", type="filepath")
generate_btn = gr.Button("Generate")
video = gr.Video()
with gr.Accordion("Advanced options", open=False):
seed = gr.Slider(
label="Seed",
value=42,
randomize=True,
minimum=0,
maximum=max_64_bit_int,
step=1,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
motion_bucket_id = gr.Slider(
label="Motion bucket id",
info="Controls how much motion to add/remove from the image",
value=127,
minimum=1,
maximum=255,
)
fps_id = gr.Slider(
label="Frames per second",
info="The length of your video in seconds will be 25/fps",
value=6,
minimum=5,
maximum=30,
)

image.upload(fn=resize_image, inputs=image, outputs=image, queue=False)
generate_btn.click(
fn=sample,
inputs=[image, seed, randomize_seed, motion_bucket_id, fps_id],
outputs=[video, seed],
api_name="video",
)

if __name__ == "__main__":
demo.queue(max_size=20)
demo.launch(share=True)
21 changes: 16 additions & 5 deletions scripts/demo/turbo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from streamlit_helpers import *
from st_keyup import st_keyup
from streamlit_helpers import *

from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler

VERSION2SPECS = {
Expand Down Expand Up @@ -193,7 +194,7 @@ def decrement_counter():

with head_cols[2]:
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4)

sampler = SubstepSampler(
n_sample_steps=1,
num_steps=1000,
Expand All @@ -203,8 +204,12 @@ def decrement_counter():
),
)
sampler.n_sample_steps = n_steps
default_prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
prompt = st_keyup("Enter a value", value=default_prompt, debounce=300, key="interactive_text")
default_prompt = (
"A cinematic shot of a baby racoon wearing an intricate italian priest robe."
)
prompt = st_keyup(
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
)

cols = st.columns([1, 5, 1])
if mode != "skip":
Expand All @@ -217,7 +222,13 @@ def decrement_counter():

sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
out = sample(
model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter")
model,
sampler,
H=512,
W=512,
seed=st.session_state.seed,
prompt=prompt,
filter=state.get("filter"),
)
with cols[1]:
st.image(out[0])
14 changes: 6 additions & 8 deletions sgm/modules/autoencoding/temporal_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
import torch
from einops import rearrange, repeat

from sgm.modules.diffusionmodules.model import (
XFORMERS_IS_AVAILABLE,
AttnBlock,
Decoder,
MemoryEfficientAttnBlock,
ResnetBlock,
)
from sgm.modules.diffusionmodules.openaimodel import ResBlock, timestep_embedding
from sgm.modules.diffusionmodules.model import (XFORMERS_IS_AVAILABLE,
AttnBlock, Decoder,
MemoryEfficientAttnBlock,
ResnetBlock)
from sgm.modules.diffusionmodules.openaimodel import (ResBlock,
timestep_embedding)
from sgm.modules.video_attention import VideoTransformerBlock
from sgm.util import partialclass

Expand Down
3 changes: 2 additions & 1 deletion sgm/modules/video_attention.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch

from ..modules.attention import *
from ..modules.diffusionmodules.util import AlphaBlender, linear, timestep_embedding
from ..modules.diffusionmodules.util import (AlphaBlender, linear,
timestep_embedding)


class TimeMixSequential(nn.Sequential):
Expand Down

0 comments on commit c51e4e3

Please sign in to comment.