Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit b32b52f

Browse files
committed
Add Horovod distributed broadcast op to the bridge
1 parent 0f287c5 commit b32b52f

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/ngraph_builder.cc

+11
Original file line numberDiff line numberDiff line change
@@ -1599,6 +1599,16 @@ static Status TranslateDepthwiseConv2dNativeOp(
15991599
return Status::OK();
16001600
}
16011601

1602+
static Status TranslateDistBroadcastOp(
1603+
const Node* op, const std::vector<const Tensor*>& static_input_map,
1604+
Builder::OpMap& ng_op_map) {
1605+
shared_ptr<ng::Node> ng_input;
1606+
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input));
1607+
1608+
SaveNgOp(ng_op_map, op->name(), make_shared<ng::op::DistBroadcast>(ng_input));
1609+
return Status::OK();
1610+
}
1611+
16021612
static Status TranslateExpandDimsOp(
16031613
const Node* op, const std::vector<const Tensor*>& static_input_map,
16041614
Builder::OpMap& ng_op_map) {
@@ -4070,6 +4080,7 @@ const static std::map<
40704080
{"Greater", TranslateBinaryOp<ngraph::op::Greater>},
40714081
{"GreaterEqual", TranslateBinaryOp<ngraph::op::GreaterEq>},
40724082
{"HorovodAllreduce", TranslateAllreduceOp},
4083+
{"HorovodBroadcast", TranslateDistBroadcastOp},
40734084
{"Identity", TranslateIdentityOp},
40744085
{"L2Loss", TranslateL2LossOp},
40754086
{"Less", TranslateBinaryOp<ngraph::op::Less>},

src/ngraph_mark_for_clustering.cc

+3
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ Status MarkForClustering(Graph* graph) {
256256
#ifdef NGRAPH_DISTRIBUTED
257257
confirmation_function_map["HorovodAllreduce"] =
258258
SimpleConfirmationFunction();
259+
confirmation_function_map["HorovodBroadcast"] =
260+
SimpleConfirmationFunction();
259261
#endif
260262
confirmation_function_map["Identity"] = SimpleConfirmationFunction();
261263
confirmation_function_map["L2Loss"] = SimpleConfirmationFunction();
@@ -388,6 +390,7 @@ Status MarkForClustering(Graph* graph) {
388390
type_constraint_map["GreaterEqual"]["T"] = NGraphDTypes();
389391
#ifdef NGRAPH_DISTRIBUTED
390392
type_constraint_map["HorovodAllreduce"]["T"] = NGraphNumericDTypes();
393+
type_constraint_map["HorovodBroadcast"]["T"] = NGraphNumericDTypes();
391394
#endif
392395
type_constraint_map["Identity"]["T"] = NGraphDTypes();
393396
type_constraint_map["L2Loss"]["T"] = NGraphNumericDTypes();

0 commit comments

Comments
 (0)