This piece of code in the iris/x/ kernels leads to error for many datatypes. Eg: fp16 get's accumulated into int32 leading to large RMSE errors in the output.
# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32