Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add customized C APIs for tf.net of v2.10. #2

Open
wants to merge 3 commits into
base: r2.10
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/cpp_shape_inference.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
Expand All @@ -71,6 +73,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"

// The implementation below is at the top level instead of the
// brain namespace because we are defining 'extern "C"' functions.
Expand Down Expand Up @@ -2614,6 +2618,147 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
}
}

// TF Customized C APIs for Tensorflow.NET --------------------------

void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
mutex_lock l(graph->mu);
graph->graph.AddControlEdge(&input->node, &op->node);
tensorflow::RecordMutation(graph, *op, "adding control input");
}

void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status) {
using tensorflow::RecordMutation;
tensorflow::AttrValue attr_val;
if (!attr_val.ParseFromArray(attr_value_proto->data,
attr_value_proto->length)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
return;
}

mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
tensorflow::RecordMutation(graph, *op, "setting attribute");
}

void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Status* status) {
mutex_lock l(graph->mu);
op->node.ClearAttr(attr_name);
tensorflow::RecordMutation(graph, *op, "clearing attribute");
}

void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
const tensorflow::FullTypeDef& full_type) {
mutex_lock l(graph->mu);
*op->node.mutable_def()->mutable_experimental_type() = full_type;
tensorflow::RecordMutation(graph, *op, "setting fulltype");
}

void TFC_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
tensorflow::RecordMutation(graph, *op, "setting device");
}

void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
TF_UpdateEdge(graph, new_src, dst, status);
}

void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
mutex_lock l(graph->mu);
std::vector<const tensorflow::Edge*> control_edges;
for (const tensorflow::Edge* edge : op->node.in_edges()) {
if (!edge->IsControlEdge()) continue;
control_edges.push_back(edge);
}
for (const tensorflow::Edge* edge : control_edges) {
graph->graph.RemoveControlEdge(edge);
}
}

void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
mutex_lock l(graph->mu);
graph->refiner.set_require_shape_inference_fns(require);
}

void TFC_ExtendSession(TF_Session* session, TF_Status* status) {
ExtendSessionGraphHelper(session, status);
session->extend_before_run = false;
}

TF_Buffer* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node *node = &output.oper->node;
tensorflow::CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
{
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext *ic =
graph->refiner.GetContext(node);
CHECK(ic != nullptr);
CHECK_LT(output.index, ic->num_outputs());
const auto *shapes_and_types =
ic->output_handle_shapes_and_types(output.index);
if (shapes_and_types == nullptr)
return nullptr;

for (const auto &p : *shapes_and_types) {
auto *out_shape_and_type = handle_data.add_shape_and_type();
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype);
*out_shape_and_type->mutable_type() = p.type;
}
}
string str_data;
handle_data.SerializeToString(&str_data);

TF_Buffer *result = TF_NewBufferFromString(str_data.c_str(), str_data.size());
return result;
}

void TFC_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Couldn't deserialize HandleData proto");
return;
}
DCHECK(handle_data.is_set());

tensorflow::mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&output.oper->node);

std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
shape_and_type_proto.type());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}

void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status) {
mutex_lock l(graph->mu);
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
new_src.index, &dst->node);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
tensorflow::RecordMutation(graph, *dst, "adding input tensor");
}
}

// -------------------------------------------------------------------

// TF_Server functions ----------------------------------------------

#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
Expand Down
39 changes: 39 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_tstring.h"
#include "tensorflow/core/framework/full_type.pb.h"

// --------------------------------------------------------------------------
// C API for TensorFlow.
Expand Down Expand Up @@ -1620,6 +1621,44 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener(
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
const char* plugin_filename, TF_Status* status);

TF_CAPI_EXPORT extern void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);

TF_CAPI_EXPORT extern void TFC_SetAttr(TF_Graph* graph, TF_Operation* op,
const char* attr_name,
TF_Buffer* attr_value_proto,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op,
const char* attr_name,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
const tensorflow::FullTypeDef& full_type);

TF_CAPI_EXPORT extern void TFC_SetRequestedDevice(TF_Graph* graph,
TF_Operation* op,
const char* device);

TF_CAPI_EXPORT extern void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);

TF_CAPI_EXPORT extern void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);

TF_CAPI_EXPORT extern void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require);

TF_CAPI_EXPORT extern void TFC_ExtendSession(TF_Session* session, TF_Status* status);

TF_CAPI_EXPORT extern TF_Buffer* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output);

TF_CAPI_EXPORT extern void TFC_SetHandleShapeAndType(TF_Graph* graph,
TF_Output output,
const void* proto,
size_t proto_len,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status);

#ifdef __cplusplus
} /* end extern "C" */
#endif
Expand Down
1 change: 1 addition & 0 deletions tensorflow/c/version_script.lds
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ VERS_1.0 {
global:
*TF_*;
*TFE_*;
*TFC_*;

# Hide everything else.
local:
Expand Down
18 changes: 16 additions & 2 deletions tensorflow/core/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ exports_files(
srcs = [
"allocation_description.proto",
"api_def.proto",
"cpp_shape_inference.proto",
"attr_value.proto",
"cost_graph.proto",
"dataset_metadata.proto",
Expand Down Expand Up @@ -1402,7 +1403,7 @@ cc_library(
# protos from the same package, so we can build the protos here and then
# link them from core:protos_all without circular dependencies.

# Generate the C++ sources for some of the protos.
#Generate the C++ sources for some of the protos.
tf_generate_proto_text_sources(
name = "attr_value_proto_text",
srcs = [
Expand Down Expand Up @@ -1693,6 +1694,18 @@ tf_proto_library(
],
)

tf_proto_library(
name = "cpp_shape_inference_proto",
srcs = ["cpp_shape_inference.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
protodeps = [
":full_type_proto",
":tensor_shape_proto",
":types_proto",
],
)

tf_proto_library(
name = "variable_proto",
srcs = ["variable.proto"],
Expand Down Expand Up @@ -1760,7 +1773,7 @@ tf_proto_library(
# ":function_proto",
# ],
# )
# copybara:uncomment_end
#copybara : uncomment_end

tf_proto_library(
name = "summary_proto",
Expand Down Expand Up @@ -1806,6 +1819,7 @@ tf_proto_library(
protodeps = [
":allocation_description_proto",
":api_def_proto",
":cpp_shape_inference_proto",
":attr_value_proto",
":cost_graph_proto",
":dataset_metadata_proto",
Expand Down
36 changes: 36 additions & 0 deletions tensorflow/core/framework/cpp_shape_inference.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
syntax = "proto3";

package tensorflow;

import "tensorflow/core/framework/full_type.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";

option cc_enable_arenas = true;
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto";

message CppShapeInferenceResult {
message HandleShapeAndType {
reserved 3;

TensorShapeProto shape = 1;
DataType dtype = 2;
FullTypeDef type = 4;
}
message HandleData {
bool is_set = 1;

// Only valid if <is_set>.
repeated HandleShapeAndType shape_and_type = 2;
}
TensorShapeProto shape = 1;

reserved 2; // was handle_shape
reserved 3; // was handle_dtype
HandleData handle_data = 4;
}

message CppShapeInferenceInputsNeeded {
repeated int32 input_tensors_needed = 1;
repeated int32 input_tensors_as_shapes_needed = 2;
}
2 changes: 1 addition & 1 deletion tensorflow/tools/def_file_filter/def_file_filter.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def main():
# Each symbols returned by undname matches the same position in candidates.
# We compare on undname but use the decorated name from candidates.
dupes = 0
proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE)
proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE, shell=True)
for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")):
decorated = candidates[idx]
if decorated in taken:
Expand Down