Skip to content

Commit 7fb73cd

Browse files
authoredJun 22, 2024··
Merge pull request #1257 from novikov-alexander/alnovi/generic-cast
fix: More generic array cast
2 parents 0392027 + def5774 commit 7fb73cd

File tree

1 file changed

+59
-29
lines changed

1 file changed

+59
-29
lines changed
 

‎src/TensorFlowNET.Core/Tensors/tensor_util.cs

+59-29
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public static NDArray MakeNdarray(TensorProto tensor)
6767

6868
T[] ExpandArrayToSize<T>(IList<T> src)
6969
{
70-
if(src.Count == 0)
70+
if (src.Count == 0)
7171
{
7272
return new T[0];
7373
}
@@ -77,7 +77,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
7777
var first_elem = src[0];
7878
var last_elem = src[src.Count - 1];
7979
T[] res = new T[num_elements];
80-
for(long i = 0; i < num_elements; i++)
80+
for (long i = 0; i < num_elements; i++)
8181
{
8282
if (i < pre) res[i] = first_elem;
8383
else if (i >= num_elements - after) res[i] = last_elem;
@@ -121,7 +121,7 @@ T[] ExpandArrayToSize<T>(IList<T> src)
121121
$"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.");
122122
}
123123

124-
if(values.size == 0)
124+
if (values.size == 0)
125125
{
126126
return np.zeros(shape, tensor_dtype);
127127
}
@@ -135,23 +135,47 @@ T[] ExpandArrayToSize<T>(IList<T> src)
135135
TF_DataType.TF_QINT32
136136
};
137137

138-
private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
138+
private static Array ConvertArray<TOut>(Array inputArray, Func<object, TOut> converter)
139139
{
140-
var rows = inputArray.GetLength(0);
141-
var cols = inputArray.GetLength(1);
142-
var outputArray = new TOut[rows, cols];
140+
if (inputArray == null)
141+
throw new ArgumentNullException(nameof(inputArray));
143142

144-
for (var i = 0; i < rows; i++)
143+
var elementType = typeof(TOut);
144+
var lengths = new int[inputArray.Rank];
145+
for (var i = 0; i < inputArray.Rank; i++)
145146
{
146-
for (var j = 0; j < cols; j++)
147-
{
148-
outputArray[i, j] = converter(inputArray[i, j]);
149-
}
147+
lengths[i] = inputArray.GetLength(i);
150148
}
151149

150+
var outputArray = Array.CreateInstance(elementType, lengths);
151+
152+
FillArray(inputArray, outputArray, converter, new int[inputArray.Rank], 0);
153+
152154
return outputArray;
153155
}
154156

157+
private static void FillArray<TIn, TOut>(Array inputArray, Array outputArray, Func<TIn, TOut> converter, int[] indices, int dimension)
158+
{
159+
if (dimension == inputArray.Rank - 1)
160+
{
161+
for (int i = 0; i < inputArray.GetLength(dimension); i++)
162+
{
163+
indices[dimension] = i;
164+
var inputValue = (TIn)inputArray.GetValue(indices);
165+
var convertedValue = converter(inputValue);
166+
outputArray.SetValue(convertedValue, indices);
167+
}
168+
}
169+
else
170+
{
171+
for (int i = 0; i < inputArray.GetLength(dimension); i++)
172+
{
173+
indices[dimension] = i;
174+
FillArray(inputArray, outputArray, converter, indices, dimension + 1);
175+
}
176+
}
177+
}
178+
155179
/// <summary>
156180
/// Create a TensorProto, invoked in graph mode
157181
/// </summary>
@@ -171,24 +195,30 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
171195
var origin_dtype = values.GetDataType();
172196
if (dtype == TF_DataType.DtInvalid)
173197
dtype = origin_dtype;
174-
else if(origin_dtype != dtype)
198+
else if (origin_dtype != dtype)
175199
{
176200
var new_system_dtype = dtype.as_system_dtype();
177-
178-
values = values switch
201+
202+
if (dtype != TF_DataType.TF_STRING && dtype != TF_DataType.TF_VARIANT && dtype != TF_DataType.TF_RESOURCE)
203+
{
204+
if (values is Array arrayValues)
205+
{
206+
values = dtype switch
207+
{
208+
TF_DataType.TF_INT32 => ConvertArray(arrayValues, Convert.ToInt32),
209+
TF_DataType.TF_FLOAT => ConvertArray(arrayValues, Convert.ToSingle),
210+
TF_DataType.TF_DOUBLE => ConvertArray(arrayValues, Convert.ToDouble),
211+
_ => values,
212+
};
213+
} else
214+
{
215+
values = Convert.ChangeType(values, new_system_dtype);
216+
}
217+
218+
} else
179219
{
180-
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
181-
long[] longValues => values,
182-
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
183-
float[] floatValues => values,
184-
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
185-
float[,] float2DValues => values,
186-
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
187-
double[] doubleValues => values,
188-
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
189-
double[,] double2DValues => values,
190-
_ => Convert.ChangeType(values, new_system_dtype),
191-
};
220+
221+
}
192222
dtype = values.GetDataType();
193223
}
194224

@@ -306,7 +336,7 @@ bool hasattr(Graph property, string attr)
306336

307337
if (tensor is EagerTensor eagerTensor)
308338
{
309-
if(tensor.dtype == tf.int64)
339+
if (tensor.dtype == tf.int64)
310340
return new Shape(tensor.ToArray<long>());
311341
else
312342
return new Shape(tensor.ToArray<int>());
@@ -481,7 +511,7 @@ bool hasattr(Graph property, string attr)
481511
var d_ = new int[value.size];
482512
foreach (var (index, d) in enumerate(value.ToArray<int>()))
483513
d_[index] = d >= 0 ? d : -1;
484-
514+
485515
ret = ret.merge_with(new Shape(d_));
486516
}
487517
return ret;

0 commit comments

Comments
 (0)
Please sign in to comment.