-
Notifications
You must be signed in to change notification settings - Fork 39
LTXVid Transformer Pytorch-Jax Conversion script #193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@@ -213,7 +213,10 @@ def load_state_if_possible( | |||
max_logging.log(f"restoring from this run's directory latest step {latest_step}") | |||
try: | |||
if not enable_single_replica_ckpt_restoring: | |||
item = {checkpoint_item: orbax.checkpoint.args.PyTreeRestore(item=abstract_unboxed_pre_state)} | |||
if checkpoint_item == " ": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment as Juan from previous PR, why is checkpoint == " "
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if checkpoint set to None, cannot pass the check "if checkpoint_manager and checkpoint_item:" in max_utils.py. So I set it to empty string to get around this
axis = _normalize_axes(axis, inputs.ndim) | ||
|
||
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features | ||
# kernel_in_axis = np.arange(len(axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove commented lines
return t_emb | ||
|
||
|
||
class AlphaCombinedTimestepSizeEmbeddings(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complete the docstring?
Converts ltxv-13b-0.9.7-dev.safetensors from lightricks huggingface into JAX weight checkpoint.
See running instruction at https://github.com/AI-Hypercomputer/maxdiffusion/blob/conversion-script/src/maxdiffusion/models/ltx_video/utils/conversion_script_instruction.md