Skip to content

Commit 463e3aa

Browse files
Add MaskGIT sampling, ST-attention, LAM codebook resets (#5)
* Sampling WIP * Add notebook * Add sample script * autoregressive sampling and full sequence prompting * Maintain sampled tokens in maskgit generation * autoregressive generation * Clean up sampling code, add arguments * Fix new frame shapes in sampling * Add image logging and rng split * MaskGitScanning * adding psnr and interweaved video logging * Add sweep * Add ST positional embedding * Add temp docker runner * Refactor MaskGIT, fix cosine schedule + token sampling * Reset inactive latent actions in LAM training * Sampling code updates, log dynamics model statistics * Refactor maskgit step * Remove dev file paths * Remove sample notebook * Log gifs and refactor sampling * Black formatting * Remove sweep config --------- Co-authored-by: Aidandos <[email protected]>
1 parent 0e349d7 commit 463e3aa

9 files changed

+337
-44
lines changed

.gitignore

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
*.pyc
2+
*.npy
3+
*.png
4+
*.gif
5+
6+
wandb_key
7+
checkpoints/
8+
wandb/
9+
__pycache__/

genie.py

+116-10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Any, Optional
1+
from typing import Dict, Any
22

33
from orbax.checkpoint import PyTreeCheckpointer
44
import jax
@@ -32,8 +32,8 @@ class Genie(nn.Module):
3232
dyna_dim: int
3333
dyna_num_blocks: int
3434
dyna_num_heads: int
35-
dropout: float
36-
mask_limit: float
35+
dropout: float = 0.0
36+
mask_limit: float = 0.0
3737

3838
def setup(self):
3939
self.tokenizer = TokenizerVQVAE(
@@ -83,19 +83,125 @@ def __call__(self, batch: Dict[str, Any], training: bool = True) -> Dict[str, An
8383
)
8484
return outputs
8585

86+
@nn.compact
87+
def sample(
88+
self,
89+
batch: Dict[str, Any],
90+
steps: int = 25,
91+
temperature: int = 1,
92+
sample_argmax: bool = False,
93+
) -> Any:
94+
# --- Encode videos and actions ---
95+
tokenizer_out = self.tokenizer.vq_encode(batch["videos"], training=False)
96+
token_idxs = tokenizer_out["indices"]
97+
new_frame_idxs = jnp.zeros_like(token_idxs)[:, 0]
98+
action_tokens = self.lam.vq.get_codes(batch["latent_actions"])
8699

87-
def restore_genie_checkpoint(
88-
params: Dict[str, Any], tokenizer: str, lam: str, dyna: Optional[str] = None
89-
):
100+
# --- Initialize MaskGIT ---
101+
init_mask = jnp.ones_like(token_idxs, dtype=bool)[:, 0]
102+
init_carry = (
103+
batch["rng"],
104+
new_frame_idxs,
105+
init_mask,
106+
token_idxs,
107+
action_tokens,
108+
)
109+
MaskGITLoop = nn.scan(
110+
MaskGITStep,
111+
variable_broadcast="params",
112+
split_rngs={"params": False},
113+
in_axes=0,
114+
out_axes=0,
115+
length=steps,
116+
)
117+
118+
# --- Run MaskGIT loop ---
119+
loop_fn = MaskGITLoop(
120+
dynamics=self.dynamics,
121+
tokenizer=self.tokenizer,
122+
temperature=temperature,
123+
sample_argmax=sample_argmax,
124+
steps=steps,
125+
)
126+
final_carry, _ = loop_fn(init_carry, jnp.arange(steps))
127+
new_frame_idxs = final_carry[1]
128+
new_frame_pixels = self.tokenizer.decode(
129+
jnp.expand_dims(new_frame_idxs, 1),
130+
video_hw=batch["videos"].shape[2:4],
131+
)
132+
return new_frame_pixels
133+
134+
def vq_encode(self, batch, training) -> Dict[str, Any]:
135+
# --- Preprocess videos ---
136+
lam_output = self.lam.vq_encode(batch["videos"], training=training)
137+
return lam_output["indices"]
138+
139+
140+
class MaskGITStep(nn.Module):
141+
dynamics: nn.Module
142+
tokenizer: nn.Module
143+
temperature: float
144+
sample_argmax: bool
145+
steps: int
146+
147+
@nn.compact
148+
def __call__(self, carry, x):
149+
rng, final_token_idxs, mask, token_idxs, action_tokens = carry
150+
step = x
151+
B, T, N = token_idxs.shape[:3]
152+
153+
# --- Construct + encode video ---
154+
vid_token_idxs = jnp.concatenate(
155+
(token_idxs, jnp.expand_dims(final_token_idxs, 1)), axis=1
156+
)
157+
vid_embed = self.dynamics.patch_embed(vid_token_idxs)
158+
curr_masked_frame = jnp.where(
159+
jnp.expand_dims(mask, -1),
160+
self.dynamics.mask_token[0],
161+
vid_embed[:, -1],
162+
)
163+
vid_embed = vid_embed.at[:, -1].set(curr_masked_frame)
164+
165+
# --- Predict transition ---
166+
act_embed = self.dynamics.action_up(action_tokens)
167+
vid_embed += jnp.pad(act_embed, ((0, 0), (1, 0), (0, 0), (0, 0)))
168+
unmasked_ratio = jnp.cos(jnp.pi * (step + 1) / (self.steps * 2))
169+
step_temp = self.temperature * (1.0 - unmasked_ratio)
170+
final_logits = self.dynamics.dynamics(vid_embed)[:, -1] / step_temp
171+
172+
# --- Sample new tokens for final frame ---
173+
if self.sample_argmax:
174+
sampled_token_idxs = jnp.argmax(final_logits, axis=-1)
175+
else:
176+
rng, _rng = jax.random.split(rng)
177+
sampled_token_idxs = jnp.where(
178+
step == self.steps - 1,
179+
jnp.argmax(final_logits, axis=-1),
180+
jax.random.categorical(_rng, final_logits),
181+
)
182+
gather_fn = jax.vmap(jax.vmap(lambda x, y: x[y]))
183+
final_token_probs = gather_fn(jax.nn.softmax(final_logits), sampled_token_idxs)
184+
final_token_probs += ~mask
185+
# Update masked tokens only
186+
new_token_idxs = jnp.where(mask, sampled_token_idxs, final_token_idxs)
187+
188+
# --- Update mask ---
189+
num_unmasked_tokens = jnp.round(N * (1.0 - unmasked_ratio)).astype(int)
190+
idx_mask = jnp.arange(final_token_probs.shape[-1]) > num_unmasked_tokens
191+
sorted_idxs = jnp.argsort(final_token_probs, axis=-1, descending=True)
192+
mask_update_fn = jax.vmap(lambda msk, ids: msk.at[ids].set(idx_mask))
193+
new_mask = mask_update_fn(mask, sorted_idxs)
194+
195+
new_carry = (rng, new_token_idxs, new_mask, token_idxs, action_tokens)
196+
return new_carry, None
197+
198+
199+
def restore_genie_components(params: Dict[str, Any], tokenizer: str, lam: str):
90200
"""Restore pre-trained Genie components"""
91201
params["params"]["tokenizer"].update(
92202
PyTreeCheckpointer().restore(tokenizer)["model"]["params"]["params"]
93203
)
94204
params["params"]["lam"].update(
95205
PyTreeCheckpointer().restore(lam)["model"]["params"]["params"]
96206
)
97-
if dyna:
98-
params["params"]["dyna"].update(
99-
PyTreeCheckpointer().restore(dyna)["model"]["params"]["params"]
100-
)
101207
return params

models/lam.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def vq_encode(self, videos: Any, training: bool = True) -> Dict[str, Any]:
7474
# --- Encode ---
7575
z = self.encoder(padded_patches) # (B, T, N, E)
7676
# Get latent action for all future frames
77-
z = z[:, 1:, 0] # (B, T-1, 1, E)
77+
z = z[:, 1:, 0] # (B, T-1, E)
7878

7979
# --- Vector quantize ---
8080
z = z.reshape(B * (T - 1), self.latent_dim)

run_docker.sh

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/bin/bash
2+
gpu=$1
3+
script_and_args="${@:2}"
4+
WANDB_API_KEY=$(cat ./docker/wandb_key)
5+
git pull
6+
7+
echo "Launching container jafar_$gpu on GPU $gpu"
8+
docker run \
9+
--env CUDA_VISIBLE_DEVICES=$gpu \
10+
--gpus all \
11+
-e WANDB_API_KEY=$WANDB_API_KEY \
12+
-v $(pwd):/home/duser/jafar \
13+
--name jafar\_$gpu \
14+
--user $(id -u) \
15+
--rm \
16+
-d \
17+
jafar \
18+
/bin/bash -c "$script_and_args"

sample.py

+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from dataclasses import dataclass
2+
import time
3+
4+
import dm_pix as pix
5+
import einops
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy as np
9+
from orbax.checkpoint import PyTreeCheckpointer
10+
from PIL import Image, ImageDraw
11+
import tyro
12+
13+
from data.dataloader import get_dataloader
14+
from genie import Genie
15+
16+
17+
@dataclass
18+
class Args:
19+
# Experiment
20+
seed: int = 0
21+
seq_len: int = 16
22+
image_channels: int = 3
23+
image_resolution: int = 64
24+
file_path: str = "data/coinrun.npy"
25+
checkpoint: str = ""
26+
# Sampling
27+
batch_size: int = 1
28+
maskgit_steps: int = 25
29+
temperature: float = 1.0
30+
sample_argmax: bool = False
31+
start_frame: int = 0
32+
# Tokenizer checkpoint
33+
tokenizer_dim: int = 512
34+
latent_patch_dim: int = 32
35+
num_patch_latents: int = 1024
36+
patch_size: int = 4
37+
tokenizer_num_blocks: int = 8
38+
tokenizer_num_heads: int = 8
39+
# LAM checkpoint
40+
lam_dim: int = 512
41+
latent_action_dim: int = 32
42+
num_latent_actions: int = 6
43+
lam_patch_size: int = 8
44+
lam_num_blocks: int = 8
45+
lam_num_heads: int = 8
46+
# Dynamics checkpoint
47+
dyna_dim: int = 512
48+
dyna_num_blocks: int = 12
49+
dyna_num_heads: int = 8
50+
51+
52+
args = tyro.cli(Args)
53+
rng = jax.random.PRNGKey(args.seed)
54+
55+
# --- Load Genie checkpoint ---
56+
genie = Genie(
57+
# Tokenizer
58+
in_dim=args.image_channels,
59+
tokenizer_dim=args.tokenizer_dim,
60+
latent_patch_dim=args.latent_patch_dim,
61+
num_patch_latents=args.num_patch_latents,
62+
patch_size=args.patch_size,
63+
tokenizer_num_blocks=args.tokenizer_num_blocks,
64+
tokenizer_num_heads=args.tokenizer_num_heads,
65+
# LAM
66+
lam_dim=args.lam_dim,
67+
latent_action_dim=args.latent_action_dim,
68+
num_latent_actions=args.num_latent_actions,
69+
lam_patch_size=args.lam_patch_size,
70+
lam_num_blocks=args.lam_num_blocks,
71+
lam_num_heads=args.lam_num_heads,
72+
# Dynamics
73+
dyna_dim=args.dyna_dim,
74+
dyna_num_blocks=args.dyna_num_blocks,
75+
dyna_num_heads=args.dyna_num_heads,
76+
)
77+
rng, _rng = jax.random.split(rng)
78+
image_shape = (args.image_resolution, args.image_resolution, args.image_channels)
79+
dummy_inputs = dict(
80+
videos=jnp.zeros((args.batch_size, args.seq_len, *image_shape), dtype=jnp.float32),
81+
mask_rng=_rng,
82+
)
83+
rng, _rng = jax.random.split(rng)
84+
params = genie.init(_rng, dummy_inputs)
85+
ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"]
86+
params["params"].update(ckpt)
87+
88+
# --- Get video + latent actions ---
89+
dataloader = get_dataloader(args.file_path, args.seq_len, args.batch_size)
90+
for vids in dataloader:
91+
video_batch = jnp.array(vids, dtype=jnp.float32) / 255.0
92+
break
93+
batch = dict(videos=video_batch)
94+
lam_output = genie.apply(params, batch, False, method=Genie.vq_encode)
95+
lam_output = lam_output.reshape(args.batch_size, args.seq_len - 1, 1)
96+
97+
98+
# --- Define autoregressive sampling loop ---
99+
def _autoreg_sample(rng, video_batch):
100+
vid = video_batch[:, : args.start_frame + 1]
101+
for frame_idx in range(args.start_frame + 1, args.seq_len):
102+
# --- Sample next frame ---
103+
print("Frame", frame_idx)
104+
rng, _rng = jax.random.split(rng)
105+
batch = dict(videos=vid, latent_actions=lam_output[:, :frame_idx], rng=_rng)
106+
new_frame = genie.apply(
107+
params,
108+
batch,
109+
args.maskgit_steps,
110+
args.temperature,
111+
args.sample_argmax,
112+
method=Genie.sample,
113+
)
114+
vid = jnp.concatenate([vid, new_frame], axis=1)
115+
return vid
116+
117+
118+
# --- Sample + evaluate video ---
119+
vid = _autoreg_sample(rng, video_batch)
120+
gt = video_batch[:, : vid.shape[1]].clip(0, 1).reshape(-1, *video_batch.shape[2:])
121+
recon = vid.clip(0, 1).reshape(-1, *vid.shape[2:])
122+
ssim = pix.ssim(gt[:, args.start_frame + 1 :], recon[:, args.start_frame + 1 :]).mean()
123+
print(f"SSIM: {ssim}")
124+
125+
# --- Save generated video ---
126+
original_frames = (video_batch * 255).astype(np.uint8)
127+
interweaved_frames = np.zeros((vid.shape[0] * 2, *vid.shape[1:5]), dtype=np.uint8)
128+
interweaved_frames[0::2] = original_frames[:, : vid.shape[1]]
129+
interweaved_frames[1::2] = (vid * 255).astype(np.uint8)
130+
flat_vid = einops.rearrange(interweaved_frames, "n t h w c -> t h (n w) c")
131+
imgs = [Image.fromarray(img) for img in flat_vid]
132+
for img, action in zip(imgs[1:], lam_output[0, :, 0]):
133+
d = ImageDraw.Draw(img)
134+
d.text((2, 2), f"{action}", fill=255)
135+
imgs[0].save(
136+
f"generation_{time.time()}.gif",
137+
save_all=True,
138+
append_images=imgs[1:],
139+
duration=250,
140+
loop=0,
141+
)

train_dynamics.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import numpy as np
1111
import jax
1212
import jax.numpy as jnp
13-
import wandb
1413
import tyro
14+
import wandb
1515

16-
from genie import Genie, restore_genie_checkpoint
1716
from data.dataloader import get_dataloader
17+
from genie import Genie, restore_genie_components
1818

1919
ts = int(time.time())
2020

@@ -103,7 +103,7 @@ class Args:
103103
)
104104
rng, _rng = jax.random.split(rng)
105105
init_params = genie.init(_rng, dummy_inputs)
106-
init_params = restore_genie_checkpoint(
106+
init_params = restore_genie_components(
107107
init_params, args.tokenizer_checkpoint, args.lam_checkpoint
108108
)
109109
lr_schedule = optax.warmup_cosine_decay_schedule(
@@ -113,8 +113,9 @@ class Args:
113113
train_state = TrainState.create(apply_fn=genie.apply, params=init_params, tx=tx)
114114

115115

116+
# --- Define dynamics loss + train step ---
116117
def dynamics_loss_fn(params, state, inputs):
117-
# --- Compute masked loss ---
118+
"""Compute masked dynamics loss"""
118119
outputs = state.apply_fn(
119120
params, inputs, training=True, rngs={"dropout": inputs["dropout_rng"]}
120121
)
@@ -125,13 +126,20 @@ def dynamics_loss_fn(params, state, inputs):
125126
ce_loss = (mask * ce_loss).sum() / mask.sum()
126127
acc = outputs["token_logits"].argmax(-1) == outputs["video_tokens"]
127128
acc = (mask * acc).sum() / mask.sum()
128-
metrics = dict(cross_entropy_loss=ce_loss, masked_token_accuracy=acc)
129+
select_probs = jax.nn.softmax(outputs["token_logits"])
130+
metrics = dict(
131+
cross_entropy_loss=ce_loss,
132+
masked_token_accuracy=acc,
133+
select_logit=outputs["token_logits"].max(-1).mean(),
134+
select_p=select_probs.max(-1).mean(),
135+
entropy=jax.scipy.special.entr(select_probs).sum(-1).mean(),
136+
)
129137
return ce_loss, (outputs["recon"], metrics)
130138

131139

132-
# --- Define train step ---
133140
@jax.jit
134141
def train_step(state, inputs):
142+
"""Update state and compute metrics"""
135143
grad_fn = jax.value_and_grad(dynamics_loss_fn, has_aux=True, allow_int=True)
136144
(loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs)
137145
state = state.apply_gradients(grads=grads)

0 commit comments

Comments
 (0)