-
Notifications
You must be signed in to change notification settings - Fork 662
Open
Labels
err: RuntimeRuntime ErrorRuntime Error
Description
Description
When using jax.profiler, I found that if I set os.environ['XLA_FLAGS'] = '--xla_hlo_profile=true'
, it will raise error below:
I am trying to profile the duration of fine-grained HLO operations, e.g. add.3.3 in a GPT2 model on GPU using jax.profiler.
%fused_add.1 {
...
%add.3.3 = f32[256,768]{1,0} ...}
...
}
%command_buffer {
...
%loop_add_fusion.1 = ..., calls=%fused_add.1
...
}
ENTRY %main.332 {
...
ROOT %call.4 = ..., to_apply=%command_buffer
}
However, I can only get the duration for the fused operations like fused_add.1 in the trace.json.gz file generated by jax.profiler, not for hlo ops add.3.3 in fused_add.1.

Is it possible to get the duration for each fine-grained HLO operation in JAX? Or is this not achievable on GPU due to asynchronous parallel execution, which only provides the execution duration of the wrapped fusion HLO op?
Code Example:
import jax
import jax.numpy as jnp
import jax.profiler
from transformers import FlaxGPT2Model, GPT2Config
import numpy as np
import os
os.environ['XLA_FLAGS'] = '--xla_hlo_profile=true'
config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_embd=768,
n_layer=1,
n_head=12,
)
model = FlaxGPT2Model(config)
batch_size = 2
seq_length = 128
input_ids = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
attention_mask = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
params = model.init_weights(rng, input_shape=(batch_size, seq_length))
def gpt2_forward(params, input_ids, attention_mask):
outputs = model(input_ids, attention_mask=attention_mask, params=params)
return outputs.last_hidden_state
jit_forward = jax.jit(gpt2_forward)
_ = jit_forward(params, input_ids, attention_mask)
gpt2_lowered = jax.jit(gpt2_forward).lower(params, input_ids, attention_mask)
gpt2_compiled = gpt2_lowered.compile()
with open("./gpt2_hlo.txt", "w") as f:
f.write(gpt2_compiled.as_text())
hlo_text = gpt2_compiled.as_text()
hlo_lines = hlo_text.split('\n')
with jax.profiler.trace("./flax_gpt2"):
for i in range(2):
result = jit_forward(params, input_ids, attention_mask)
result.block_until_ready()
System info
jax=0.6.2
Metadata
Metadata
Assignees
Labels
err: RuntimeRuntime ErrorRuntime Error