Skip to content

Commit 4d23e8c

Browse files
tannergoodingViktorHofermichaelgsharp
committed
Adding a naive implementation of various primitive tensor operations (dotnet#91228)
* Adding a naive implementation of various primitive tensor operations * Adding tests covering the new tensor primitives APIs * Adding tensor primitives APIs to the ref assembly * Allow .NET Framework to build/run * Sync TFMs between ref and src, csproj simplication and clean-up * Apply suggestions from code review Co-authored-by: Viktor Hofer <[email protected]> * Don't use var * Fix the S.N.Tensors readme and remove the file marking it as non-shipping --------- Co-authored-by: Viktor Hofer <[email protected]> Co-authored-by: Michael Sharp <[email protected]>
1 parent 30e2d6b commit 4d23e8c

26 files changed

+1181
-23373
lines changed

src/libraries/System.Numerics.Tensors/Directory.Build.props

-8
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# System.Numerics.Tensors
2-
This library has not been shipped publicly and is not accepting contributions at this time.
2+
3+
Provides APIs for performing primitive operations over tensors represented by spans of memory.

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.cs

+21-145
Original file line numberDiff line numberDiff line change
@@ -6,151 +6,27 @@
66

77
namespace System.Numerics.Tensors
88
{
9-
public static partial class ArrayTensorExtensions
9+
public static class TensorPrimitives
1010
{
11-
public static System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor<T>(this System.Array array, bool reverseStride = false) { throw null; }
12-
public static System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor<T>(this T[,,] array, bool reverseStride = false) { throw null; }
13-
public static System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor<T>(this T[,] array, bool reverseStride = false) { throw null; }
14-
public static System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor<T>(this T[] array) { throw null; }
15-
public static System.Numerics.Tensors.SparseTensor<T> ToSparseTensor<T>(this System.Array array, bool reverseStride = false) { throw null; }
16-
public static System.Numerics.Tensors.SparseTensor<T> ToSparseTensor<T>(this T[,,] array, bool reverseStride = false) { throw null; }
17-
public static System.Numerics.Tensors.SparseTensor<T> ToSparseTensor<T>(this T[,] array, bool reverseStride = false) { throw null; }
18-
public static System.Numerics.Tensors.SparseTensor<T> ToSparseTensor<T>(this T[] array) { throw null; }
19-
public static System.Numerics.Tensors.DenseTensor<T> ToTensor<T>(this System.Array array, bool reverseStride = false) { throw null; }
20-
public static System.Numerics.Tensors.DenseTensor<T> ToTensor<T>(this T[,,] array, bool reverseStride = false) { throw null; }
21-
public static System.Numerics.Tensors.DenseTensor<T> ToTensor<T>(this T[,] array, bool reverseStride = false) { throw null; }
22-
public static System.Numerics.Tensors.DenseTensor<T> ToTensor<T>(this T[] array) { throw null; }
23-
}
24-
public partial class CompressedSparseTensor<T> : System.Numerics.Tensors.Tensor<T>
25-
{
26-
public CompressedSparseTensor(System.Memory<T> values, System.Memory<int> compressedCounts, System.Memory<int> indices, int nonZeroCount, System.ReadOnlySpan<int> dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { }
27-
public CompressedSparseTensor(System.ReadOnlySpan<int> dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { }
28-
public CompressedSparseTensor(System.ReadOnlySpan<int> dimensions, int capacity, bool reverseStride = false) : base (default(System.Array), default(bool)) { }
29-
public int Capacity { get { throw null; } }
30-
public System.Memory<int> CompressedCounts { get { throw null; } }
31-
public System.Memory<int> Indices { get { throw null; } }
32-
public override T this[System.ReadOnlySpan<int> indices] { get { throw null; } set { } }
33-
public int NonZeroCount { get { throw null; } }
34-
public System.Memory<T> Values { get { throw null; } }
35-
public override System.Numerics.Tensors.Tensor<T> Clone() { throw null; }
36-
public override System.Numerics.Tensors.Tensor<TResult> CloneEmpty<TResult>(System.ReadOnlySpan<int> dimensions) { throw null; }
37-
public override T GetValue(int index) { throw null; }
38-
public override System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions) { throw null; }
39-
public override void SetValue(int index, T value) { }
40-
public override System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor() { throw null; }
41-
public override System.Numerics.Tensors.DenseTensor<T> ToDenseTensor() { throw null; }
42-
public override System.Numerics.Tensors.SparseTensor<T> ToSparseTensor() { throw null; }
43-
}
44-
public partial class DenseTensor<T> : System.Numerics.Tensors.Tensor<T>
45-
{
46-
public DenseTensor(int length) : base (default(System.Array), default(bool)) { }
47-
public DenseTensor(System.Memory<T> memory, System.ReadOnlySpan<int> dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { }
48-
public DenseTensor(System.ReadOnlySpan<int> dimensions, bool reverseStride = false) : base (default(System.Array), default(bool)) { }
49-
public System.Memory<T> Buffer { get { throw null; } }
50-
public override System.Numerics.Tensors.Tensor<T> Clone() { throw null; }
51-
public override System.Numerics.Tensors.Tensor<TResult> CloneEmpty<TResult>(System.ReadOnlySpan<int> dimensions) { throw null; }
52-
protected override void CopyTo(T[] array, int arrayIndex) { }
53-
public override T GetValue(int index) { throw null; }
54-
protected override int IndexOf(T item) { throw null; }
55-
public override System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions) { throw null; }
56-
public override void SetValue(int index, T value) { }
57-
}
58-
public partial class SparseTensor<T> : System.Numerics.Tensors.Tensor<T>
59-
{
60-
public SparseTensor(System.ReadOnlySpan<int> dimensions, bool reverseStride = false, int capacity = 0) : base (default(System.Array), default(bool)) { }
61-
public int NonZeroCount { get { throw null; } }
62-
public override System.Numerics.Tensors.Tensor<T> Clone() { throw null; }
63-
public override System.Numerics.Tensors.Tensor<TResult> CloneEmpty<TResult>(System.ReadOnlySpan<int> dimensions) { throw null; }
64-
public override T GetValue(int index) { throw null; }
65-
public override System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions) { throw null; }
66-
public override void SetValue(int index, T value) { }
67-
public override System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor() { throw null; }
68-
public override System.Numerics.Tensors.DenseTensor<T> ToDenseTensor() { throw null; }
69-
public override System.Numerics.Tensors.SparseTensor<T> ToSparseTensor() { throw null; }
70-
}
71-
public static partial class Tensor
72-
{
73-
public static System.Numerics.Tensors.Tensor<T> CreateFromDiagonal<T>(System.Numerics.Tensors.Tensor<T> diagonal) { throw null; }
74-
public static System.Numerics.Tensors.Tensor<T> CreateFromDiagonal<T>(System.Numerics.Tensors.Tensor<T> diagonal, int offset) { throw null; }
75-
public static System.Numerics.Tensors.Tensor<T> CreateIdentity<T>(int size) { throw null; }
76-
public static System.Numerics.Tensors.Tensor<T> CreateIdentity<T>(int size, bool columMajor) { throw null; }
77-
public static System.Numerics.Tensors.Tensor<T> CreateIdentity<T>(int size, bool columMajor, T oneValue) { throw null; }
78-
}
79-
public abstract partial class Tensor<T> : System.Collections.Generic.ICollection<T>, System.Collections.Generic.IEnumerable<T>, System.Collections.Generic.IList<T>, System.Collections.Generic.IReadOnlyCollection<T>, System.Collections.Generic.IReadOnlyList<T>, System.Collections.ICollection, System.Collections.IEnumerable, System.Collections.IList, System.Collections.IStructuralComparable, System.Collections.IStructuralEquatable
80-
{
81-
protected Tensor(System.Array fromArray, bool reverseStride) { }
82-
protected Tensor(int length) { }
83-
protected Tensor(System.ReadOnlySpan<int> dimensions, bool reverseStride) { }
84-
public System.ReadOnlySpan<int> Dimensions { get { throw null; } }
85-
public bool IsFixedSize { get { throw null; } }
86-
public bool IsReadOnly { get { throw null; } }
87-
public bool IsReversedStride { get { throw null; } }
88-
public virtual T this[params int[] indices] { get { throw null; } set { } }
89-
public virtual T this[System.ReadOnlySpan<int> indices] { get { throw null; } set { } }
90-
public long Length { get { throw null; } }
91-
public int Rank { get { throw null; } }
92-
public System.ReadOnlySpan<int> Strides { get { throw null; } }
93-
int System.Collections.Generic.ICollection<T>.Count { get { throw null; } }
94-
T System.Collections.Generic.IList<T>.this[int index] { get { throw null; } set { } }
95-
int System.Collections.Generic.IReadOnlyCollection<T>.Count { get { throw null; } }
96-
T System.Collections.Generic.IReadOnlyList<T>.this[int index] { get { throw null; } }
97-
int System.Collections.ICollection.Count { get { throw null; } }
98-
bool System.Collections.ICollection.IsSynchronized { get { throw null; } }
99-
object System.Collections.ICollection.SyncRoot { get { throw null; } }
100-
object? System.Collections.IList.this[int index] { get { throw null; } set { } }
101-
public abstract System.Numerics.Tensors.Tensor<T> Clone();
102-
public virtual System.Numerics.Tensors.Tensor<T> CloneEmpty() { throw null; }
103-
public virtual System.Numerics.Tensors.Tensor<T> CloneEmpty(System.ReadOnlySpan<int> dimensions) { throw null; }
104-
public virtual System.Numerics.Tensors.Tensor<TResult> CloneEmpty<TResult>() { throw null; }
105-
public abstract System.Numerics.Tensors.Tensor<TResult> CloneEmpty<TResult>(System.ReadOnlySpan<int> dimensions);
106-
public static int Compare(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) { throw null; }
107-
protected virtual bool Contains(T item) { throw null; }
108-
protected virtual void CopyTo(T[] array, int arrayIndex) { }
109-
public static bool Equals(System.Numerics.Tensors.Tensor<T> left, System.Numerics.Tensors.Tensor<T> right) { throw null; }
110-
public virtual void Fill(T value) { }
111-
public string GetArrayString(bool includeWhitespace = true) { throw null; }
112-
public System.Numerics.Tensors.Tensor<T> GetDiagonal() { throw null; }
113-
public System.Numerics.Tensors.Tensor<T> GetDiagonal(int offset) { throw null; }
114-
public System.Numerics.Tensors.Tensor<T> GetTriangle() { throw null; }
115-
public System.Numerics.Tensors.Tensor<T> GetTriangle(int offset) { throw null; }
116-
public System.Numerics.Tensors.Tensor<T> GetUpperTriangle() { throw null; }
117-
public System.Numerics.Tensors.Tensor<T> GetUpperTriangle(int offset) { throw null; }
118-
public abstract T GetValue(int index);
119-
protected virtual int IndexOf(T item) { throw null; }
120-
public abstract System.Numerics.Tensors.Tensor<T> Reshape(System.ReadOnlySpan<int> dimensions);
121-
public abstract void SetValue(int index, T value);
122-
public struct Enumerator : System.Collections.Generic.IEnumerator<T>
123-
{
124-
public T Current { get; private set; }
125-
object? System.Collections.IEnumerator.Current => throw null;
126-
public bool MoveNext() => throw null;
127-
public void Reset() { }
128-
public void Dispose() { }
129-
}
130-
public Enumerator GetEnumerator() => throw null;
131-
void System.Collections.Generic.ICollection<T>.Add(T item) { }
132-
void System.Collections.Generic.ICollection<T>.Clear() { }
133-
bool System.Collections.Generic.ICollection<T>.Contains(T item) { throw null; }
134-
void System.Collections.Generic.ICollection<T>.CopyTo(T[] array, int arrayIndex) { }
135-
bool System.Collections.Generic.ICollection<T>.Remove(T item) { throw null; }
136-
System.Collections.Generic.IEnumerator<T> System.Collections.Generic.IEnumerable<T>.GetEnumerator() { throw null; }
137-
int System.Collections.Generic.IList<T>.IndexOf(T item) { throw null; }
138-
void System.Collections.Generic.IList<T>.Insert(int index, T item) { }
139-
void System.Collections.Generic.IList<T>.RemoveAt(int index) { }
140-
void System.Collections.ICollection.CopyTo(System.Array array, int index) { }
141-
System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { throw null; }
142-
int System.Collections.IList.Add(object? value) { throw null; }
143-
void System.Collections.IList.Clear() { }
144-
bool System.Collections.IList.Contains(object? value) { throw null; }
145-
int System.Collections.IList.IndexOf(object? value) { throw null; }
146-
void System.Collections.IList.Insert(int index, object? value) { }
147-
void System.Collections.IList.Remove(object? value) { }
148-
void System.Collections.IList.RemoveAt(int index) { }
149-
int System.Collections.IStructuralComparable.CompareTo(object? other, System.Collections.IComparer comparer) { throw null; }
150-
bool System.Collections.IStructuralEquatable.Equals(object? other, System.Collections.IEqualityComparer comparer) { throw null; }
151-
int System.Collections.IStructuralEquatable.GetHashCode(System.Collections.IEqualityComparer comparer) { throw null; }
152-
public virtual System.Numerics.Tensors.CompressedSparseTensor<T> ToCompressedSparseTensor() { throw null; }
153-
public virtual System.Numerics.Tensors.DenseTensor<T> ToDenseTensor() { throw null; }
154-
public virtual System.Numerics.Tensors.SparseTensor<T> ToSparseTensor() { throw null; }
11+
public static void Add(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { throw null; }
12+
public static void Add(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
13+
public static void AddMultiply(System.ReadOnlySpan<float> x, float y, System.ReadOnlySpan<float> multiplier, System.Span<float> destination) { throw null; }
14+
public static void AddMultiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, float multiplier, System.Span<float> destination) { throw null; }
15+
public static void AddMultiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.ReadOnlySpan<float> multiplier, System.Span<float> destination) { throw null; }
16+
public static void Cosh(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
17+
public static void Divide(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { throw null; }
18+
public static void Divide(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
19+
public static void Exp(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
20+
public static void Log(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
21+
public static void Multiply(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { throw null; }
22+
public static void Multiply(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
23+
public static void MultiplyAdd(System.ReadOnlySpan<float> x, float y, System.ReadOnlySpan<float> addend, System.Span<float> destination) { throw null; }
24+
public static void MultiplyAdd(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, float addend, System.Span<float> destination) { throw null; }
25+
public static void MultiplyAdd(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.ReadOnlySpan<float> addend, System.Span<float> destination) { throw null; }
26+
public static void Negate(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
27+
public static void Subtract(System.ReadOnlySpan<float> x, float y, System.Span<float> destination) { throw null; }
28+
public static void Subtract(System.ReadOnlySpan<float> x, System.ReadOnlySpan<float> y, System.Span<float> destination) { throw null; }
29+
public static void Sinh(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
30+
public static void Tanh(System.ReadOnlySpan<float> x, System.Span<float> destination) { throw null; }
15531
}
15632
}

src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.csproj

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
<Project Sdk="Microsoft.NET.Sdk">
2+
23
<PropertyGroup>
34
<TargetFrameworks>$(NetCoreAppCurrent);$(NetCoreAppPrevious);$(NetCoreAppMinimum);netstandard2.0;$(NetFrameworkMinimum)</TargetFrameworks>
45
</PropertyGroup>
@@ -10,4 +11,5 @@
1011
<ItemGroup Condition="'$(TargetFrameworkIdentifier)' != '.NETCoreApp'">
1112
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
1213
</ItemGroup>
14+
1315
</Project>

src/libraries/System.Numerics.Tensors/src/Properties/InternalsVisibleTo.cs

-6
This file was deleted.

src/libraries/System.Numerics.Tensors/src/Resources/Strings.resx

+4-46
Original file line numberDiff line numberDiff line change
@@ -117,52 +117,10 @@
117117
<resheader name="writer">
118118
<value>System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089</value>
119119
</resheader>
120-
<data name="ArrayMustContainElements" xml:space="preserve">
121-
<value>Array must contain elements.</value>
120+
<data name="Argument_DestinationTooShort" xml:space="preserve">
121+
<value>Destination is too short.</value>
122122
</data>
123-
<data name="CannotCompare" xml:space="preserve">
124-
<value>Cannot compare {0} to {1}.</value>
125-
</data>
126-
<data name="CannotCompareToWithDifferentDimension" xml:space="preserve">
127-
<value>Cannot compare {0} to {1} with different dimension {2}, {3} != {4}.</value>
128-
</data>
129-
<data name="CannotCompareWithDifferentDimension" xml:space="preserve">
130-
<value>Cannot compare {0} with different dimension {1}, {2} != {3}.</value>
131-
</data>
132-
<data name="CannotCompareWithRank" xml:space="preserve">
133-
<value>Cannot compare {0} with Rank {1} to {2} with Rank {3}.</value>
134-
</data>
135-
<data name="CannotComputeDiagonal" xml:space="preserve">
136-
<value>Cannot compute diagonal of {0} with Rank less than 2.</value>
137-
</data>
138-
<data name="CannotComputeDiagonalWithOffset" xml:space="preserve">
139-
<value>Cannot compute diagonal with offset {0}.</value>
140-
</data>
141-
<data name="MustHaveAtLeastOneDimension" xml:space="preserve">
142-
<value>Tensor {0} must have at least one dimension.</value>
143-
</data>
144-
<data name="CannotComputeTriangle" xml:space="preserve">
145-
<value>Cannot compute triangle of {0} with Rank less than 2.</value>
146-
</data>
147-
<data name="CannotReshapeArrayDueToMismatchInLengths" xml:space="preserve">
148-
<value>Cannot reshape array due to mismatch in lengths, currently {0} would become {1}.</value>
149-
</data>
150-
<data name="DimensionsMustBePositiveAndNonZero" xml:space="preserve">
151-
<value>Dimensions must be positive and non-zero.</value>
152-
</data>
153-
<data name="DimensionsMustContainElements" xml:space="preserve">
154-
<value>Dimensions must contain elements.</value>
155-
</data>
156-
<data name="LengthMustMatch" xml:space="preserve">
157-
<value>Length of {0} ({1}) must match product of {2} ({3}).</value>
158-
</data>
159-
<data name="NumberGreaterThenAvailableSpace" xml:space="preserve">
160-
<value>The number of elements in the Tensor is greater than the available space from index to the end of the destination array.</value>
161-
</data>
162-
<data name="OnlySingleDimensionalArraysSupported" xml:space="preserve">
163-
<value>Only single dimensional arrays are supported for the requested action.</value>
164-
</data>
165-
<data name="ValueIsNotOfType" xml:space="preserve">
166-
<value>The value "{0}" is not of type "{1}" and cannot be used in this generic collection.</value>
123+
<data name="Argument_SpansMustHaveSameLength" xml:space="preserve">
124+
<value>Length of '{0}' must be same as length of '{1}'.</value>
167125
</data>
168126
</root>

0 commit comments

Comments
 (0)