From de39f739b08a47be4df465177d2d2326772278c4 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Tue, 18 Apr 2023 01:26:11 +0800 Subject: [PATCH] Add customized C APIs for tf.net. --- tensorflow/c/c_api.cc | 142 ++++++++++++++++++ tensorflow/c/c_api.h | 39 +++++ tensorflow/core/framework/BUILD | 18 ++- .../core/framework/cpp_shape_inference.proto | 36 +++++ 4 files changed, 233 insertions(+), 2 deletions(-) create mode 100644 tensorflow/core/framework/cpp_shape_inference.proto diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 7a7149c4fd96e9..d8f86cbd2134e9 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -50,10 +50,12 @@ limitations under the License. #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" // NOLINT +#include "tensorflow/core/framework/cpp_shape_inference.pb.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" @@ -71,6 +73,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/version.h" +#include "tensorflow/core/framework/full_type.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" // The implementation below is at the top level instead of the // brain namespace because we are defining 'extern "C"' functions. @@ -2614,6 +2618,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } } +// TF Customized C APIs for Tensorflow.NET -------------------------- + +void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { + mutex_lock l(graph->mu); + graph->graph.AddControlEdge(&input->node, &op->node); + tensorflow::RecordMutation(graph, *op, "adding control input"); +} + +void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status) { + using tensorflow::RecordMutation; + tensorflow::AttrValue attr_val; + if (!attr_val.ParseFromArray(attr_value_proto->data, + attr_value_proto->length)) { + status->status = + tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); + return; + } + + mutex_lock l(graph->mu); + op->node.AddAttr(attr_name, attr_val); + tensorflow::RecordMutation(graph, *op, "setting attribute"); +} + +void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Status* status) { + mutex_lock l(graph->mu); + op->node.ClearAttr(attr_name); + tensorflow::RecordMutation(graph, *op, "clearing attribute"); +} + +void TFC_SetFullType(TF_Graph* graph, TF_Operation* op, + const tensorflow::FullTypeDef& full_type) { + mutex_lock l(graph->mu); + *op->node.mutable_def()->mutable_experimental_type() = full_type; + tensorflow::RecordMutation(graph, *op, "setting fulltype"); +} + +void TFC_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { + mutex_lock l(graph->mu); + op->node.set_requested_device(device); + tensorflow::RecordMutation(graph, *op, "setting device"); +} + +void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + TF_UpdateEdge(graph, new_src, dst, status); +} + +void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const tensorflow::Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const tensorflow::Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } +} + +void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + +void TFC_ExtendSession(TF_Session* session, TF_Status* status) { + ExtendSessionGraphHelper(session, status); + session->extend_before_run = false; +} + +const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) { + Node* node = &output.oper->node; + tensorflow::CppShapeInferenceResult::HandleData handle_data; + handle_data.set_is_set(true); + { + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + CHECK(ic != nullptr); + CHECK_LT(output.index, ic->num_outputs()); + const auto* shapes_and_types = + ic->output_handle_shapes_and_types(output.index); + if (shapes_and_types == nullptr) return ""; + + for (const auto& p : *shapes_and_types) { + auto* out_shape_and_type = handle_data.add_shape_and_type(); + ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape()); + out_shape_and_type->set_dtype(p.dtype); + *out_shape_and_type->mutable_type() = p.type; + } + } + string result; + handle_data.SerializeToString(&result); + return result.c_str(); +} + +void TFC_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto, + size_t proto_len, TF_Status* status) { + tensorflow::CppShapeInferenceResult::HandleData handle_data; + if (!handle_data.ParseFromArray(proto, proto_len)) { + status->status = tensorflow::errors::InvalidArgument( + "Couldn't deserialize HandleData proto"); + return; + } + DCHECK(handle_data.is_set()); + + tensorflow::mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(&output.oper->node); + + std::vector shapes_and_types; + for (const auto& shape_and_type_proto : handle_data.shape_and_type()) { + tensorflow::shape_inference::ShapeHandle shape; + status->status = + ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape); + if (TF_GetCode(status) != TF_OK) return; + shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(), + shape_and_type_proto.type()); + } + ic->set_output_handle_shapes_and_types(output.index, shapes_and_types); +} + +void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.AddWhileInputHack(&new_src.oper->node, + new_src.index, &dst->node); + if (TF_GetCode(status) == TF_OK) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + tensorflow::RecordMutation(graph, *dst, "adding input tensor"); + } +} + +// ------------------------------------------------------------------- + // TF_Server functions ---------------------------------------------- #if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD) diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 9de98ee5a07fed..5e0363b596e567 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tstring.h" +#include "tensorflow/core/framework/full_type.pb.h" // -------------------------------------------------------------------------- // C API for TensorFlow. @@ -1620,6 +1621,44 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener( TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin( const char* plugin_filename, TF_Status* status); +TF_CAPI_EXPORT extern void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); + +TF_CAPI_EXPORT extern void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Buffer* attr_value_proto, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, + const char* attr_name, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFC_SetFullType(TF_Graph* graph, TF_Operation* op, + const tensorflow::FullTypeDef& full_type); + +TF_CAPI_EXPORT extern void TFC_SetRequestedDevice(TF_Graph* graph, + TF_Operation* op, + const char* device); + +TF_CAPI_EXPORT extern void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, + TF_Input dst, TF_Status* status); + +TF_CAPI_EXPORT extern void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + +TF_CAPI_EXPORT extern void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require); + +TF_CAPI_EXPORT extern void TFC_ExtendSession(TF_Session* session, TF_Status* status); + +TF_CAPI_EXPORT extern const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output); + +TF_CAPI_EXPORT extern void TFC_SetHandleShapeAndType(TF_Graph* graph, + TF_Output output, + const void* proto, + size_t proto_len, + TF_Status* status); + +TF_CAPI_EXPORT extern void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/core/framework/BUILD b/tensorflow/core/framework/BUILD index eb45d6206f2fa3..937327f212ae27 100644 --- a/tensorflow/core/framework/BUILD +++ b/tensorflow/core/framework/BUILD @@ -116,6 +116,7 @@ exports_files( srcs = [ "allocation_description.proto", "api_def.proto", + "cpp_shape_inference.proto", "attr_value.proto", "cost_graph.proto", "dataset_metadata.proto", @@ -1402,7 +1403,7 @@ cc_library( # protos from the same package, so we can build the protos here and then # link them from core:protos_all without circular dependencies. -# Generate the C++ sources for some of the protos. +#Generate the C++ sources for some of the protos. tf_generate_proto_text_sources( name = "attr_value_proto_text", srcs = [ @@ -1693,6 +1694,18 @@ tf_proto_library( ], ) +tf_proto_library( + name = "cpp_shape_inference_proto", + srcs = ["cpp_shape_inference.proto"], + cc_api_version = 2, + make_default_target_header_only = True, + protodeps = [ + ":full_type_proto", + ":tensor_shape_proto", + ":types_proto", + ], +) + tf_proto_library( name = "variable_proto", srcs = ["variable.proto"], @@ -1760,7 +1773,7 @@ tf_proto_library( # ":function_proto", # ], # ) -# copybara:uncomment_end +#copybara : uncomment_end tf_proto_library( name = "summary_proto", @@ -1806,6 +1819,7 @@ tf_proto_library( protodeps = [ ":allocation_description_proto", ":api_def_proto", + ":cpp_shape_inference_proto", ":attr_value_proto", ":cost_graph_proto", ":dataset_metadata_proto", diff --git a/tensorflow/core/framework/cpp_shape_inference.proto b/tensorflow/core/framework/cpp_shape_inference.proto new file mode 100644 index 00000000000000..d2fd1f29f23b87 --- /dev/null +++ b/tensorflow/core/framework/cpp_shape_inference.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package tensorflow; + +import "tensorflow/core/framework/full_type.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +option cc_enable_arenas = true; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto"; + +message CppShapeInferenceResult { + message HandleShapeAndType { + reserved 3; + + TensorShapeProto shape = 1; + DataType dtype = 2; + FullTypeDef type = 4; + } + message HandleData { + bool is_set = 1; + + // Only valid if . + repeated HandleShapeAndType shape_and_type = 2; + } + TensorShapeProto shape = 1; + + reserved 2; // was handle_shape + reserved 3; // was handle_dtype + HandleData handle_data = 4; +} + +message CppShapeInferenceInputsNeeded { + repeated int32 input_tensors_needed = 1; + repeated int32 input_tensors_as_shapes_needed = 2; +}