Skip to content

Commit a73bbde

Browse files
committed
Add back vjp flops
1 parent fb5bc55 commit a73bbde

File tree

2 files changed

+22
-26
lines changed

2 files changed

+22
-26
lines changed

flax/nnx/summary.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ class CallInfo:
170170
path: statelib.PathParts
171171
inputs_repr: str
172172
outputs: tp.Any
173-
flops: int | None = None
174-
vjp_flops: int | None = None
173+
flops: int | None
174+
vjp_flops: int | None
175175

176176
class SimpleObjectRepr:
177177
def __init__(self, obj: tp.Any):
@@ -201,7 +201,7 @@ def inner(state, *args, **kwargs):
201201
return f(model, *args, **kwargs)
202202
return jax.vjp(inner, state, *args, **kwargs)
203203

204-
def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_repr):
204+
def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_repr, vjp_flops):
205205
flops = _get_flops(lowered) if compute_flops else None
206206
outputs = lowered.lowered.out_info[2]
207207
output_repr = jax.tree.map(_to_dummy_array, outputs)
@@ -219,6 +219,7 @@ def _get_call_info(lowered, method_name, node_stats, obj, compute_flops, inputs_
219219
inputs_repr=inputs_repr,
220220
outputs=output_repr,
221221
flops=flops,
222+
vjp_flops=vjp_flops
222223
)
223224

224225

@@ -234,8 +235,11 @@ def _create_obj_env(object_types):
234235
result[(obj_type, name)] = top_method
235236
return result
236237

237-
def _argsave(counter, tracer_args, f):
238+
def _argsave(counter, tracer_args, f, compute_vjp_flops):
238239
"Wrap a function to save its arguments"
240+
def do_vjp(*args, **kwargs):
241+
primals, f_vjp = jax.vjp(f, *args, **kwargs)
242+
return f_vjp(primals)
239243
n = f.__name__
240244
@wraps(f)
241245
def wrapper(obj, *args, **kwargs):
@@ -258,7 +262,11 @@ def wrapper(obj, *args, **kwargs):
258262
counter_val = counter[0]
259263
counter[0] += 1
260264
lowered = f.lower(obj, *args, **kwargs)
261-
tracer_args.append((counter_val, obj, n, lowered, inputs_repr))
265+
if compute_vjp_flops:
266+
vjp_flops = _get_flops(jax.jit(do_vjp).lower(obj, *args, **kwargs))
267+
else:
268+
vjp_flops = None
269+
tracer_args.append((counter_val, obj, n, lowered, inputs_repr, vjp_flops))
262270
f.seen.add(identifier)
263271
return f(obj, *args, **kwargs)
264272
return wrapper
@@ -415,14 +423,13 @@ def tabulate(
415423
# iteration over methods easier.
416424
env = _create_obj_env(object_types)
417425

418-
# Modify all the object's methods to save their lowered JIT representations.
419-
tracer_args = []
420-
421426
# Information is recorded in post-order, but should be presented as a pre-order traversal.
422427
# This counter is incremented in pre-order traversal to keep track of the order of calls.
423428
counter = [0]
424429

425-
jits = {k: _argsave(counter, tracer_args, MaybeJit(v)) for k,v in env.items()}
430+
# Modify all the object's methods to save their lowered JIT representations.
431+
tracer_args = []
432+
jits = {k: _argsave(counter, tracer_args, MaybeJit(v), compute_vjp_flops) for k,v in env.items()}
426433
_overwrite_methods(jits)
427434

428435
# Trace the top function (which indirectly traces all the others)
@@ -431,19 +438,8 @@ def tabulate(
431438
# Get call_info
432439
rows : list[CallInfo] = [_get_call_info(
433440
lowered, name, node_stats, object,
434-
compute_flops, inputs_repr)
435-
for (_, object, name, lowered, inputs_repr) in sorted(tracer_args, key=lambda x: x[0])]
436-
437-
# Add VJP flops if required. This needs to be done separately because calls to `_pure_nnx_vjp`
438-
# can result in tracing the jitted functions a second time if there's shared structure.
439-
# This would add items to `tracer_args`, resulting in duplicate rows in the table.
440-
# if compute_vjp_flops:
441-
# for i, row in enumerate(rows):
442-
# object, method_name, args, kwargs = tracer_args[i]
443-
# def do_vjp(*args, **kwargs):
444-
# primals, f_vjp = _pure_nnx_vjp(jits[(type(object), method_name)], *args, **kwargs)
445-
# return f_vjp(primals)
446-
# row.vjp_flops = _get_flops(jax.jit(do_vjp).lower(object, *args, **kwargs))
441+
compute_flops, inputs_repr, vjp_flops)
442+
for (_, object, name, lowered, inputs_repr, vjp_flops) in sorted(tracer_args, key=lambda x: x[0])]
447443

448444
# Restore the object's original methods
449445
_overwrite_methods(env)

tests/nnx/summary_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ def __call__(self, x1):
193193
).splitlines()
194194
self.assertIn('flops', table_repr1[2])
195195
self.assertNotIn('vjp_flops', table_repr1[2])
196-
# table_repr2 = nnx.tabulate(
197-
# m, x, compute_flops=True, compute_vjp_flops=True
198-
# ).splitlines()
199-
# self.assertIn('vjp_flops', table_repr2[2])
196+
table_repr2 = nnx.tabulate(
197+
m, x, compute_flops=True, compute_vjp_flops=True
198+
).splitlines()
199+
self.assertIn('vjp_flops', table_repr2[2])
200200

201201
def test_nested(self):
202202
class Block(nnx.Module):

0 commit comments

Comments
 (0)