Skip to content

Commit dfeb1ae

Browse files
thomasjpfansvekars
andauthored
Fixes backward definition in triton tutorial (#3282)
Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 780d5cb commit dfeb1ae

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

recipes_source/torch_compile_user_defined_triton_kernel_tutorial.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ def sin_triton(x):
257257
# Prefer this to using ``torch.autograd.Function`` (which has various composability footguns
258258
# with ``torch.compile``).
259259

260-
def backward(ctx, grad_output):
260+
def backward(ctx, grad):
261261
x, = ctx.saved_tensors
262-
return grad_input * x.cos()
262+
return grad * x.cos()
263263

264264
def setup_context(ctx, inputs, output):
265265
x, = inputs
@@ -293,9 +293,9 @@ def mycos(x: torch.Tensor) -> torch.Tensor:
293293
wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)
294294
return out
295295

296-
def backward(ctx, grad_output):
296+
def backward(ctx, grad):
297297
x, = ctx.saved_tensors
298-
return grad_input * mycos(x)
298+
return grad * mycos(x)
299299

300300
def setup_context(ctx, inputs, output):
301301
x, = inputs

0 commit comments

Comments
 (0)