Skip to content

Commit de39f73

Browse files
committed
Add customized C APIs for tf.net.
1 parent fdfc646 commit de39f73

File tree

4 files changed

+233
-2
lines changed

4 files changed

+233
-2
lines changed

tensorflow/c/c_api.cc

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@ limitations under the License.
5050
#include "tensorflow/core/framework/partial_tensor_shape.h"
5151
#include "tensorflow/core/framework/tensor.h"
5252
#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
53+
#include "tensorflow/core/framework/cpp_shape_inference.pb.h"
5354
#include "tensorflow/core/framework/tensor_shape.h"
5455
#include "tensorflow/core/framework/tensor_shape.pb.h"
5556
#include "tensorflow/core/framework/types.h"
5657
#include "tensorflow/core/framework/versions.pb.h"
58+
#include "tensorflow/core/framework/shape_inference.h"
5759
#include "tensorflow/core/graph/graph.h"
5860
#include "tensorflow/core/graph/node_builder.h"
5961
#include "tensorflow/core/graph/validate.h"
@@ -71,6 +73,8 @@ limitations under the License.
7173
#include "tensorflow/core/platform/types.h"
7274
#include "tensorflow/core/public/session.h"
7375
#include "tensorflow/core/public/version.h"
76+
#include "tensorflow/core/framework/full_type.pb.h"
77+
#include "tensorflow/core/framework/attr_value_util.h"
7478

7579
// The implementation below is at the top level instead of the
7680
// brain namespace because we are defining 'extern "C"' functions.
@@ -2614,6 +2618,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
26142618
}
26152619
}
26162620

2621+
// TF Customized C APIs for Tensorflow.NET --------------------------
2622+
2623+
void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
2624+
mutex_lock l(graph->mu);
2625+
graph->graph.AddControlEdge(&input->node, &op->node);
2626+
tensorflow::RecordMutation(graph, *op, "adding control input");
2627+
}
2628+
2629+
void TFC_SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
2630+
TF_Buffer* attr_value_proto, TF_Status* status) {
2631+
using tensorflow::RecordMutation;
2632+
tensorflow::AttrValue attr_val;
2633+
if (!attr_val.ParseFromArray(attr_value_proto->data,
2634+
attr_value_proto->length)) {
2635+
status->status =
2636+
tensorflow::errors::InvalidArgument("Invalid AttrValue proto");
2637+
return;
2638+
}
2639+
2640+
mutex_lock l(graph->mu);
2641+
op->node.AddAttr(attr_name, attr_val);
2642+
tensorflow::RecordMutation(graph, *op, "setting attribute");
2643+
}
2644+
2645+
void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
2646+
TF_Status* status) {
2647+
mutex_lock l(graph->mu);
2648+
op->node.ClearAttr(attr_name);
2649+
tensorflow::RecordMutation(graph, *op, "clearing attribute");
2650+
}
2651+
2652+
void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
2653+
const tensorflow::FullTypeDef& full_type) {
2654+
mutex_lock l(graph->mu);
2655+
*op->node.mutable_def()->mutable_experimental_type() = full_type;
2656+
tensorflow::RecordMutation(graph, *op, "setting fulltype");
2657+
}
2658+
2659+
void TFC_SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
2660+
mutex_lock l(graph->mu);
2661+
op->node.set_requested_device(device);
2662+
tensorflow::RecordMutation(graph, *op, "setting device");
2663+
}
2664+
2665+
void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2666+
TF_Status* status) {
2667+
TF_UpdateEdge(graph, new_src, dst, status);
2668+
}
2669+
2670+
void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) {
2671+
mutex_lock l(graph->mu);
2672+
std::vector<const tensorflow::Edge*> control_edges;
2673+
for (const tensorflow::Edge* edge : op->node.in_edges()) {
2674+
if (!edge->IsControlEdge()) continue;
2675+
control_edges.push_back(edge);
2676+
}
2677+
for (const tensorflow::Edge* edge : control_edges) {
2678+
graph->graph.RemoveControlEdge(edge);
2679+
}
2680+
}
2681+
2682+
void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require) {
2683+
mutex_lock l(graph->mu);
2684+
graph->refiner.set_require_shape_inference_fns(require);
2685+
}
2686+
2687+
void TFC_ExtendSession(TF_Session* session, TF_Status* status) {
2688+
ExtendSessionGraphHelper(session, status);
2689+
session->extend_before_run = false;
2690+
}
2691+
2692+
const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output) {
2693+
Node* node = &output.oper->node;
2694+
tensorflow::CppShapeInferenceResult::HandleData handle_data;
2695+
handle_data.set_is_set(true);
2696+
{
2697+
mutex_lock l(graph->mu);
2698+
tensorflow::shape_inference::InferenceContext* ic =
2699+
graph->refiner.GetContext(node);
2700+
CHECK(ic != nullptr);
2701+
CHECK_LT(output.index, ic->num_outputs());
2702+
const auto* shapes_and_types =
2703+
ic->output_handle_shapes_and_types(output.index);
2704+
if (shapes_and_types == nullptr) return "";
2705+
2706+
for (const auto& p : *shapes_and_types) {
2707+
auto* out_shape_and_type = handle_data.add_shape_and_type();
2708+
ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
2709+
out_shape_and_type->set_dtype(p.dtype);
2710+
*out_shape_and_type->mutable_type() = p.type;
2711+
}
2712+
}
2713+
string result;
2714+
handle_data.SerializeToString(&result);
2715+
return result.c_str();
2716+
}
2717+
2718+
void TFC_SetHandleShapeAndType(TF_Graph* graph, TF_Output output, const void* proto,
2719+
size_t proto_len, TF_Status* status) {
2720+
tensorflow::CppShapeInferenceResult::HandleData handle_data;
2721+
if (!handle_data.ParseFromArray(proto, proto_len)) {
2722+
status->status = tensorflow::errors::InvalidArgument(
2723+
"Couldn't deserialize HandleData proto");
2724+
return;
2725+
}
2726+
DCHECK(handle_data.is_set());
2727+
2728+
tensorflow::mutex_lock l(graph->mu);
2729+
tensorflow::shape_inference::InferenceContext* ic =
2730+
graph->refiner.GetContext(&output.oper->node);
2731+
2732+
std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
2733+
for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
2734+
tensorflow::shape_inference::ShapeHandle shape;
2735+
status->status =
2736+
ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
2737+
if (TF_GetCode(status) != TF_OK) return;
2738+
shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype(),
2739+
shape_and_type_proto.type());
2740+
}
2741+
ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
2742+
}
2743+
2744+
void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
2745+
TF_Status* status) {
2746+
mutex_lock l(graph->mu);
2747+
status->status = graph->graph.AddWhileInputHack(&new_src.oper->node,
2748+
new_src.index, &dst->node);
2749+
if (TF_GetCode(status) == TF_OK) {
2750+
// This modification only updates the destination node for
2751+
// the purposes of running this graph in a session. Thus, we don't
2752+
// record the source node as being modified.
2753+
tensorflow::RecordMutation(graph, *dst, "adding input tensor");
2754+
}
2755+
}
2756+
2757+
// -------------------------------------------------------------------
2758+
26172759
// TF_Server functions ----------------------------------------------
26182760

26192761
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

tensorflow/c/c_api.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include "tensorflow/c/tf_status.h"
2525
#include "tensorflow/c/tf_tensor.h"
2626
#include "tensorflow/c/tf_tstring.h"
27+
#include "tensorflow/core/framework/full_type.pb.h"
2728

2829
// --------------------------------------------------------------------------
2930
// C API for TensorFlow.
@@ -1620,6 +1621,44 @@ TF_CAPI_EXPORT extern void TF_RegisterLogListener(
16201621
TF_CAPI_EXPORT extern void TF_RegisterFilesystemPlugin(
16211622
const char* plugin_filename, TF_Status* status);
16221623

1624+
TF_CAPI_EXPORT extern void TFC_AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input);
1625+
1626+
TF_CAPI_EXPORT extern void TFC_SetAttr(TF_Graph* graph, TF_Operation* op,
1627+
const char* attr_name,
1628+
TF_Buffer* attr_value_proto,
1629+
TF_Status* status);
1630+
1631+
TF_CAPI_EXPORT extern void TFC_ClearAttr(TF_Graph* graph, TF_Operation* op,
1632+
const char* attr_name,
1633+
TF_Status* status);
1634+
1635+
TF_CAPI_EXPORT extern void TFC_SetFullType(TF_Graph* graph, TF_Operation* op,
1636+
const tensorflow::FullTypeDef& full_type);
1637+
1638+
TF_CAPI_EXPORT extern void TFC_SetRequestedDevice(TF_Graph* graph,
1639+
TF_Operation* op,
1640+
const char* device);
1641+
1642+
TF_CAPI_EXPORT extern void TFC_UpdateEdge(TF_Graph* graph, TF_Output new_src,
1643+
TF_Input dst, TF_Status* status);
1644+
1645+
TF_CAPI_EXPORT extern void TFC_RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op);
1646+
1647+
TF_CAPI_EXPORT extern void TFC_SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
1648+
1649+
TF_CAPI_EXPORT extern void TFC_ExtendSession(TF_Session* session, TF_Status* status);
1650+
1651+
TF_CAPI_EXPORT extern const char* TFC_GetHandleShapeAndType(TF_Graph* graph, TF_Output output);
1652+
1653+
TF_CAPI_EXPORT extern void TFC_SetHandleShapeAndType(TF_Graph* graph,
1654+
TF_Output output,
1655+
const void* proto,
1656+
size_t proto_len,
1657+
TF_Status* status);
1658+
1659+
TF_CAPI_EXPORT extern void TFC_AddWhileInputHack(TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
1660+
TF_Status* status);
1661+
16231662
#ifdef __cplusplus
16241663
} /* end extern "C" */
16251664
#endif

tensorflow/core/framework/BUILD

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ exports_files(
116116
srcs = [
117117
"allocation_description.proto",
118118
"api_def.proto",
119+
"cpp_shape_inference.proto",
119120
"attr_value.proto",
120121
"cost_graph.proto",
121122
"dataset_metadata.proto",
@@ -1402,7 +1403,7 @@ cc_library(
14021403
# protos from the same package, so we can build the protos here and then
14031404
# link them from core:protos_all without circular dependencies.
14041405

1405-
# Generate the C++ sources for some of the protos.
1406+
#Generate the C++ sources for some of the protos.
14061407
tf_generate_proto_text_sources(
14071408
name = "attr_value_proto_text",
14081409
srcs = [
@@ -1693,6 +1694,18 @@ tf_proto_library(
16931694
],
16941695
)
16951696

1697+
tf_proto_library(
1698+
name = "cpp_shape_inference_proto",
1699+
srcs = ["cpp_shape_inference.proto"],
1700+
cc_api_version = 2,
1701+
make_default_target_header_only = True,
1702+
protodeps = [
1703+
":full_type_proto",
1704+
":tensor_shape_proto",
1705+
":types_proto",
1706+
],
1707+
)
1708+
16961709
tf_proto_library(
16971710
name = "variable_proto",
16981711
srcs = ["variable.proto"],
@@ -1760,7 +1773,7 @@ tf_proto_library(
17601773
# ":function_proto",
17611774
# ],
17621775
# )
1763-
# copybara:uncomment_end
1776+
#copybara : uncomment_end
17641777

17651778
tf_proto_library(
17661779
name = "summary_proto",
@@ -1806,6 +1819,7 @@ tf_proto_library(
18061819
protodeps = [
18071820
":allocation_description_proto",
18081821
":api_def_proto",
1822+
":cpp_shape_inference_proto",
18091823
":attr_value_proto",
18101824
":cost_graph_proto",
18111825
":dataset_metadata_proto",
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
syntax = "proto3";
2+
3+
package tensorflow;
4+
5+
import "tensorflow/core/framework/full_type.proto";
6+
import "tensorflow/core/framework/tensor_shape.proto";
7+
import "tensorflow/core/framework/types.proto";
8+
9+
option cc_enable_arenas = true;
10+
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/python/framework/cpp_shape_inference_go_proto";
11+
12+
message CppShapeInferenceResult {
13+
message HandleShapeAndType {
14+
reserved 3;
15+
16+
TensorShapeProto shape = 1;
17+
DataType dtype = 2;
18+
FullTypeDef type = 4;
19+
}
20+
message HandleData {
21+
bool is_set = 1;
22+
23+
// Only valid if <is_set>.
24+
repeated HandleShapeAndType shape_and_type = 2;
25+
}
26+
TensorShapeProto shape = 1;
27+
28+
reserved 2; // was handle_shape
29+
reserved 3; // was handle_dtype
30+
HandleData handle_data = 4;
31+
}
32+
33+
message CppShapeInferenceInputsNeeded {
34+
repeated int32 input_tensors_needed = 1;
35+
repeated int32 input_tensors_as_shapes_needed = 2;
36+
}

0 commit comments

Comments
 (0)