Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] [ir] [cuda] Add clz instruction #8276

Merged
merged 11 commits into from
Oct 31, 2023
12 changes: 12 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,18 @@ def py_ifte(cond, x1, x2):
return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2)


def clz(a):
"""Count the number of leading zeros for a 32bit integer"""

def _clz(x):
for i in range(32):
if 2**i > x:
return 32 - i
return 0

return _unary_operation(_ti_core.expr_clz, _clz, a)


@writeback_binary
def atomic_add(x, y):
"""Atomically compute `x + y`, store the result in `x`,
Expand Down
6 changes: 6 additions & 0 deletions python/taichi/math/mathimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,12 +812,18 @@ def popcnt(x):
return ops.popcnt(x)


@func
def clz(x):
return ops.clz(x)


__all__ = [
"acos",
"asin",
"atan2",
"ceil",
"clamp",
"clz",
"cos",
"cross",
"degrees",
Expand Down
9 changes: 9 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::clz) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) {
stmt->ret_type = PrimitiveType::i32;
llvm_val[stmt] = call("__nv_clz", input);
} else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) {
llvm_val[stmt] = call("__nv_clzll", input);
} else {
TI_NOT_IMPLEMENTED
}
} else if (op == UnaryOpType::log) {
if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) {
// logf has fast-math option
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) {
llvm_val[stmt] =
builder->CreateIntrinsic(llvm::Intrinsic::ctpop, {input_type}, {input});
}
else if (op == UnaryOpType::clz) {
llvm_val[stmt] = builder->CreateIntrinsic(
llvm::Intrinsic::ctlz, {input_type},
{input,
llvm::ConstantInt::get(llvm::Type::getInt1Ty(*llvm_context), 0)});
}
else {
TI_P(unary_op_type_name(op));
TI_NOT_IMPLEMENTED
Expand Down
6 changes: 6 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,12 @@ class TaskCodegen : public IRVisitor {
ir_->store_variable(val, v);
} else if (stmt->op_type == UnaryOpType::popcnt) {
val = ir_->popcnt(operand_val);
} else if (stmt->op_type == UnaryOpType::clz) {
uint32_t FindMSB_id = 74;
spirv::Value msb = ir_->call_glsl450(dst_type, FindMSB_id, operand_val);
spirv::Value bitcnt = ir_->int_immediate_number(ir_->i32_type(), 32);
spirv::Value one = ir_->int_immediate_number(ir_->i32_type(), 1);
val = ir_->sub(ir_->sub(bitcnt, msb), one);
}
#define UNARY_OP_TO_SPIRV(op, instruction, instruction_id, max_bits) \
else if (stmt->op_type == UnaryOpType::op) { \
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/unary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ PER_UNARY_OP(rcp)
PER_UNARY_OP(exp)
PER_UNARY_OP(log)
PER_UNARY_OP(popcnt)
PER_UNARY_OP(clz)
PER_UNARY_OP(rsqrt)
PER_UNARY_OP(bit_not)
PER_UNARY_OP(logic_not)
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/expression_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ DEFINE_EXPRESSION_FUNC_UNARY(rsqrt)
DEFINE_EXPRESSION_FUNC_UNARY(exp)
DEFINE_EXPRESSION_FUNC_UNARY(log)
DEFINE_EXPRESSION_FUNC_UNARY(popcnt)
DEFINE_EXPRESSION_FUNC_UNARY(clz)
DEFINE_EXPRESSION_FUNC_UNARY(logic_not)
DEFINE_EXPRESSION_OP_UNARY(~, bit_not)
DEFINE_EXPRESSION_OP_UNARY(-, neg)
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ UnaryOpStmt *IRBuilder::create_popcnt(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::popcnt, value));
}

UnaryOpStmt *IRBuilder::create_clz(Stmt *value) {
return insert(Stmt::make_typed<UnaryOpStmt>(UnaryOpType::clz, value));
}

BinaryOpStmt *IRBuilder::create_add(Stmt *l, Stmt *r) {
return insert(Stmt::make_typed<BinaryOpStmt>(BinaryOpType::add, l, r));
}
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ class IRBuilder {
UnaryOpStmt *create_exp(Stmt *value);
UnaryOpStmt *create_log(Stmt *value);
UnaryOpStmt *create_popcnt(Stmt *value);
UnaryOpStmt *create_clz(Stmt *value);

// Binary operations. Returns the result.
BinaryOpStmt *create_add(Stmt *l, Stmt *r);
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,7 @@ void export_lang(py::module &m) {
DEFINE_EXPRESSION_OP(exp)
DEFINE_EXPRESSION_OP(log)
DEFINE_EXPRESSION_OP(popcnt)
DEFINE_EXPRESSION_OP(clz)

DEFINE_EXPRESSION_OP(select)
DEFINE_EXPRESSION_OP(ifte)
Expand Down
1 change: 1 addition & 0 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _get_expected_matrix_apis():
"cinv",
"clamp",
"clog",
"clz",
"cmul",
"cos",
"cpow",
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,22 @@ def test_u64(x: ti.uint64) -> ti.int32:
assert test_i64(10000) == 5


@test_utils.test(arch=[ti.cpu, ti.metal, ti.cuda, ti.vulkan])
def test_clz():
@ti.kernel
def test_i32(x: ti.int32) -> ti.int32:
return ti.math.clz(x)

# assert test_i32(0) == 32
assert test_i32(1) == 31
assert test_i32(2) == 30
assert test_i32(3) == 30
assert test_i32(4) == 29
assert test_i32(5) == 29
assert test_i32(1023) == 22
assert test_i32(1024) == 21


@test_utils.test(arch=[ti.metal])
def test_popcnt():
@ti.kernel
Expand Down
Loading