Skip to content

[CINN] Fixed longlong2int induced minmax type mismatch #74240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from

Conversation

Enigmatisms
Copy link
Contributor

@Enigmatisms Enigmatisms commented Jul 25, 2025

PR Category

CINN

PR Types

Bug fixes

Description

longlong2int 是一个 int类型计算的 optimization pass,包括了:

  • OpDataTypePromote (显式在 CodeGenAndJit中调用)
  • TryElevateIntxxtoIntxx(32->64或者反过来都有):这个函数在很多类型、函数中都隐式调用了。包括 ir::BinaryNode 构造时

特别是第二点,导致type cast问题涉及的范围很广(甚至在 IrPrinter 生成代码的环节都能出现 type cast,比如 load/store node调用 index 方法时)。问题的主要表现是:当被编译操作中含有 ir::Min / ir::Max 两种 node 时,可能会因为动态 shape 的存在而导致 operand type 不匹配。究其原因:

  • CINN 根据 predicate 对动态shape同一个 kernel 编译几个版本:比如 shape 大时,动态 shape 输入将会使用 int64,反之用 int32,kernel 在 host 端通过 branches 进行选择。
  • 输入参数的动态 shape 表达式(ir::LoweredFunc)的 ir::Argument 与在 kernel 内部的动态 shape 表达式(ir::_Var_)不是强关联的,导致同时有 int32 / int64 版本ir::Argument 的 LoweredFunc,但 func body 内部类型并没有正确设置。具体表现如下:
// func body IR: i32 version
var[some_index] = cinn_min(S0, 0);
// func body IR: i64 version
var[some_index] = cinn_min(S0, 0ll);
// CUDA code:
// 32bit ir::Argument
__global__ foo_predicate_le_int_max_kernel(int S0) {
    var[threadIdx.x] = cinn_min(S0, 0ll);    // 报错!
}
// 64bit ir::Argument
__global__ foo_predicate_gt_int_max_kernel(int64_t S0) {
    var[threadIdx.x] = cinn_min(S0, 0ll);    // 不报错
}

这一问题的解决,可以说目前我的方案还比较暴力(但可能很难找到更简单的方法了):

  • codegen 开始阶段(GPU only),通过哈希表记录变量名到动态shape symbol 类型的关系
  • 在codegen阶段,IrPrinter 打印 ir::Min / ir::Max node 时,对于左右operands进行 ir-visiting:当左右operand都是int类型输入,并且都包含ir::Argument中已经记录的symbol名称时,确定 最大的 int类型 bit 数(32/64)。
  • 根据最大bit数,将左右 operand 进行 cast,IrPrinter print cast 后(unified type)的 operands。

举个例子:cinn_min((S0 + 1), 1024ll)ir::Argument 中的 S0 是 int32类型的 ir::_Var_,其left operand 包含动态shape symbol,并且查哈希表知,最大 bit 是32,而 right operand 不包含动态 shape symbol,故为0,则这个 min node 应该被整体 cast 为 int32 类型的(否则在 codegen 带入时,S0 是 int,左边整体是int)。

之所以说目前没有更简单的方法,就是因为 longlong2int 相关的改动设计的API太多了,从 PostProcess 到 CodeGen 的 string code printing 过程都有int类型的相互变换。

感兴趣可以测试一下本 PR 提供的单测:test_unifying_minmax_type.py。本 PR 前此单测没有一个可以通过,并且原始的 bug 也影响了一些 CINN 算子的支持(比如 gather_nd 我尝试将min/max 进行简单的 type unify,统一到int64,引入了一定的性能问题,见 #73940),后续会对对应算子支持进行进一步简化。

Pcard-89620

Copy link

paddle-bot bot commented Jul 25, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant