diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index f6699533ee4d5..b38846973ff06 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -32,7 +32,7 @@ void re_id(IRNode *root); void flag_access(IRNode *root); void eliminate_immutable_local_vars(IRNode *root); bool scalarize(IRNode *root, bool half2_optimization_enabled = false); -void lower_matrix_ptr(IRNode *root); +void lower_matrix_ptr(IRNode *root, bool force_scalarize = false); bool die(IRNode *root); bool simplify(IRNode *root, const CompileConfig &config); bool cfg_optimization( diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 0507b9fd52344..b8dc6b31bf94c 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -69,9 +69,16 @@ void compile_to_offloads(IRNode *ir, } // Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt - irpass::lower_matrix_ptr(ir); + irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix); print("Matrix ptr lowered"); + if (config.force_scalarize_matrix) { + irpass::scalarize(ir, false /*half2_optimization_enabled*/); + + irpass::die(ir); + print("Scalarized"); + } + irpass::full_simplify( ir, config, {false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone, @@ -86,10 +93,6 @@ void compile_to_offloads(IRNode *ir, irpass::analysis::gather_meshfor_relation_types(ir); } - if (config.force_scalarize_matrix) { - irpass::scalarize(ir, false /*half2_optimization_enabled*/); - } - if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) { // Check whether the kernel obeys the autodiff limitation e.g., gloabl data // access rule @@ -366,7 +369,7 @@ void compile_function(IRNode *ir, } // Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt - irpass::lower_matrix_ptr(ir); + irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix); print("Matrix ptr lowered"); irpass::demote_atomics(ir, config); diff --git a/taichi/transforms/lower_matrix_ptr.cpp b/taichi/transforms/lower_matrix_ptr.cpp index e4deb59192e84..65c2ba602f400 100644 --- a/taichi/transforms/lower_matrix_ptr.cpp +++ b/taichi/transforms/lower_matrix_ptr.cpp @@ -593,13 +593,15 @@ class RemoveMatrixOfPtr : public BasicStmtVisitor { namespace irpass { -void lower_matrix_ptr(IRNode *root) { +void lower_matrix_ptr(IRNode *root, bool force_scalarize) { TI_AUTO_PROF; - GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root); + if (!force_scalarize) { + GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root); - LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass( - root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_); + LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass( + root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_); + } ScalarizeMatrixPtr scalarize_matrix_ptr_pass(root); LowerMatrixPtr::run(root);