@@ -50,10 +50,12 @@ limitations under the License.
50
50
#include " tensorflow/core/framework/partial_tensor_shape.h"
51
51
#include " tensorflow/core/framework/tensor.h"
52
52
#include " tensorflow/core/framework/tensor.pb.h" // NOLINT
53
+ #include " tensorflow/core/framework/cpp_shape_inference.pb.h"
53
54
#include " tensorflow/core/framework/tensor_shape.h"
54
55
#include " tensorflow/core/framework/tensor_shape.pb.h"
55
56
#include " tensorflow/core/framework/types.h"
56
57
#include " tensorflow/core/framework/versions.pb.h"
58
+ #include " tensorflow/core/framework/shape_inference.h"
57
59
#include " tensorflow/core/graph/graph.h"
58
60
#include " tensorflow/core/graph/node_builder.h"
59
61
#include " tensorflow/core/graph/validate.h"
@@ -71,6 +73,8 @@ limitations under the License.
71
73
#include " tensorflow/core/platform/types.h"
72
74
#include " tensorflow/core/public/session.h"
73
75
#include " tensorflow/core/public/version.h"
76
+ #include " tensorflow/core/framework/full_type.pb.h"
77
+ #include " tensorflow/core/framework/attr_value_util.h"
74
78
75
79
// The implementation below is at the top level instead of the
76
80
// brain namespace because we are defining 'extern "C"' functions.
@@ -2614,6 +2618,144 @@ void TF_UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
2614
2618
}
2615
2619
}
2616
2620
2621
+ // TF Customized C APIs for Tensorflow.NET --------------------------
2622
+
2623
+ void TFC_AddControlInput (TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
2624
+ mutex_lock l (graph->mu );
2625
+ graph->graph .AddControlEdge (&input->node , &op->node );
2626
+ tensorflow::RecordMutation (graph, *op, " adding control input" );
2627
+ }
2628
+
2629
+ void TFC_SetAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2630
+ TF_Buffer* attr_value_proto, TF_Status* status) {
2631
+ using tensorflow::RecordMutation;
2632
+ tensorflow::AttrValue attr_val;
2633
+ if (!attr_val.ParseFromArray (attr_value_proto->data ,
2634
+ attr_value_proto->length )) {
2635
+ status->status =
2636
+ tensorflow::errors::InvalidArgument (" Invalid AttrValue proto" );
2637
+ return ;
2638
+ }
2639
+
2640
+ mutex_lock l (graph->mu );
2641
+ op->node .AddAttr (attr_name, attr_val);
2642
+ tensorflow::RecordMutation (graph, *op, " setting attribute" );
2643
+ }
2644
+
2645
+ void TFC_ClearAttr (TF_Graph* graph, TF_Operation* op, const char * attr_name,
2646
+ TF_Status* status) {
2647
+ mutex_lock l (graph->mu );
2648
+ op->node .ClearAttr (attr_name);
2649
+ tensorflow::RecordMutation (graph, *op, " clearing attribute" );
2650
+ }
2651
+
2652
+ void TFC_SetFullType (TF_Graph* graph, TF_Operation* op,
2653
+ const tensorflow::FullTypeDef& full_type) {
2654
+ mutex_lock l (graph->mu );
2655
+ *op->node .mutable_def ()->mutable_experimental_type () = full_type;
2656
+ tensorflow::RecordMutation (graph, *op, " setting fulltype" );
2657
+ }
2658
+
2659
+ void TFC_SetRequestedDevice (TF_Graph* graph, TF_Operation* op, const char * device) {
2660
+ mutex_lock l (graph->mu );
2661
+ op->node .set_requested_device (device);
2662
+ tensorflow::RecordMutation (graph, *op, " setting device" );
2663
+ }
2664
+
2665
+ void TFC_UpdateEdge (TF_Graph* graph, TF_Output new_src, TF_Input dst,
2666
+ TF_Status* status) {
2667
+ TF_UpdateEdge (graph, new_src, dst, status);
2668
+ }
2669
+
2670
+ void TFC_RemoveAllControlInputs (TF_Graph* graph, TF_Operation* op) {
2671
+ mutex_lock l (graph->mu );
2672
+ std::vector<const tensorflow::Edge*> control_edges;
2673
+ for (const tensorflow::Edge* edge : op->node .in_edges ()) {
2674
+ if (!edge->IsControlEdge ()) continue ;
2675
+ control_edges.push_back (edge);
2676
+ }
2677
+ for (const tensorflow::Edge* edge : control_edges) {
2678
+ graph->graph .RemoveControlEdge (edge);
2679
+ }
2680
+ }
2681
+
2682
+ void TFC_SetRequireShapeInferenceFns (TF_Graph* graph, bool require) {
2683
+ mutex_lock l (graph->mu );
2684
+ graph->refiner .set_require_shape_inference_fns (require);
2685
+ }
2686
+
2687
+ void TFC_ExtendSession (TF_Session* session, TF_Status* status) {
2688
+ ExtendSessionGraphHelper (session, status);
2689
+ session->extend_before_run = false ;
2690
+ }
2691
+
2692
+ const char * TFC_GetHandleShapeAndType (TF_Graph* graph, TF_Output output) {
2693
+ Node* node = &output.oper ->node ;
2694
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2695
+ handle_data.set_is_set (true );
2696
+ {
2697
+ mutex_lock l (graph->mu );
2698
+ tensorflow::shape_inference::InferenceContext* ic =
2699
+ graph->refiner .GetContext (node);
2700
+ CHECK (ic != nullptr );
2701
+ CHECK_LT (output.index , ic->num_outputs ());
2702
+ const auto * shapes_and_types =
2703
+ ic->output_handle_shapes_and_types (output.index );
2704
+ if (shapes_and_types == nullptr ) return " " ;
2705
+
2706
+ for (const auto & p : *shapes_and_types) {
2707
+ auto * out_shape_and_type = handle_data.add_shape_and_type ();
2708
+ ic->ShapeHandleToProto (p.shape , out_shape_and_type->mutable_shape ());
2709
+ out_shape_and_type->set_dtype (p.dtype );
2710
+ *out_shape_and_type->mutable_type () = p.type ;
2711
+ }
2712
+ }
2713
+ string result;
2714
+ handle_data.SerializeToString (&result);
2715
+ return result.c_str ();
2716
+ }
2717
+
2718
+ void TFC_SetHandleShapeAndType (TF_Graph* graph, TF_Output output, const void * proto,
2719
+ size_t proto_len, TF_Status* status) {
2720
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
2721
+ if (!handle_data.ParseFromArray (proto, proto_len)) {
2722
+ status->status = tensorflow::errors::InvalidArgument (
2723
+ " Couldn't deserialize HandleData proto" );
2724
+ return ;
2725
+ }
2726
+ DCHECK (handle_data.is_set ());
2727
+
2728
+ tensorflow::mutex_lock l (graph->mu );
2729
+ tensorflow::shape_inference::InferenceContext* ic =
2730
+ graph->refiner .GetContext (&output.oper ->node );
2731
+
2732
+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
2733
+ for (const auto & shape_and_type_proto : handle_data.shape_and_type ()) {
2734
+ tensorflow::shape_inference::ShapeHandle shape;
2735
+ status->status =
2736
+ ic->MakeShapeFromShapeProto (shape_and_type_proto.shape (), &shape);
2737
+ if (TF_GetCode (status) != TF_OK) return ;
2738
+ shapes_and_types.emplace_back (shape, shape_and_type_proto.dtype (),
2739
+ shape_and_type_proto.type ());
2740
+ }
2741
+ ic->set_output_handle_shapes_and_types (output.index , shapes_and_types);
2742
+ }
2743
+
2744
+ void TFC_AddWhileInputHack (TF_Graph* graph, TF_Output new_src, TF_Operation* dst,
2745
+ TF_Status* status) {
2746
+ mutex_lock l (graph->mu );
2747
+ status->status = graph->graph .AddWhileInputHack (&new_src.oper ->node ,
2748
+ new_src.index , &dst->node );
2749
+ if (TF_GetCode (status) == TF_OK) {
2750
+ // This modification only updates the destination node for
2751
+ // the purposes of running this graph in a session. Thus, we don't
2752
+ // record the source node as being modified.
2753
+ tensorflow::RecordMutation (graph, *dst, " adding input tensor" );
2754
+ }
2755
+ }
2756
+
2757
+ // -------------------------------------------------------------------
2758
+
2617
2759
// TF_Server functions ----------------------------------------------
2618
2760
2619
2761
#if !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)
0 commit comments