-
Notifications
You must be signed in to change notification settings - Fork 696
Open
Labels
err:performancePerformance issuesPerformance issues
Description
Hey this is a cross reference this issue from the jax repo.
The issue here seems to be that XLA chooses a suboptimal convolution compared to PyTorch, leading to a > 30x slowdown.
For PyTorch I get
{'eager_ms_per_iter': 0.464, 'compiled_ms_per_iter': 0.538}And with this HLO I get:
{'stablehlo_execution_ms_per_iter': 41.291}The HLO to reproduce is:
module @jit_loss_fn attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<1x64x64x64x32xbf16>) -> (tensor<i32> {jax.result_info = "result[0]"}, tensor<1x1x64x3x3x3xbf16> {jax.result_info = "result[1].weight"}) {
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 2, 3, 4, 5] : (tensor<1x64x64x64x32xbf16>) -> tensor<1x1x64x64x64x32xbf16>
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%cst_0 = stablehlo.constant dense<0x49800000> : tensor<f32>
%1 = stablehlo.divide %cst, %cst_0 : tensor<f32>
%2 = stablehlo.broadcast_in_dim %1, dims = [] : (tensor<f32>) -> tensor<1x128x128x64xf32>
%3 = stablehlo.convert %2 : (tensor<1x128x128x64xf32>) -> tensor<1x128x128x64xbf16>
%4 = stablehlo.broadcast_in_dim %3, dims = [1, 2, 3, 4] : (tensor<1x128x128x64xbf16>) -> tensor<1x1x128x128x64xbf16>
%5 = stablehlo.transpose %0, dims = [1, 0, 2, 3, 4, 5] : (tensor<1x1x64x64x64x32xbf16>) -> tensor<1x1x64x64x64x32xbf16>
%6 = stablehlo.reshape %5 : (tensor<1x1x64x64x64x32xbf16>) -> tensor<1x64x64x64x32xbf16>
%7 = stablehlo.convolution(%6, %4) dim_numbers = [f, b, 0, 1, 2]x[i, o, 0, 1, 2]->[f, b, 0, 1, 2], window = {stride = [1, 1, 1], pad = [[1, 2], [1, 2], [1, 2]], lhs_dilate = [2, 2, 2], rhs_dilate = [1, 1, 1], reverse = [false, false, false]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<1x64x64x64x32xbf16>, tensor<1x1x128x128x64xbf16>) -> tensor<1x64x3x3x3xbf16>
%8 = stablehlo.reshape %7 : (tensor<1x64x3x3x3xbf16>) -> tensor<1x1x64x3x3x3xbf16>
%9 = stablehlo.transpose %8, dims = [1, 0, 2, 3, 4, 5] : (tensor<1x1x64x3x3x3xbf16>) -> tensor<1x1x64x3x3x3xbf16>
%c = stablehlo.constant dense<0> : tensor<i32>
return %c, %9 : tensor<i32>, tensor<1x1x64x3x3x3xbf16>
}
}
Reproduction Script
import time
import equinox as eqx
import jax
import jax.numpy as jnp
in_f = 64
ou_f = 1
dtype = jnp.bfloat16
key = jax.random.PRNGKey(0)
model = eqx.nn.ConvTranspose(
num_spatial_dims=3,
in_channels=64,
out_channels=1,
kernel_size=(3, 3, 3),
stride=(2, 2, 2),
padding=((1, 1), (1, 1), (1, 1)),
output_padding=(1, 1, 1),
dilation=(1, 1, 1),
groups=1,
use_bias=False,
padding_mode="ZEROS",
key=jax.random.key(0),
dtype=dtype,
)
batch_size = 1
spatial_shape = (64, 64, 32)
inp = jnp.zeros((batch_size, in_f, *spatial_shape), dtype=dtype)
def loss_fn(model, x):
return model(x).mean()
grad_fn = eqx.filter_grad(loss_fn)
grad_fn_batched = eqx.filter_jit(eqx.filter_vmap(grad_fn, in_axes=(None, 0)))
lowered = grad_fn_batched.lower(model, inp)
print(lowered.as_text())
compiled = lowered.compile()
out = compiled(model, inp)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), out)
def bench(f, iters=20):
t0 = time.perf_counter()
for _ in range(iters):
y = f()
jax.tree_util.tree_map(lambda x: x.block_until_ready(), y)
t1 = time.perf_counter()
return (t1 - t0) * 1000.0 / iters
ms = bench(lambda: compiled(model, inp))
print({"stablehlo_execution_ms_per_iter": round(ms, 3)})And the PyTorch variant:
import time, copy, torch
import torch.nn as nn
from torch.profiler import profile, ProfilerActivity
in_f = 64
ou_f = 1
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
output_padding = (1, 1, 1)
dilation = (1, 1, 1)
groups = 1
batch_size = 1
spatial_shape = (64, 64, 32)
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
def make_model():
return nn.ConvTranspose3d(
in_channels=in_f,
out_channels=ou_f,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
bias=True,
).to(device=device, dtype=dtype)
model_eager = make_model()
model_compiled = torch.compile(copy.deepcopy(model_eager))
inp = torch.zeros((batch_size, in_f, *spatial_shape), device=device, dtype=dtype)
for p in model_compiled.parameters():
p.grad = None
y = model_compiled(inp).mean()
y.backward()
print(y.item())
print([tuple(p.grad.shape) for p in model_compiled.parameters()])
trace_path = "/tmp/torch-trace.json"
with profile(
activities=[ProfilerActivity.CPU] + ([ProfilerActivity.CUDA] if device == "cuda" else []),
with_stack=True,
record_shapes=True,
profile_memory=True,
) as prof:
for p in model_compiled.parameters():
p.grad = None
_y = model_compiled(inp).mean()
_y.backward()
prof.export_chrome_trace(trace_path)
print(trace_path)
def bench(model, iters=20, warmup=5):
if device == "cuda":
torch.cuda.synchronize()
for _ in range(warmup):
for p in model.parameters():
p.grad = None
loss = model(inp).mean()
loss.backward()
if device == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
for p in model.parameters():
p.grad = None
loss = model(inp).mean()
loss.backward()
if device == "cuda":
torch.cuda.synchronize()
t1 = time.perf_counter()
return (t1 - t0) * 1000.0 / iters
eager_ms = bench(model_eager)
compiled_ms = bench(model_compiled)
print({"eager_ms_per_iter": round(eager_ms, 3), "compiled_ms_per_iter": round(compiled_ms, 3)})Metadata
Metadata
Assignees
Labels
err:performancePerformance issuesPerformance issues