Skip to content

Issue with profiling HLO Ops on GPU #31669

@wenboqian

Description

@wenboqian

Description

When using jax.profiler, I found that if I set os.environ['XLA_FLAGS'] = '--xla_hlo_profile=true', it will raise error below:
Image

I am trying to profile the duration of fine-grained HLO operations, e.g. add.3.3 in a GPT2 model on GPU using jax.profiler.

%fused_add.1 {
  ...
  %add.3.3 = f32[256,768]{1,0} ...}
  ...
}
%command_buffer {
  ...
  %loop_add_fusion.1 = ..., calls=%fused_add.1
  ...
}
ENTRY %main.332 {
  ...
  ROOT %call.4 = ..., to_apply=%command_buffer
}

However, I can only get the duration for the fused operations like fused_add.1 in the trace.json.gz file generated by jax.profiler, not for hlo ops add.3.3 in fused_add.1.

Image

Is it possible to get the duration for each fine-grained HLO operation in JAX? Or is this not achievable on GPU due to asynchronous parallel execution, which only provides the execution duration of the wrapped fusion HLO op?

Code Example:

import jax
import jax.numpy as jnp
import jax.profiler
from transformers import FlaxGPT2Model, GPT2Config
import numpy as np
import os

os.environ['XLA_FLAGS'] = '--xla_hlo_profile=true'

config = GPT2Config(
    vocab_size=50257,
    n_positions=1024,
    n_embd=768,
    n_layer=1, 
    n_head=12,
)

model = FlaxGPT2Model(config)

batch_size = 2
seq_length = 128
input_ids = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
attention_mask = jnp.ones((batch_size, seq_length), dtype=jnp.int32)

rng = jax.random.PRNGKey(0)
params = model.init_weights(rng, input_shape=(batch_size, seq_length))

def gpt2_forward(params, input_ids, attention_mask):
    outputs = model(input_ids, attention_mask=attention_mask, params=params)
    return outputs.last_hidden_state

jit_forward = jax.jit(gpt2_forward)

_ = jit_forward(params, input_ids, attention_mask)

gpt2_lowered = jax.jit(gpt2_forward).lower(params, input_ids, attention_mask)
gpt2_compiled = gpt2_lowered.compile()

with open("./gpt2_hlo.txt", "w") as f:
    f.write(gpt2_compiled.as_text())

hlo_text = gpt2_compiled.as_text()
hlo_lines = hlo_text.split('\n')

with jax.profiler.trace("./flax_gpt2"):
    for i in range(2):  
        result = jit_forward(params, input_ids, attention_mask)
        result.block_until_ready()

System info

jax=0.6.2

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions