Skip to content

Massive performance issue on Convolution with JAX #34273

@ZagButNoZig

Description

@ZagButNoZig

Hey this is a cross reference this issue from the jax repo.
The issue here seems to be that XLA chooses a suboptimal convolution compared to PyTorch, leading to a > 30x slowdown.

For PyTorch I get

{'eager_ms_per_iter': 0.464, 'compiled_ms_per_iter': 0.538}

And with this HLO I get:

{'stablehlo_execution_ms_per_iter': 41.291}

The HLO to reproduce is:

module @jit_loss_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x64x64x64x32xbf16>) -> (tensor<i32> {jax.result_info = "result[0]"}, tensor<1x1x64x3x3x3xbf16> {jax.result_info = "result[1].weight"}) {
    %0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 2, 3, 4, 5] : (tensor<1x64x64x64x32xbf16>) -> tensor<1x1x64x64x64x32xbf16>
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0x49800000> : tensor<f32>
    %1 = stablehlo.divide %cst, %cst_0 : tensor<f32>
    %2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f32>) -> tensor<1x128x128x64xf32>
    %3 = stablehlo.convert %2 : (tensor<1x128x128x64xf32>) -> tensor<1x128x128x64xbf16>
    %4 = stablehlo.broadcast_in_dim %3, dims = [1, 2, 3, 4] : (tensor<1x128x128x64xbf16>) -> tensor<1x1x128x128x64xbf16>
    %5 = stablehlo.transpose %0, dims = [1, 0, 2, 3, 4, 5] : (tensor<1x1x64x64x64x32xbf16>) -> tensor<1x1x64x64x64x32xbf16>
    %6 = stablehlo.reshape %5 : (tensor<1x1x64x64x64x32xbf16>) -> tensor<1x64x64x64x32xbf16>
    %7 = stablehlo.convolution(%6, %4) dim_numbers = [f, b, 0, 1, 2]x[i, o, 0, 1, 2]->[f, b, 0, 1, 2], window = {stride = [1, 1, 1], pad = [[1, 2], [1, 2], [1, 2]], lhs_dilate = [2, 2, 2], rhs_dilate = [1, 1, 1], reverse = [false, false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x64x64x64x32xbf16>, tensor<1x1x128x128x64xbf16>) -> tensor<1x64x3x3x3xbf16>
    %8 = stablehlo.reshape %7 : (tensor<1x64x3x3x3xbf16>) -> tensor<1x1x64x3x3x3xbf16>
    %9 = stablehlo.transpose %8, dims = [1, 0, 2, 3, 4, 5] : (tensor<1x1x64x3x3x3xbf16>) -> tensor<1x1x64x3x3x3xbf16>
    %c = stablehlo.constant dense<0> : tensor<i32>
    return %c, %9 : tensor<i32>, tensor<1x1x64x3x3x3xbf16>
  }
}
Reproduction Script
import time
import equinox as eqx
import jax
import jax.numpy as jnp

in_f = 64
ou_f = 1
dtype = jnp.bfloat16
key = jax.random.PRNGKey(0)

model = eqx.nn.ConvTranspose(
    num_spatial_dims=3,
    in_channels=64,
    out_channels=1,
    kernel_size=(3, 3, 3),
    stride=(2, 2, 2),
    padding=((1, 1), (1, 1), (1, 1)),
    output_padding=(1, 1, 1),
    dilation=(1, 1, 1),
    groups=1,
    use_bias=False,
    padding_mode="ZEROS",
    key=jax.random.key(0),
    dtype=dtype,
)

batch_size = 1
spatial_shape = (64, 64, 32)
inp = jnp.zeros((batch_size, in_f, *spatial_shape), dtype=dtype)

def loss_fn(model, x):
    return model(x).mean()

grad_fn = eqx.filter_grad(loss_fn)
grad_fn_batched = eqx.filter_jit(eqx.filter_vmap(grad_fn, in_axes=(None, 0)))

lowered = grad_fn_batched.lower(model, inp)
print(lowered.as_text())

compiled = lowered.compile()

out = compiled(model, inp)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)

def bench(f, iters=20):
    t0 = time.perf_counter()
    for _ in range(iters):
        y = f()
        jax.tree_util.tree_map(lambda x: x.block_until_ready(), y)
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / iters

ms = bench(lambda: compiled(model, inp))
print({"stablehlo_execution_ms_per_iter": round(ms, 3)})

And the PyTorch variant:

import time, copy, torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity

in_f = 64
ou_f = 1
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
output_padding = (1, 1, 1)
dilation = (1, 1, 1)
groups = 1
batch_size = 1
spatial_shape = (64, 64, 32)

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32

def make_model():
    return nn.ConvTranspose3d(
        in_channels=in_f,
        out_channels=ou_f,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
        bias=True,
    ).to(device=device, dtype=dtype)

model_eager = make_model()
model_compiled = torch.compile(copy.deepcopy(model_eager))

inp = torch.zeros((batch_size, in_f, *spatial_shape), device=device, dtype=dtype)

for p in model_compiled.parameters():
    p.grad = None
y = model_compiled(inp).mean()
y.backward()
print(y.item())
print([tuple(p.grad.shape) for p in model_compiled.parameters()])

trace_path = "/tmp/torch-trace.json"
with profile(
    activities=[ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if device == "cuda" else []),
    with_stack=True,
    record_shapes=True,
    profile_memory=True,
) as prof:
    for p in model_compiled.parameters():
        p.grad = None
    _y = model_compiled(inp).mean()
    _y.backward()
prof.export_chrome_trace(trace_path)
print(trace_path)

def bench(model, iters=20, warmup=5):
    if device == "cuda":
        torch.cuda.synchronize()
    for _ in range(warmup):
        for p in model.parameters():
            p.grad = None
        loss = model(inp).mean()
        loss.backward()
    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        for p in model.parameters():
            p.grad = None
        loss = model(inp).mean()
        loss.backward()
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) * 1000.0 / iters

eager_ms = bench(model_eager)
compiled_ms = bench(model_compiled)
print({"eager_ms_per_iter": round(eager_ms, 3), "compiled_ms_per_iter": round(compiled_ms, 3)})

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions