Skip to content

Commit

Permalink
[Lang] Migrate irpass::force_scalarize_matrix() beforehand (#8532)
Browse files Browse the repository at this point in the history
Issue: #

### Brief Summary

copilot:summary

### Walkthrough

copilot:walkthrough
  • Loading branch information
jim19930609 authored Jun 11, 2024
1 parent 9fcf4f0 commit 06826c9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void re_id(IRNode *root);
void flag_access(IRNode *root);
void eliminate_immutable_local_vars(IRNode *root);
bool scalarize(IRNode *root, bool half2_optimization_enabled = false);
void lower_matrix_ptr(IRNode *root);
void lower_matrix_ptr(IRNode *root, bool force_scalarize = false);
bool die(IRNode *root);
bool simplify(IRNode *root, const CompileConfig &config);
bool cfg_optimization(
Expand Down
15 changes: 9 additions & 6 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,16 @@ void compile_to_offloads(IRNode *ir,
}

// Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt
irpass::lower_matrix_ptr(ir);
irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix);
print("Matrix ptr lowered");

if (config.force_scalarize_matrix) {
irpass::scalarize(ir, false /*half2_optimization_enabled*/);

irpass::die(ir);
print("Scalarized");
}

irpass::full_simplify(
ir, config,
{false, /*autodiff_enabled*/ autodiff_mode != AutodiffMode::kNone,
Expand All @@ -86,10 +93,6 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::gather_meshfor_relation_types(ir);
}

if (config.force_scalarize_matrix) {
irpass::scalarize(ir, false /*half2_optimization_enabled*/);
}

if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
// Check whether the kernel obeys the autodiff limitation e.g., gloabl data
// access rule
Expand Down Expand Up @@ -366,7 +369,7 @@ void compile_function(IRNode *ir,
}

// Removes MatrixOfMatrixPtrStmt & MatrixOfGlobalPtrStmt
irpass::lower_matrix_ptr(ir);
irpass::lower_matrix_ptr(ir, config.force_scalarize_matrix);
print("Matrix ptr lowered");

irpass::demote_atomics(ir, config);
Expand Down
10 changes: 6 additions & 4 deletions taichi/transforms/lower_matrix_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,13 +593,15 @@ class RemoveMatrixOfPtr : public BasicStmtVisitor {

namespace irpass {

void lower_matrix_ptr(IRNode *root) {
void lower_matrix_ptr(IRNode *root, bool force_scalarize) {
TI_AUTO_PROF;

GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root);
if (!force_scalarize) {
GatherValidAOSGlobalPtrStmt gather_valid_aos_global_ptr_pass(root);

LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass(
root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_);
LowerAOSGlobalPtrStmt lower_aos_global_ptr_stmt_pass(
root, gather_valid_aos_global_ptr_pass.invalid_aos_global_ptr_stmts_);
}

ScalarizeMatrixPtr scalarize_matrix_ptr_pass(root);
LowerMatrixPtr::run(root);
Expand Down

0 comments on commit 06826c9

Please sign in to comment.