[CINN] Fixed longlong2int induced minmax type mismatch #74240
+221
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
CINN
PR Types
Bug fixes
Description
longlong2int 是一个 int类型计算的 optimization pass,包括了:
OpDataTypePromote
(显式在CodeGenAndJit
中调用)TryElevateIntxxtoIntxx
(32->64或者反过来都有):这个函数在很多类型、函数中都隐式调用了。包括 ir::BinaryNode 构造时特别是第二点,导致type cast问题涉及的范围很广(甚至在 IrPrinter 生成代码的环节都能出现 type cast,比如 load/store node调用 index 方法时)。问题的主要表现是:当被编译操作中含有
ir::Min
/ir::Max
两种 node 时,可能会因为动态 shape 的存在而导致 operand type 不匹配。究其原因:ir::LoweredFunc
)的ir::Argument
与在 kernel 内部的动态 shape 表达式(ir::_Var_
)不是强关联的,导致同时有 int32 / int64 版本ir::Argument
的 LoweredFunc,但 func body 内部类型并没有正确设置。具体表现如下:这一问题的解决,可以说目前我的方案还比较暴力(但可能很难找到更简单的方法了):
ir::Min
/ir::Max
node 时,对于左右operands进行 ir-visiting:当左右operand都是int类型输入,并且都包含ir::Argument中已经记录的symbol名称时,确定 最大的 int类型 bit 数(32/64)。举个例子:
cinn_min((S0 + 1), 1024ll)
,ir::Argument
中的S0
是 int32类型的ir::_Var_
,其left operand 包含动态shape symbol,并且查哈希表知,最大 bit 是32,而 right operand 不包含动态 shape symbol,故为0,则这个 min node 应该被整体 cast 为 int32 类型的(否则在 codegen 带入时,S0
是 int,左边整体是int)。之所以说目前没有更简单的方法,就是因为 longlong2int 相关的改动设计的API太多了,从 PostProcess 到 CodeGen 的 string code printing 过程都有int类型的相互变换。
感兴趣可以测试一下本 PR 提供的单测:
test_unifying_minmax_type.py
。本 PR 前此单测没有一个可以通过,并且原始的 bug 也影响了一些 CINN 算子的支持(比如gather_nd
我尝试将min/max 进行简单的 type unify,统一到int64,引入了一定的性能问题,见 #73940),后续会对对应算子支持进行进一步简化。Pcard-89620