Skip to content

Commit cfde327

Browse files
committed
fix device_type2sub_tsk_gph_builder_
1 parent f0d13a6 commit cfde327

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

oneflow/core/graph/task_graph.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,14 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
880880
if (device_type != DeviceType::kCPU
881881
&& device_type2sub_tsk_gph_builder_.find(device_type)
882882
!= device_type2sub_tsk_gph_builder_.end()) {
883-
status = CHECK_JUST( // NOLINT
883+
auto maybe_status = // NOLINT
884884
device_type2sub_tsk_gph_builder_ // NOLINT
885885
.at(device_type) // NOLINT
886886
->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT
887887
&sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT
888888
blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT
889-
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT
889+
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT
890+
if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); }
890891
} else {
891892
status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(
892893
sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
@@ -1052,6 +1053,12 @@ Maybe<void> GlobalTaskGraph::Init() {
10521053
OpGraph* op_graph = Singleton<OpGraph>::Get();
10531054
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
10541055
boxing_logger_ = CreateBoxingLogger();
1056+
// Register the corresponding task graph builder based on the device type and store them to map
1057+
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
1058+
GlobalDeviceType2CreateSubTskGphBuilderFn();
1059+
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
1060+
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
1061+
}
10551062
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
10561063
HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;
10571064

@@ -1088,6 +1095,13 @@ Maybe<void> BoxingTaskGraph::Init(
10881095
OpGraph* op_graph = Singleton<OpGraph>::Get();
10891096
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
10901097
boxing_logger_ = CreateBoxingLogger();
1098+
// Register the corresponding task graph builder based on the device type and store them to map
1099+
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
1100+
GlobalDeviceType2CreateSubTskGphBuilderFn();
1101+
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
1102+
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
1103+
}
1104+
10911105
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
10921106

10931107
const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {

0 commit comments

Comments
 (0)