@@ -67,7 +67,7 @@ public static NDArray MakeNdarray(TensorProto tensor)
67
67
68
68
T [ ] ExpandArrayToSize < T > ( IList < T > src )
69
69
{
70
- if ( src . Count == 0 )
70
+ if ( src . Count == 0 )
71
71
{
72
72
return new T [ 0 ] ;
73
73
}
@@ -77,7 +77,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
77
77
var first_elem = src [ 0 ] ;
78
78
var last_elem = src [ src . Count - 1 ] ;
79
79
T [ ] res = new T [ num_elements ] ;
80
- for ( long i = 0 ; i < num_elements ; i ++ )
80
+ for ( long i = 0 ; i < num_elements ; i ++ )
81
81
{
82
82
if ( i < pre ) res [ i ] = first_elem ;
83
83
else if ( i >= num_elements - after ) res [ i ] = last_elem ;
@@ -121,7 +121,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
121
121
$ "https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.") ;
122
122
}
123
123
124
- if ( values . size == 0 )
124
+ if ( values . size == 0 )
125
125
{
126
126
return np . zeros ( shape , tensor_dtype ) ;
127
127
}
@@ -135,23 +135,47 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135
135
TF_DataType . TF_QINT32
136
136
} ;
137
137
138
- private static TOut [ , ] ConvertArray2D < TIn , TOut > ( TIn [ , ] inputArray , Func < TIn , TOut > converter )
138
+ private static Array ConvertArray < TOut > ( Array inputArray , Func < object , TOut > converter )
139
139
{
140
- var rows = inputArray . GetLength ( 0 ) ;
141
- var cols = inputArray . GetLength ( 1 ) ;
142
- var outputArray = new TOut [ rows , cols ] ;
140
+ if ( inputArray == null )
141
+ throw new ArgumentNullException ( nameof ( inputArray ) ) ;
143
142
144
- for ( var i = 0 ; i < rows ; i ++ )
143
+ var elementType = typeof ( TOut ) ;
144
+ var lengths = new int [ inputArray . Rank ] ;
145
+ for ( var i = 0 ; i < inputArray . Rank ; i ++ )
145
146
{
146
- for ( var j = 0 ; j < cols ; j ++ )
147
- {
148
- outputArray [ i , j ] = converter ( inputArray [ i , j ] ) ;
149
- }
147
+ lengths [ i ] = inputArray . GetLength ( i ) ;
150
148
}
151
149
150
+ var outputArray = Array . CreateInstance ( elementType , lengths ) ;
151
+
152
+ FillArray ( inputArray , outputArray , converter , new int [ inputArray . Rank ] , 0 ) ;
153
+
152
154
return outputArray ;
153
155
}
154
156
157
+ private static void FillArray < TIn , TOut > ( Array inputArray , Array outputArray , Func < TIn , TOut > converter , int [ ] indices , int dimension )
158
+ {
159
+ if ( dimension == inputArray . Rank - 1 )
160
+ {
161
+ for ( int i = 0 ; i < inputArray . GetLength ( dimension ) ; i ++ )
162
+ {
163
+ indices [ dimension ] = i ;
164
+ var inputValue = ( TIn ) inputArray . GetValue ( indices ) ;
165
+ var convertedValue = converter ( inputValue ) ;
166
+ outputArray . SetValue ( convertedValue , indices ) ;
167
+ }
168
+ }
169
+ else
170
+ {
171
+ for ( int i = 0 ; i < inputArray . GetLength ( dimension ) ; i ++ )
172
+ {
173
+ indices [ dimension ] = i ;
174
+ FillArray ( inputArray , outputArray , converter , indices , dimension + 1 ) ;
175
+ }
176
+ }
177
+ }
178
+
155
179
/// <summary>
156
180
/// Create a TensorProto, invoked in graph mode
157
181
/// </summary>
@@ -171,24 +195,30 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
171
195
var origin_dtype = values . GetDataType ( ) ;
172
196
if ( dtype == TF_DataType . DtInvalid )
173
197
dtype = origin_dtype ;
174
- else if ( origin_dtype != dtype )
198
+ else if ( origin_dtype != dtype )
175
199
{
176
200
var new_system_dtype = dtype . as_system_dtype ( ) ;
177
-
178
- values = values switch
201
+
202
+ if ( dtype != TF_DataType . TF_STRING && dtype != TF_DataType . TF_VARIANT && dtype != TF_DataType . TF_RESOURCE )
203
+ {
204
+ if ( values is Array arrayValues )
205
+ {
206
+ values = dtype switch
207
+ {
208
+ TF_DataType . TF_INT32 => ConvertArray ( arrayValues , Convert . ToInt32 ) ,
209
+ TF_DataType . TF_FLOAT => ConvertArray ( arrayValues , Convert . ToSingle ) ,
210
+ TF_DataType . TF_DOUBLE => ConvertArray ( arrayValues , Convert . ToDouble ) ,
211
+ _ => values ,
212
+ } ;
213
+ } else
214
+ {
215
+ values = Convert . ChangeType ( values , new_system_dtype ) ;
216
+ }
217
+
218
+ } else
179
219
{
180
- long [ ] longValues when dtype == TF_DataType . TF_INT32 => longValues . Select ( x => ( int ) x ) . ToArray ( ) ,
181
- long [ ] longValues => values ,
182
- float [ ] floatValues when dtype == TF_DataType . TF_DOUBLE => floatValues . Select ( x => ( double ) x ) . ToArray ( ) ,
183
- float [ ] floatValues => values ,
184
- float [ , ] float2DValues when dtype == TF_DataType . TF_DOUBLE => ConvertArray2D ( float2DValues , Convert . ToDouble ) ,
185
- float [ , ] float2DValues => values ,
186
- double [ ] doubleValues when dtype == TF_DataType . TF_FLOAT => doubleValues . Select ( x => ( float ) x ) . ToArray ( ) ,
187
- double [ ] doubleValues => values ,
188
- double [ , ] double2DValues when dtype == TF_DataType . TF_FLOAT => ConvertArray2D ( double2DValues , Convert . ToSingle ) ,
189
- double [ , ] double2DValues => values ,
190
- _ => Convert . ChangeType ( values , new_system_dtype ) ,
191
- } ;
220
+
221
+ }
192
222
dtype = values . GetDataType ( ) ;
193
223
}
194
224
@@ -306,7 +336,7 @@ bool hasattr(Graph property, string attr)
306
336
307
337
if ( tensor is EagerTensor eagerTensor )
308
338
{
309
- if ( tensor . dtype == tf . int64 )
339
+ if ( tensor . dtype == tf . int64 )
310
340
return new Shape ( tensor . ToArray < long > ( ) ) ;
311
341
else
312
342
return new Shape ( tensor . ToArray < int > ( ) ) ;
@@ -481,7 +511,7 @@ bool hasattr(Graph property, string attr)
481
511
var d_ = new int [ value . size ] ;
482
512
foreach ( var ( index , d ) in enumerate ( value . ToArray < int > ( ) ) )
483
513
d_ [ index ] = d >= 0 ? d : - 1 ;
484
-
514
+
485
515
ret = ret . merge_with ( new Shape ( d_ ) ) ;
486
516
}
487
517
return ret ;
0 commit comments