@@ -880,13 +880,14 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
880
880
if (device_type != DeviceType::kCPU
881
881
&& device_type2sub_tsk_gph_builder_.find (device_type)
882
882
!= device_type2sub_tsk_gph_builder_.end ()) {
883
- status = CHECK_JUST ( // NOLINT
883
+ auto maybe_status = // NOLINT
884
884
device_type2sub_tsk_gph_builder_ // NOLINT
885
885
.at (device_type) // NOLINT
886
886
->Build (sub_tsk_gph_builder_ctx_.get (), in_nodes, &out_nodes, // NOLINT
887
887
&sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT
888
888
blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT
889
- *(CHECK_JUST (src_op_node->op ().GetOpTimeShape ()).get ()))); // NOLINT
889
+ *(CHECK_JUST (src_op_node->op ().GetOpTimeShape ()).get ())); // NOLINT
890
+ if (maybe_status.IsOk ()) { status = CHECK_JUST (maybe_status); }
890
891
} else {
891
892
status = CHECK_JUST (hierarchical_sub_tsk_gph_builder_->Build (
892
893
sub_tsk_gph_builder_ctx_.get (), in_nodes, &out_nodes, &sorted_ctrl_tasks,
@@ -1052,6 +1053,12 @@ Maybe<void> GlobalTaskGraph::Init() {
1052
1053
OpGraph* op_graph = Singleton<OpGraph>::Get ();
1053
1054
sub_tsk_gph_builder_ctx_.reset (new SubTskGphBuilderCtx (this ));
1054
1055
boxing_logger_ = CreateBoxingLogger ();
1056
+ // Register the corresponding task graph builder based on the device type and store them to map
1057
+ const auto * global_device_type_create_sub_tsk_gph_builder_fn =
1058
+ GlobalDeviceType2CreateSubTskGphBuilderFn ();
1059
+ for (const auto & pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
1060
+ device_type2sub_tsk_gph_builder_.emplace (pair.first , pair.second ());
1061
+ }
1055
1062
hierarchical_sub_tsk_gph_builder_.reset (new DispatchHierarchicalSubTskGphBuilder ());
1056
1063
HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;
1057
1064
@@ -1088,6 +1095,13 @@ Maybe<void> BoxingTaskGraph::Init(
1088
1095
OpGraph* op_graph = Singleton<OpGraph>::Get ();
1089
1096
sub_tsk_gph_builder_ctx_.reset (new SubTskGphBuilderCtx (this ));
1090
1097
boxing_logger_ = CreateBoxingLogger ();
1098
+ // Register the corresponding task graph builder based on the device type and store them to map
1099
+ const auto * global_device_type_create_sub_tsk_gph_builder_fn =
1100
+ GlobalDeviceType2CreateSubTskGphBuilderFn ();
1101
+ for (const auto & pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
1102
+ device_type2sub_tsk_gph_builder_.emplace (pair.first , pair.second ());
1103
+ }
1104
+
1091
1105
hierarchical_sub_tsk_gph_builder_.reset (new DispatchHierarchicalSubTskGphBuilder ());
1092
1106
1093
1107
const auto & TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {
0 commit comments