Skip to content

Commit

Permalink
Add customized C APIs for tf.net.
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakusaRinne committed Apr 17, 2023
1 parent fdfc646 commit de39f73
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 2 deletions.
142 changes: 142 additions & 0 deletions tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/cpp_shape_inference.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/validate.h"
Expand All @@ -71,6 +73,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/framework/full_type.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"

// The implementation below is at the top level instead of the
// brain namespace because we are defining 'extern "C"' functions.
Expand Down Expand Up @@ -2614,6 +2618,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
}
}

// TF Customized C APIs for Tensorflow.NET --------------------------

void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
mutex_lock l(graph->mu);
graph->graph.AddControlEdge(&input->node, &op->node);
tensorflow::RecordMutation(graph, *op, "adding control input");
}

void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Buffer* attr_value_proto, TF_Status* status) {
using tensorflow::RecordMutation;
tensorflow::AttrValue attr_val;
if (!attr_val.ParseFromArray(attr_value_proto->data,
attr_value_proto->length)) {
status->status =
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
return;
}

mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
tensorflow::RecordMutation(graph, *op, "setting attribute");
}

void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
TF_Status* status) {
mutex_lock l(graph->mu);
op->node.ClearAttr(attr_name);
tensorflow::RecordMutation(graph, *op, "clearing attribute");
}

void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
const tensorflow::FullTypeDef& full_type) {
mutex_lock l(graph->mu);
*op->node.mutable_def()->mutable_experimental_type() = full_type;
tensorflow::RecordMutation(graph, *op, "setting fulltype");
}

void TFC_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
tensorflow::RecordMutation(graph, *op, "setting device");
}

void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
TF_UpdateEdge(graph, new_src, dst, status);
}

void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
mutex_lock l(graph->mu);
std::vector<const tensorflow::Edge*> control_edges;
for (const tensorflow::Edge* edge : op->node.in_edges()) {
if (!edge->IsControlEdge()) continue;
control_edges.push_back(edge);
}
for (const tensorflow::Edge* edge : control_edges) {
graph->graph.RemoveControlEdge(edge);
}
}

void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
mutex_lock l(graph->mu);
graph->refiner.set_require_shape_inference_fns(require);
}

void TFC_ExtendSession(TF_Session* session, TF_Status* status) {
ExtendSessionGraphHelper(session, status);
session->extend_before_run = false;
}

const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
tensorflow::CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
{
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(node);
CHECK(ic != nullptr);
CHECK_LT(output.index, ic->num_outputs());
const auto* shapes_and_types =
ic->output_handle_shapes_and_types(output.index);
if (shapes_and_types == nullptr) return "";

for (const auto& p : *shapes_and_types) {
auto* out_shape_and_type = handle_data.add_shape_and_type();
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
out_shape_and_type->set_dtype(p.dtype);
*out_shape_and_type->mutable_type() = p.type;
}
}
string result;
handle_data.SerializeToString(&result);
return result.c_str();
}

void TFC_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
size_t proto_len, TF_Status* status) {
tensorflow::CppShapeInferenceResult::HandleData handle_data;
if (!handle_data.ParseFromArray(proto, proto_len)) {
status->status = tensorflow::errors::InvalidArgument(
"Couldn't deserialize HandleData proto");
return;
}
DCHECK(handle_data.is_set());

tensorflow::mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&output.oper->node);

std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
tensorflow::shape_inference::ShapeHandle shape;
status->status =
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
if (TF_GetCode(status) != TF_OK) return;
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
shape_and_type_proto.type());
}
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
}

void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status) {
mutex_lock l(graph->mu);
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
new_src.index, &dst->node);
if (TF_GetCode(status) == TF_OK) {
// This modification only updates the destination node for
// the purposes of running this graph in a session. Thus, we don't
// record the source node as being modified.
tensorflow::RecordMutation(graph, *dst, "adding input tensor");
}
}

// -------------------------------------------------------------------

// TF_Server functions ----------------------------------------------

#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
Expand Down
39 changes: 39 additions & 0 deletions tensorflow/c/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/c/tf_tstring.h"
#include "tensorflow/core/framework/full_type.pb.h"

// --------------------------------------------------------------------------
// C API for TensorFlow.
Expand Down Expand Up @@ -1620,6 +1621,44 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener(
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
const char* plugin_filename, TF_Status* status);

TF_CAPI_EXPORT extern void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);

TF_CAPI_EXPORT extern void TFC_SetAttr(TF_Graph* graph, TF_Operation* op,
const char* attr_name,
TF_Buffer* attr_value_proto,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op,
const char* attr_name,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
const tensorflow::FullTypeDef& full_type);

TF_CAPI_EXPORT extern void TFC_SetRequestedDevice(TF_Graph* graph,
TF_Operation* op,
const char* device);

TF_CAPI_EXPORT extern void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src,
TF_Input dst, TF_Status* status);

TF_CAPI_EXPORT extern void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);

TF_CAPI_EXPORT extern void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require);

TF_CAPI_EXPORT extern void TFC_ExtendSession(TF_Session* session, TF_Status* status);

TF_CAPI_EXPORT extern const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output);

TF_CAPI_EXPORT extern void TFC_SetHandleShapeAndType(TF_Graph* graph,
TF_Output output,
const void* proto,
size_t proto_len,
TF_Status* status);

TF_CAPI_EXPORT extern void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
TF_Status* status);

#ifdef __cplusplus
} /* end extern "C" */
#endif
Expand Down
18 changes: 16 additions & 2 deletions tensorflow/core/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ exports_files(
srcs = [
"allocation_description.proto",
"api_def.proto",
"cpp_shape_inference.proto",
"attr_value.proto",
"cost_graph.proto",
"dataset_metadata.proto",
Expand Down Expand Up @@ -1402,7 +1403,7 @@ cc_library(
# protos from the same package, so we can build the protos here and then
# link them from core:protos_all without circular dependencies.

# Generate the C++ sources for some of the protos.
#Generate the C++ sources for some of the protos.
tf_generate_proto_text_sources(
name = "attr_value_proto_text",
srcs = [
Expand Down Expand Up @@ -1693,6 +1694,18 @@ tf_proto_library(
],
)

tf_proto_library(
name = "cpp_shape_inference_proto",
srcs = ["cpp_shape_inference.proto"],
cc_api_version = 2,
make_default_target_header_only = True,
protodeps = [
":full_type_proto",
":tensor_shape_proto",
":types_proto",
],
)

tf_proto_library(
name = "variable_proto",
srcs = ["variable.proto"],
Expand Down Expand Up @@ -1760,7 +1773,7 @@ tf_proto_library(
# ":function_proto",
# ],
# )
# copybara:uncomment_end
#copybara : uncomment_end

tf_proto_library(
name = "summary_proto",
Expand Down Expand Up @@ -1806,6 +1819,7 @@ tf_proto_library(
protodeps = [
":allocation_description_proto",
":api_def_proto",
":cpp_shape_inference_proto",
":attr_value_proto",
":cost_graph_proto",
":dataset_metadata_proto",
Expand Down
36 changes: 36 additions & 0 deletions tensorflow/core/framework/cpp_shape_inference.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
syntax = "proto3";

package tensorflow;

import "tensorflow/core/framework/full_type.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";

option cc_enable_arenas = true;
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto";

message CppShapeInferenceResult {
message HandleShapeAndType {
reserved 3;

TensorShapeProto shape = 1;
DataType dtype = 2;
FullTypeDef type = 4;
}
message HandleData {
bool is_set = 1;

// Only valid if <is_set>.
repeated HandleShapeAndType shape_and_type = 2;
}
TensorShapeProto shape = 1;

reserved 2; // was handle_shape
reserved 3; // was handle_dtype
HandleData handle_data = 4;
}

message CppShapeInferenceInputsNeeded {
repeated int32 input_tensors_needed = 1;
repeated int32 input_tensors_as_shapes_needed = 2;
}

0 comments on commit de39f73

Please sign in to comment.