Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move nGraph rewrite pass registration to POST_PLACEMENT #691

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ngraph_bridge/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,14 @@ set(SRC
tf_graphcycles.cc
tf_deadness_analysis.cc
version.cc
grappler/ngraph_add_identityn.cc
)

message(STATUS "NGRAPH_TF_USE_GRAPPLER_OPTIMIZER: ${NGRAPH_TF_USE_GRAPPLER_OPTIMIZER}")
if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc)
list(APPEND SRC grappler/ngraph_optimizer.cc)
list(APPEND SRC grappler/ngraph_add_identityn.cc)
# list(APPEND SRC grappler/ngraph_add_identityn.cc)
add_definitions(-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
endif()

Expand Down
48 changes: 45 additions & 3 deletions ngraph_bridge/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
* limitations under the License.
*******************************************************************************/

#include "tensorflow/core/common_runtime/build_graph_options.h"
#include "tensorflow/core/graph/graph.h"

#include "ngraph_bridge/default_opset.h"
#include "ngraph_bridge/grappler/ngraph_add_identityn.h"
#include "ngraph_bridge/ngraph_api.h"
#include "ngraph_bridge/ngraph_backend_manager.h"
#include "ngraph_bridge/ngraph_mark_for_clustering.h"
Expand Down Expand Up @@ -739,8 +741,7 @@ GetTFToNgOpMap() {
//
// Main entry point for the marking pass.
//
Status MarkForClustering(Graph* graph,
const std::set<string> skip_these_nodes) {
Status MarkForClustering(Graph* graph, std::set<string> skip_these_nodes) {
const TypeConstraintMap& type_constraint_map = GetTypeConstraintMap();

// confirmation_function_map is non-const unlike the other maps
Expand Down Expand Up @@ -800,9 +801,50 @@ Status MarkForClustering(Graph* graph,
vector<Node*> nodes_marked_for_clustering;

shared_ptr<Backend> op_backend = BackendManager::GetBackend();
#if !defined NGRAPH_TF_USE_GRAPPLER_OPTIMIZER
std::set<string> disabled_nodes = {};
// Find a list of nodes that are of the types that are disabled
for (auto itr : graph->nodes()) {
if (disabled_ops_set.find(itr->type_string()) != disabled_ops_set.end()) {
disabled_nodes.insert(itr->name());
}
}
// const BuildGraphOptions options;
// cout << "trying " << options.callable_options.fetch().size() << endl;
// cout << "trying " << options.callable_options.tensor_connection().size()
// <<endl;
std::set<string> fetch_nodes;
for (auto edge : graph->edges()) {
Node* src = edge->src();
Node* dst = edge->dst();
// Skip source/sink
if (dst->IsSink()) {
cout << "Skip this node " << src->type_string() << endl;
fetch_nodes.insert(src->name());
}
}
cout << "Total nodes " << graph->num_nodes() << endl;
cout << "OP nodes " << graph->num_op_nodes() << endl;

// nodes_to_add_identity_to = fetch_nodes - disabled_nodes
std::set<string> nodes_to_add_identity_to;
std::set_difference(fetch_nodes.begin(), fetch_nodes.end(),
disabled_nodes.begin(), disabled_nodes.end(),
std::inserter(nodes_to_add_identity_to,
nodes_to_add_identity_to.begin()));

// Rewrite graph to add IdentityN node so the fetch node can be encapsulated
// as well
// If the fetch node in question has 0 outputs or any of the outputs
// has ref type as a data type then don't add IdentityN node, but the fetch
// node will be skipped from marking and clustering.
TF_RETURN_IF_ERROR(AddIdentityN(graph, nodes_to_add_identity_to));
skip_these_nodes = nodes_to_add_identity_to;
#endif

for (auto node : graph->op_nodes()) {
cout << node->type_string() << endl;
bool mark_for_clustering = false;

do {
// check if output node
bool skip_it = false;
Expand Down
3 changes: 2 additions & 1 deletion ngraph_bridge/ngraph_rewrite_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ mutex NGraphRewritePass::s_serial_counter_mutex;
class NGraphEncapsulationPass : public NGraphRewritePass {
public:
Status Run(const GraphOptimizationPassOptions& options) override {
// cout << "trying " << options.sessioncallable_options.fetch().size();
// If we don't get a main graph, log that fact and bail.
if (options.graph == nullptr) {
NGRAPH_VLOG(0) << "NGraphEncapsulationPass: options.graph == nullptr";
Expand Down Expand Up @@ -151,6 +152,6 @@ class NGraphEncapsulationPass : public NGraphRewritePass {

} // namespace ngraph_bridge

REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, 0,
REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0,
ngraph_bridge::NGraphEncapsulationPass);
} // namespace tensorflow
4 changes: 2 additions & 2 deletions test/tests_linux_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
# Read in one/more external manifest file(s)
# Path specified is relative to this file's path

tests_common.txt
#tests_common.txt

###################################################
[RUN]
# Specify tests/patterns/regex that should be included

MathOps.Abs1D
###################################################
[SKIP]
# Specify tests/patterns/regex that should be excluded/skipped
Expand Down