Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Langjian/dist broadcast #383

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
19 changes: 18 additions & 1 deletion src/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,18 @@ static Status TranslateDepthwiseConv2dNativeOp(
return Status::OK();
}

static Status TranslateBroadcastDistributedOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
shared_ptr<ng::Node> ng_input;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input));

auto ng_broadcast_distributed =
ConstructNgNode<ng::op::BroadcastDistributed>(op->name(), ng_input);
SaveNgOp(ng_op_map, op->name(), ng_broadcast_distributed);
return Status::OK();
}

static Status TranslateExpandDimsOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
Expand Down Expand Up @@ -4425,6 +4437,7 @@ const static std::map<
{"Greater", TranslateBinaryOp<ngraph::op::Greater>},
{"GreaterEqual", TranslateBinaryOp<ngraph::op::GreaterEq>},
{"HorovodAllreduce", TranslateAllreduceOp},
{"HorovodBroadcast", TranslateBroadcastDistributedOp},
{"Identity", TranslateIdentityOp},
{"L2Loss", TranslateL2LossOp},
{"Less", TranslateBinaryOp<ngraph::op::Less>},
Expand Down Expand Up @@ -4550,6 +4563,9 @@ Status Builder::TranslateGraph(
if (n->type_string() == "HorovodAllreduce") {
NGRAPH_VLOG(1) << "[NGRAPH_TF RANK: " << rank_id << "]: " << n->name();
}
if (n->type_string() == "HorovodBroadcast") {
NGRAPH_VLOG(1) << "[NGRAPH_TF RANK: " << rank_id << "]: " << n->name();
}
#endif
}
}
Expand Down Expand Up @@ -4649,7 +4665,8 @@ Status Builder::TranslateGraph(
ng_function = make_shared<ng::Function>(ng_result_list, ng_parameter_list);

#if defined NGRAPH_DISTRIBUTED
AllreduceOpControlOrder(ng_function);
OpControlOrder(ng_function, "AllReduce");
OpControlOrder(ng_function, "BroadcastDistributed");
#endif

//
Expand Down
3 changes: 3 additions & 0 deletions src/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ Status MarkForClustering(Graph* graph,
#if defined NGRAPH_DISTRIBUTED
confirmation_function_map["HorovodAllreduce"] =
SimpleConfirmationFunction();
confirmation_function_map["HorovodBroadcast"] =
SimpleConfirmationFunction();
#endif
confirmation_function_map["Identity"] = SimpleConfirmationFunction();
confirmation_function_map["L2Loss"] = SimpleConfirmationFunction();
Expand Down Expand Up @@ -437,6 +439,7 @@ Status MarkForClustering(Graph* graph,
type_constraint_map["GreaterEqual"]["T"] = NGraphDTypes();
#if defined NGRAPH_DISTRIBUTED
type_constraint_map["HorovodAllreduce"]["T"] = NGraphNumericDTypes();
type_constraint_map["HorovodBroadcast"]["T"] = NGraphNumericDTypes();
#endif
type_constraint_map["Identity"]["T"] = NGraphDTypes();
type_constraint_map["L2Loss"]["T"] = NGraphNumericDTypes();
Expand Down
23 changes: 13 additions & 10 deletions src/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/default/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#if defined NGRAPH_DISTRIBUTED
#include "ngraph/distributed.hpp"
#endif

using namespace std;
namespace ng = ngraph;
Expand Down Expand Up @@ -356,24 +359,24 @@ bool DumpTrackedGraphs() {
std::getenv("NGRAPH_TF_DUMP_TRACKED_GRAPHS") != nullptr;
}

void AllreduceOpControlOrder(
const std::shared_ptr<ngraph::Function>& ng_function) {
void OpControlOrder(const std::shared_ptr<ngraph::Function>& ng_function,
const std::string& op_name) {
// Get the serialized ops and stored the allreduce ops to a vector and
ng::NodeVector allreduce_op_list;
ng::NodeVector op_list;
for (const shared_ptr<ng::Node>& node : ng_function->get_ordered_ops()) {
if (node->description() == "AllReduce") {
allreduce_op_list.push_back(node);
if (node->description() == op_name) {
op_list.push_back(node);
}
// Sort the allreduce ops according to the TF names
std::sort(allreduce_op_list.begin(), allreduce_op_list.end(),
std::sort(op_list.begin(), op_list.end(),
[](const shared_ptr<ng::Node>& x, const shared_ptr<ng::Node>& y) {
return x->get_friendly_name() < y->get_friendly_name();
});
// Add control dependency in for the allreduce ops
if (allreduce_op_list.size() > 1) {
for (size_t i = 1; i < allreduce_op_list.size(); ++i) {
auto pre_node = allreduce_op_list[i - 1];
auto cur_node = allreduce_op_list[i];
if (op_list.size() > 1) {
for (size_t i = 1; i < op_list.size(); ++i) {
auto pre_node = op_list[i - 1];
auto cur_node = op_list[i];
cur_node->add_control_dependency(pre_node);
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ bool DumpEncapsulatedGraphs();
bool DumpTrackedGraphs();

// Insert constrol dependency for AllReduce ops to ensure execution order
void AllreduceOpControlOrder(const std::shared_ptr<ngraph::Function>&);
void OpControlOrder(const std::shared_ptr<ngraph::Function>&,
const std::string&);
} // namespace ngraph_bridge

} // namespace tensorflow
Expand Down