Skip to content

Commit

Permalink
fix device_type2sub_tsk_gph_builder_
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowingsun007 committed Nov 12, 2024
1 parent f0d13a6 commit cfde327
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions oneflow/core/graph/task_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -880,13 +880,14 @@ DEFINE_BLD_SUB_TASK_GRAPH_METHOD(BldSubTskGphByBoxing) {
if (device_type != DeviceType::kCPU
&& device_type2sub_tsk_gph_builder_.find(device_type)
!= device_type2sub_tsk_gph_builder_.end()) {
status = CHECK_JUST( // NOLINT
auto maybe_status = // NOLINT
device_type2sub_tsk_gph_builder_ // NOLINT
.at(device_type) // NOLINT
->Build(sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, // NOLINT
&sorted_ctrl_tasks, src_parallel_desc, dst_parallel_desc, lbi, // NOLINT
blob_desc, src_nd_sbp, dst_nd_sbp, // NOLINT
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get()))); // NOLINT
*(CHECK_JUST(src_op_node->op().GetOpTimeShape()).get())); // NOLINT
if (maybe_status.IsOk()) { status = CHECK_JUST(maybe_status); }
} else {
status = CHECK_JUST(hierarchical_sub_tsk_gph_builder_->Build(
sub_tsk_gph_builder_ctx_.get(), in_nodes, &out_nodes, &sorted_ctrl_tasks,
Expand Down Expand Up @@ -1052,6 +1053,12 @@ Maybe<void> GlobalTaskGraph::Init() {
OpGraph* op_graph = Singleton<OpGraph>::Get();
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
// Register the corresponding task graph builder based on the device type and store them to map
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
GlobalDeviceType2CreateSubTskGphBuilderFn();
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
}
hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());
HashMap<const OpNode*, std::vector<CompTaskNode*>> op_node2sorted_comp_tasks;

Expand Down Expand Up @@ -1088,6 +1095,13 @@ Maybe<void> BoxingTaskGraph::Init(
OpGraph* op_graph = Singleton<OpGraph>::Get();
sub_tsk_gph_builder_ctx_.reset(new SubTskGphBuilderCtx(this));
boxing_logger_ = CreateBoxingLogger();
// Register the corresponding task graph builder based on the device type and store them to map
const auto* global_device_type_create_sub_tsk_gph_builder_fn =
GlobalDeviceType2CreateSubTskGphBuilderFn();
for (const auto& pair : *global_device_type_create_sub_tsk_gph_builder_fn) {
device_type2sub_tsk_gph_builder_.emplace(pair.first, pair.second());
}

hierarchical_sub_tsk_gph_builder_.reset(new DispatchHierarchicalSubTskGphBuilder());

const auto& TryCreateSortedCompTaskNodes = [&](const OpNode* op_node) {
Expand Down

0 comments on commit cfde327

Please sign in to comment.