Skip to content

model setup [WIP: do not merge] #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,4 @@ else
fi

# Install maxdiffusion
pip3 install -U . || echo "Failed to install maxdiffusion" >&2
pip3 install -e . || echo "Failed to install maxdiffusion" >&2
2 changes: 2 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@
_import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"]
_import_structure["models.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"]
_import_structure["models.flux.transformers.transformer_flux_flax"] = ["FluxTransformer2DModel"]
_import_structure["models.ltx_video.transformers.transformer3d"] = ["Transformer3DModel"]
_import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"]
_import_structure["pipelines"].extend(["FlaxDiffusionPipeline"])
_import_structure["schedulers"].extend(
Expand Down Expand Up @@ -453,6 +454,7 @@
from .models.modeling_flax_utils import FlaxModelMixin
from .models.unet_2d_condition_flax import FlaxUNet2DConditionModel
from .models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from .models.ltx_video.transformers.transformer3d import Transformer3DModel
from .models.vae_flax import FlaxAutoencoderKL
from .pipelines import FlaxDiffusionPipeline
from .schedulers import (
Expand Down
50 changes: 50 additions & 0 deletions src/maxdiffusion/configs/ltx_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'


run_name: ''
output_dir: 'ltx-video-output'
save_config_to_gcs: False

#parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_batch', ['data','fsdp']],
['activation_heads', 'tensor'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: -1
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1




learning_rate_schedule_steps: -1
max_train_steps: 500 #TODO: change this
pretrained_model_name_or_path: ''
unet_checkpoint: ''
dataset_name: 'diffusers/pokemon-gpt4-captions'
train_split: 'train'
dataset_type: 'tf'
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
jit_initializers: True
73 changes: 73 additions & 0 deletions src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from absl import app
from typing import Sequence
import jax
import json
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
import os
import functools
import jax.numpy as jnp
from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
from maxdiffusion.max_utils import (
create_device_mesh,
setup_initial_state,
)
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P


def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
print("prompts_embeds.shape: ", prompt_embeds.shape, prompt_embeds.dtype)
print("fractional_coords.shape: ", fractional_coords.shape, fractional_coords.dtype)
print("latents.shape: ", latents.shape, latents.dtype)
print("noise_cond.shape: ", noise_cond.shape, noise_cond.dtype)

def run(config):
key = jax.random.PRNGKey(0)

devices_array = create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)

batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
base_dir = os.path.dirname(__file__)

##load in model config
config_path = os.path.join(base_dir, "models/ltx_video/xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)


transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only = False)

key, split_key = jax.random.split(key)
weights_init_fn = functools.partial(
transformer.init_weights,
split_key,
batch_size,
text_tokens,
num_tokens,
features,
eval_only = False
)

transformer_state, transformer_state_shardings = setup_initial_state(
model=transformer,
tx=None,
config=config,
mesh=mesh,
weights_init_fn=weights_init_fn,
model_params=None,
training=False,
)



def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
run(pyconfig.config)


if __name__ == "__main__":
app.run(main)



4 changes: 2 additions & 2 deletions src/maxdiffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from maxdiffusion.utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available


_import_structure = {}
Expand All @@ -32,7 +32,7 @@
from .vae_flax import FlaxAutoencoderKL
from .lora import *
from .flux.transformers.transformer_flux_flax import FluxTransformer2DModel

from .ltx_video.transformers.transformer3d import Transformer3DModel
else:
import sys

Expand Down
Empty file.
70 changes: 70 additions & 0 deletions src/maxdiffusion/models/ltx_video/gradient_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from enum import Enum, auto
from typing import Optional

import jax
from flax import linen as nn

SKIP_GRADIENT_CHECKPOINT_KEY = "skip"


class GradientCheckpointType(Enum):
"""
Defines the type of the gradient checkpoint we will have

NONE - means no gradient checkpoint
FULL - means full gradient checkpoint, wherever possible (minimum memory usage)
MATMUL_WITHOUT_BATCH - means gradient checkpoint for every linear/matmul operation,
except for ones that involve batch dimension - that means that all attention and projection
layers will have gradient checkpoint, but not the backward with respect to the parameters
"""

NONE = auto()
FULL = auto()
MATMUL_WITHOUT_BATCH = auto()

@classmethod
def from_str(cls, s: Optional[str] = None) -> "GradientCheckpointType":
"""
Constructs the gradient checkpoint type from a string

Args:
s (Optional[str], optional): The name of the gradient checkpointing policy. Defaults to None.

Returns:
GradientCheckpointType: The policy that corresponds to the string
"""
if s is None:
s = "none"
return GradientCheckpointType[s.upper()]

def to_jax_policy(self):
"""
Converts the gradient checkpoint type to a jax policy
"""
match self:
case GradientCheckpointType.NONE:
return SKIP_GRADIENT_CHECKPOINT_KEY
case GradientCheckpointType.FULL:
return None
case GradientCheckpointType.MATMUL_WITHOUT_BATCH:
return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims

def apply(self, module: nn.Module) -> nn.Module:
"""
Applies a gradient checkpoint policy to a module
if no policy is needed, it will return the module as is

Args:
module (nn.Module): the module to apply the policy to

Returns:
nn.Module: the module with the policy applied
"""
policy = self.to_jax_policy()
if policy == SKIP_GRADIENT_CHECKPOINT_KEY:
return module
return nn.remat( # pylint: disable=invalid-name
module,
prevent_cse=False,
policy=policy,
)
109 changes: 109 additions & 0 deletions src/maxdiffusion/models/ltx_video/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Union, Iterable, Tuple, Optional, Callable

import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.linen.initializers import lecun_normal


Shape = Tuple[int, ...]
Initializer = Callable[[jax.random.PRNGKey, Shape, jax.numpy.dtype], jax.Array]
InitializerAxis = Union[int, Shape]


def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(ax if ax >= 0 else ndim + ax for ax in axes)


def _canonicalize_tuple(x):
if isinstance(x, Iterable):
return tuple(x)
else:
return (x,)


NdInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]
KernelInitializer = Callable[[jax.random.PRNGKey, Shape, jnp.dtype, InitializerAxis, InitializerAxis], jax.Array]


class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes.

Adapted from https://github.com/AI-Hypercomputer/maxtext/blob/4bf3beaa5e721745427bfed09938427e369c2aaf/MaxText/layers/linears.py#L86

Attributes:
features: tuple with numbers of output features.
axis: tuple with axes to apply the transformation on.
weight_dtype: the dtype of the weights (default: float32).
dtype: the dtype of the computation (default: float32).
kernel_init: initializer function for the weight matrix.
use_bias: whether to add bias in linear transformation.
bias_norm: whether to add normalization before adding bias.
quant: quantization config, defaults to None implying no quantization.
"""

features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1
weight_dtype: jnp.dtype = jnp.float32
dtype: np.dtype = jnp.float32
kernel_init: KernelInitializer = lecun_normal()
kernel_axes: Tuple[Optional[str], ...] = ()
use_bias: bool = False
matmul_precision: str = "default"

bias_init: Initializer = jax.nn.initializers.constant(0.0)

@nn.compact
def __call__(self, inputs: jax.Array) -> jax.Array:
"""Applies a linear transformation to the inputs along multiple dimensions.

Args:
inputs: The nd-array to be transformed.

Returns:
The transformed input.
"""

def compute_dot_general(inputs, kernel, axis, contract_ind):
"""Computes a dot_general operation that may be quantized."""
dot_general = jax.lax.dot_general
matmul_precision = jax.lax.Precision(self.matmul_precision)
return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=matmul_precision)

features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)

inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)

kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_in_axis = np.arange(len(axis))
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.weight_dtype,
)
kernel = jnp.asarray(kernel, self.dtype)

contract_ind = tuple(range(0, len(axis)))
output = compute_dot_general(inputs, kernel, axis, contract_ind)

if self.use_bias:
bias_axes, bias_shape = (
self.kernel_axes[-len(features) :],
kernel_shape[-len(features) :],
)
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, bias_axes),
bias_shape,
self.weight_dtype,
)
bias = jnp.asarray(bias, self.dtype)

output += bias
return output
40 changes: 40 additions & 0 deletions src/maxdiffusion/models/ltx_video/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import jax
import jax.numpy as jnp
import json


from models.transformers.transformer3d import Transformer3DModel

# Load JSON config
base_dir = os.path.dirname(__file__)
config_path = os.path.join(base_dir, "xora_v1.2-13B-balanced-128.json")
with open(config_path, "r") as f:
model_config = json.load(f)

key = jax.random.PRNGKey(0)
model = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")

batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
prompt_embeds = jax.random.normal(key, shape=(batch_size, text_tokens, features), dtype=jnp.bfloat16)
fractional_coords = jax.random.normal(key, shape=(batch_size, 3, num_tokens), dtype=jnp.bfloat16)
latents = jax.random.normal(key, shape=(batch_size, num_tokens, features), dtype=jnp.bfloat16)
noise_cond = jax.random.normal(key, shape=(batch_size, 1), dtype=jnp.bfloat16)

model_params = model.init(
hidden_states=latents,
indices_grid=fractional_coords,
encoder_hidden_states=prompt_embeds,
timestep=noise_cond,
rngs={"params": key}
)

output = model.apply(
model_params,
hidden_states=latents,
indices_grid=fractional_coords,
encoder_hidden_states=prompt_embeds,
timestep=noise_cond,
)

print("done!")
Loading
Loading