diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index d192fea0c9fc6..e259886d92002 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -1120,6 +1120,18 @@ def py_ifte(cond, x1, x2): return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2) +def clz(a): + """Count the number of leading zeros for a 32bit integer""" + + def _clz(x): + for i in range(32): + if 2**i > x: + return 32 - i + return 0 + + return _unary_operation(_ti_core.expr_clz, _clz, a) + + @writeback_binary def atomic_add(x, y): """Atomically compute `x + y`, store the result in `x`, diff --git a/python/taichi/math/mathimpl.py b/python/taichi/math/mathimpl.py index 0d2b6eb87c123..85fb4565a5629 100644 --- a/python/taichi/math/mathimpl.py +++ b/python/taichi/math/mathimpl.py @@ -812,12 +812,18 @@ def popcnt(x): return ops.popcnt(x) +@func +def clz(x): + return ops.clz(x) + + __all__ = [ "acos", "asin", "atan2", "ceil", "clamp", + "clz", "cos", "cross", "degrees", diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 734b941d5fe8a..04112e92bcae6 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -286,6 +286,15 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { } else { TI_NOT_IMPLEMENTED } + } else if (op == UnaryOpType::clz) { + if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { + stmt->ret_type = PrimitiveType::i32; + llvm_val[stmt] = call("__nv_clz", input); + } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { + llvm_val[stmt] = call("__nv_clzll", input); + } else { + TI_NOT_IMPLEMENTED + } } else if (op == UnaryOpType::log) { if (input_taichi_type->is_primitive(PrimitiveTypeID::f32)) { // logf has fast-math option diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4ae01c67af1d3..85e9b1788098e 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -207,6 +207,12 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { llvm_val[stmt] = builder->CreateIntrinsic(llvm::Intrinsic::ctpop, {input_type}, {input}); } + else if (op == UnaryOpType::clz) { + llvm_val[stmt] = builder->CreateIntrinsic( + llvm::Intrinsic::ctlz, {input_type}, + {input, + llvm::ConstantInt::get(llvm::Type::getInt1Ty(*llvm_context), 0)}); + } else { TI_P(unary_op_type_name(op)); TI_NOT_IMPLEMENTED diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 8d072707762d9..e1e1124fd58fe 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -923,6 +923,12 @@ class TaskCodegen : public IRVisitor { ir_->store_variable(val, v); } else if (stmt->op_type == UnaryOpType::popcnt) { val = ir_->popcnt(operand_val); + } else if (stmt->op_type == UnaryOpType::clz) { + uint32_t FindMSB_id = 74; + spirv::Value msb = ir_->call_glsl450(dst_type, FindMSB_id, operand_val); + spirv::Value bitcnt = ir_->int_immediate_number(ir_->i32_type(), 32); + spirv::Value one = ir_->int_immediate_number(ir_->i32_type(), 1); + val = ir_->sub(ir_->sub(bitcnt, msb), one); } #define UNARY_OP_TO_SPIRV(op, instruction, instruction_id, max_bits) \ else if (stmt->op_type == UnaryOpType::op) { \ diff --git a/taichi/inc/unary_op.inc.h b/taichi/inc/unary_op.inc.h index d0ec0c790a7d4..b7fa8abc109b9 100644 --- a/taichi/inc/unary_op.inc.h +++ b/taichi/inc/unary_op.inc.h @@ -19,6 +19,7 @@ PER_UNARY_OP(rcp) PER_UNARY_OP(exp) PER_UNARY_OP(log) PER_UNARY_OP(popcnt) +PER_UNARY_OP(clz) PER_UNARY_OP(rsqrt) PER_UNARY_OP(bit_not) PER_UNARY_OP(logic_not) diff --git a/taichi/ir/expression_ops.h b/taichi/ir/expression_ops.h index cb95181a9e101..48024b96d971d 100644 --- a/taichi/ir/expression_ops.h +++ b/taichi/ir/expression_ops.h @@ -87,6 +87,7 @@ DEFINE_EXPRESSION_FUNC_UNARY(rsqrt) DEFINE_EXPRESSION_FUNC_UNARY(exp) DEFINE_EXPRESSION_FUNC_UNARY(log) DEFINE_EXPRESSION_FUNC_UNARY(popcnt) +DEFINE_EXPRESSION_FUNC_UNARY(clz) DEFINE_EXPRESSION_FUNC_UNARY(logic_not) DEFINE_EXPRESSION_OP_UNARY(~, bit_not) DEFINE_EXPRESSION_OP_UNARY(-, neg) diff --git a/taichi/ir/ir_builder.cpp b/taichi/ir/ir_builder.cpp index 9abf05109effa..b28412cd441f4 100644 --- a/taichi/ir/ir_builder.cpp +++ b/taichi/ir/ir_builder.cpp @@ -278,6 +278,10 @@ UnaryOpStmt *IRBuilder::create_popcnt(Stmt *value) { return insert(Stmt::make_typed(UnaryOpType::popcnt, value)); } +UnaryOpStmt *IRBuilder::create_clz(Stmt *value) { + return insert(Stmt::make_typed(UnaryOpType::clz, value)); +} + BinaryOpStmt *IRBuilder::create_add(Stmt *l, Stmt *r) { return insert(Stmt::make_typed(BinaryOpType::add, l, r)); } diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index b705a0c8e1666..c585ed7be425e 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -180,6 +180,7 @@ class IRBuilder { UnaryOpStmt *create_exp(Stmt *value); UnaryOpStmt *create_log(Stmt *value); UnaryOpStmt *create_popcnt(Stmt *value); + UnaryOpStmt *create_clz(Stmt *value); // Binary operations. Returns the result. BinaryOpStmt *create_add(Stmt *l, Stmt *r); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 1b425de5dad74..cfdc8e987ae2d 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -972,6 +972,7 @@ void export_lang(py::module &m) { DEFINE_EXPRESSION_OP(exp) DEFINE_EXPRESSION_OP(log) DEFINE_EXPRESSION_OP(popcnt) + DEFINE_EXPRESSION_OP(clz) DEFINE_EXPRESSION_OP(select) DEFINE_EXPRESSION_OP(ifte) diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 651f53f2954cf..ef30eb4026ab1 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -288,6 +288,7 @@ def _get_expected_matrix_apis(): "cinv", "clamp", "clog", + "clz", "cmul", "cos", "cpow", diff --git a/tests/python/test_unary_ops.py b/tests/python/test_unary_ops.py index 6fe014bee872b..f1770aa8e66cc 100644 --- a/tests/python/test_unary_ops.py +++ b/tests/python/test_unary_ops.py @@ -126,6 +126,22 @@ def test_u64(x: ti.uint64) -> ti.int32: assert test_i64(10000) == 5 +@test_utils.test(arch=[ti.cpu, ti.metal, ti.cuda, ti.vulkan]) +def test_clz(): + @ti.kernel + def test_i32(x: ti.int32) -> ti.int32: + return ti.math.clz(x) + + # assert test_i32(0) == 32 + assert test_i32(1) == 31 + assert test_i32(2) == 30 + assert test_i32(3) == 30 + assert test_i32(4) == 29 + assert test_i32(5) == 29 + assert test_i32(1023) == 22 + assert test_i32(1024) == 21 + + @test_utils.test(arch=[ti.metal]) def test_popcnt(): @ti.kernel