-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
Describe the issue
The Local Response Normalization (LRN) operator on the CPUExecutionProvider shows a major performance regression starting in onnxruntime v1.21.0.
- Latency remained consistent (~7–8 ms) up to v1.20.0.
- From v1.21.0 onward, latency degraded to ~29 ms (~3.8× slowdown).
- The regression persists through the latest release (v1.23.0).
- Culprit commit (via bisect) : [7c0c6fbe](7c0c6fbe) (Eigen dependency update).
- Because LRN relies on channel-wise reductions and exponentiation, it is sensitive to vectorization and reduction path changes.
- The Eigen update likely affected these optimizations, leading to the observed slowdown.
Additional Observation
LRN
--------------------------------------------------------------------------------
Version Individual Operators (ms)
--------------------------------------------------------------------------------
onnxruntime-v1.18.0 LRN(8.00)
onnxruntime-v1.19.0 LRN(7.77)
onnxruntime-v1.20.0 LRN(7.59)
onnxruntime-v1.21.0 LRN(29.14)
onnxruntime-v1.22.0 LRN(29.18)
onnxruntime-v1.23.0 LRN(29.13)
--------------------------------------------------------------------------------
Interestingly, the regression depends on the beta parameter:
- With the typical setting beta=0.75, latency regresses heavily (v1.20.0 median ~7.3 ms → v1.21.0 median ~29 ms).
- However, with beta=1.0, latency actually improves in v1.21.0+ (v1.20.0 median ~1.76 ms → v1.21.0 median ~0.49 ms).
LRN
--------------------------------------------------------------------------------
Version Individual Operators (ms)
--------------------------------------------------------------------------------
onnxruntime-v1.20.0 LRN(0.490)
onnxruntime-v1.21.0 LRN(1.763)
--------------------------------------------------------------------------------
We performed additional analysis to further investigate this issue.

Observation with line citations.
In commit 1442fe0, the LRN kernel computes the exponent parameter as f.b = -beta (see L116).
Later, the output is formed via an Eigen expression ym = a2 * a.pow(b); (see L48), where b is the exponent used by Eigen’s tensor math.
This change in how the exponent is passed/applied (f.b = -beta_ → used in a.pow(b)) likely alters Eigen’s internal evaluation/optimization, leading to the observed performance difference.
To reproduce
Running the script below under different ONNX Runtime versions (e.g., v1.20.0, v1.21.0, and v1.23.0) will show the performance difference. The regression consistently appears from v1.21.0 onwards.
#!/usr/bin/env python3
import time
import numpy as np
import onnx
import onnxruntime as ort
from onnx import helper, TensorProto
from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument
# ---------------- Model configuration ----------------
N, C, H, W = 1, 64, 56, 56
LRN_SIZE = 5
LRN_ALPHA = 1e-4
LRN_BETA = 0.75
LRN_BIAS = 1.0
OPSET = 13
WARMUP = 200
RUNS = 1000
THREADS = 1
def make_lrn_model_bytes(ir_version: int):
X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [N, C, H, W])
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [N, C, H, W])
node = helper.make_node("LRN", ["X"], ["Y"],
size=LRN_SIZE, alpha=LRN_ALPHA, beta=LRN_BETA, bias=LRN_BIAS)
graph = helper.make_graph([node], "lrn_graph", [X], [Y])
model = helper.make_model(
graph,
opset_imports=[helper.make_opsetid("", OPSET)],
producer_name="lrn-repro",
ir_version=ir_version,
)
onnx.checker.check_model(model)
return model.SerializeToString()
def build_session_with_fallback():
so = ort.SessionOptions()
so.intra_op_num_threads = THREADS
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
for ir in (10, 9):
try:
model_bytes = make_lrn_model_bytes(ir)
return ort.InferenceSession(model_bytes, sess_options=so, providers=["CPUExecutionProvider"]), ir
except InvalidArgument as e:
last_err = e
continue
raise last_err
def benchmark():
sess, ir_used = build_session_with_fallback()
rng = np.random.default_rng(0)
x = rng.standard_normal((N, C, H, W), dtype=np.float32)
# Warmup
for _ in range(WARMUP):
sess.run(None, {"X": x})
# Timed runs
ts = []
for _ in range(RUNS):
t0 = time.perf_counter()
sess.run(None, {"X": x})
ts.append((time.perf_counter() - t0) * 1000.0)
ts.sort()
median = ts[len(ts)//2]
p90 = ts[int(len(ts)*0.9)]
p99 = ts[int(len(ts)*0.99)]
print("Providers:", sess.get_providers())
print(f"IR version used: {ir_used}")
print(f"Input: (N,C,H,W)=({N},{C},{H},{W}) | LRN(size={LRN_SIZE}, alpha={LRN_ALPHA}, beta={LRN_BETA}, bias={LRN_BIAS})")
print(f"Threads={THREADS} | Runs={RUNS} | Warmup={WARMUP}")
print(f"Latency (ms): median={median:.3f}, p90={p90:.3f}, p99={p99:.3f}")
if __name__ == "__main__":
benchmark()
results
(venv_onnxruntime-Release-v1.20.0) root@kayle:/app/regression_experiment# python lrn.py
Providers: ['CPUExecutionProvider']
IR version used: 10
Input: (N,C,H,W)=(1,64,56,56) | LRN(size=5, alpha=0.0001, beta=0.75, bias=1.0)
Threads=1 | Runs=1000 | Warmup=200
Latency (ms): median=1.763, p90=1.768, p99=1.831
(venv_onnxruntime-Release-v1.21.0) root@kayle:/app/regression_experiment# python lrn.py
Providers: ['CPUExecutionProvider']
IR version used: 10
Input: (N,C,H,W)=(1,64,56,56) | LRN(size=5, alpha=0.0001, beta=0.75, bias=1.0)
Threads=1 | Runs=1000 | Warmup=200
Latency (ms): median=7.361, p90=7.397, p99=7.481
Urgency
No response
Platform
Linux
OS Version
ubuntu 24.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
v1.20.0, 1442fe0, v1.21.0
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
Model File
No response
Is this a quantized model?
Yes