Skip to content

Commit

Permalink
Use interface static methods to define operators
Browse files Browse the repository at this point in the history
  • Loading branch information
aalmada committed Dec 9, 2023
1 parent 7814756 commit 9a3988f
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 5 deletions.
46 changes: 46 additions & 0 deletions src/NetFabric.Numerics.Tensors/Aggregate.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

using System;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace NetFabric.Numerics
{
public static partial class Tensor
{
/// <summary>
/// Aggregates the elements in the specified <see cref="ReadOnlySpan{T}"/> using the specified <see cref="IAggregationOperator{T}"/>.
/// </summary>
/// <typeparam name="T">The type of the elements in the span.</typeparam>
/// <typeparam name="TOperator">The type of the aggregation operator.</typeparam>
/// <param name="source">The source span.</param>
/// <returns>The aggregated value.</returns>
public static T Aggregate<T, TOperator>(ReadOnlySpan<T> source)
where T : struct
where TOperator : struct, IAggregationOperator<T>
{
var result = TOperator.Seed;
var resultVector = new Vector<T>(TOperator.Seed);
nint index = 0;

if (Vector.IsHardwareAccelerated &&
Vector<T>.IsSupported &&
source.Length >= Vector<T>.Count)
{
var sourceVectors = MemoryMarshal.Cast<T, Vector<T>>(source);

ref var sourceVectorsRef = ref MemoryMarshal.GetReference(sourceVectors);
for (nint indexVector = 0; indexVector < sourceVectors.Length; indexVector++)
resultVector = TOperator.Invoke(in resultVector, in Unsafe.Add(ref sourceVectorsRef, indexVector));

index = source.Length - source.Length % Vector<T>.Count;
}

ref var sourceRef = ref MemoryMarshal.GetReference(source);
for (; index < source.Length; index++)
result = TOperator.Invoke(in result, in Unsafe.Add(ref sourceRef, index));

return TOperator.ResultSelector(in result, in resultVector);
}
}
}
31 changes: 31 additions & 0 deletions src/NetFabric.Numerics.Tensors/ITensorOperation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,35 @@ public interface ITernaryTensorOperation<T>
{
void Apply(ref readonly T leftItem, ref readonly T rightItem, ref T destinationItem);
void Apply(ref readonly Vector<T> leftVector, ref readonly Vector<T> rightVector, ref Vector<T> destinationVector);
}

public interface IUnaryOperator<T>
where T : struct
{
static abstract T Invoke(ref readonly T x);
static abstract Vector<T> Invoke(ref readonly Vector<T> x);
}

public interface IBinaryOperator<T>
where T : struct
{
static abstract T Invoke(ref readonly T x, ref readonly T y);
static abstract Vector<T> Invoke(ref readonly Vector<T> x, ref readonly Vector<T> y);
}

public interface IAggregationOperator<T>
: IBinaryOperator<T>
where T : struct
{
static virtual T Seed
=> Throw.NotSupportedException<T>();

static abstract T ResultSelector(ref readonly T value, ref readonly Vector<T> vector);
}

public interface ITernaryOperator<T>
where T : struct
{
static abstract T Invoke(ref readonly T x, ref readonly T y, ref readonly T z);
static abstract Vector<T> Invoke(ref readonly Vector<T> x, ref readonly Vector<T> y, ref readonly Vector<T> z);
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,21 @@ void IUnary2DTensorOperation<T>.Apply(ref readonly T item1, ref readonly T item2
[MethodImpl(MethodImplOptions.AggressiveInlining)]
void IUnary2DTensorOperation<T>.Apply(ref readonly Vector<T> vector)
=> sumVector += vector;
}

public readonly struct SumOperator<T>
: IAggregationOperator<T>
where T : struct, IAdditiveIdentity<T, T>, IAdditionOperators<T, T, T>
{
public static T Seed
=> T.AdditiveIdentity;

public static T ResultSelector(ref readonly T value, ref readonly Vector<T> vector)
=> Vector.Sum(vector) + value;

public static T Invoke(ref readonly T x, ref readonly T y)
=> x + y;

public static Vector<T> Invoke(ref readonly Vector<T> x, ref readonly Vector<T> y)
=> x + y;
}
6 changes: 1 addition & 5 deletions src/NetFabric.Numerics.Tensors/Sum.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@ public static partial class Tensor
/// </remarks>
public static T Sum<T>(ReadOnlySpan<T> source)
where T : struct, IAdditionOperators<T, T, T>, IAdditiveIdentity<T, T>
{
var sum = new SumOperation<T>();
Apply(source, ref sum);
return sum.Result;
}
=> Aggregate<T, SumOperator<T>>(source);

/// <summary>
/// Computes the sum of a 2D span of values.
Expand Down

0 comments on commit 9a3988f

Please sign in to comment.