Skip to content

Error in promotion of data type for accumulation #299

@aamarnat

Description

@aamarnat

This piece of code in the iris/x/ kernels leads to error for many datatypes. Eg: fp16 get's accumulated into int32 leading to large RMSE errors in the output.

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32

Eg. used here:
https://github.com/ROCm/iris/blob/1a09fae0572c2e2484f5abb0c09214a02ed4500b/iris/x/all_gather_gemm.py#L119C1-L119C73

Metadata

Metadata

Assignees

Labels

irisIris project issue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions