diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs index 67eef0a1b..7bc438c99 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs @@ -1,5 +1,7 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. +using System.Drawing; +using System.Numerics; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; @@ -50,16 +52,16 @@ public unsafe ILinkableFunction Build(TIR.PrimFunction function) ulong rdataPoolSize = ulong.MinValue; foreach (var (@const, range) in function.SchedResult.Rdatas) { - var bytes = ((TensorConst)@const).Value.BytesBuffer; + var tensor = ((TensorConst)@const).Value; var size = range.Max - range.Min; rdataPoolSize = System.Math.Max(range.Max, rdataPoolSize); - if ((uint)bytes.Length != size) + if ((ulong)tensor.Length * (ulong)tensor.ElementType.SizeInBytes != size) { throw new InvalidDataException("The Buffer Size Not Equal!"); } _rdataWriter.Position(checked((long)range.Min)); - _rdataWriter.Write(bytes); + tensor.Serialize(_rdataWriter.BaseStream); } // 3. build function. diff --git a/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs index b10945054..bef0440cf 100644 --- a/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs +++ b/modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs @@ -5,6 +5,10 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Net.Http.Headers; +using System.Numerics; +using System.Runtime.InteropServices; +using CommunityToolkit.HighPerformance; using Nncase.CostModel; using Nncase.IR; using Nncase.IR.CPU; @@ -41,7 +45,8 @@ public IValue Visit(IEvaluateContext context, Pack target) input = input.Pack(lanes, axis); } - return Value.FromTensor(Tensor.FromBytes(new VectorType(input.DataType.ToDataType(), target.Lanes), input.BytesBuffer.ToArray(), input.Shape.ToArray().SkipLast(target.Lanes.Count).Select(i => (int)i).ToArray())); + var dt = input.DataType.ToDataType(); + return Value.FromTensor(input.ToTensor(new TensorType(new VectorType(input.DataType.ToDataType(), target.Lanes), new Shape(input.Shape.SkipLast(target.Lanes.Count).Select(i => (int)i))))); } } diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs index 9026f35c7..10226c776 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs @@ -940,6 +940,52 @@ void AddCandidate(int[] packAxes, int[] lanes) } } +public sealed class PackCast : PackRule +{ + public PackCast(int rank, int lane) + : base(rank, lane) + { + } + + public override Pattern Pattern { get; } = IsCast( + "target", + _ => true, + IsWildcard("input", e => e is not Call { Target: IR.CPU.Unpack }) with { TypePattern = IsFloat() & !IsVector() }); + + public override List GetReplaceCandidates(IMatchResult result, RunPassContext context) + { + var rets = new List(); + var op = (IR.Tensors.Cast)result["target"]; + var input = (Expr)result["input"]; + var inShape = input.CheckedShape.ToValueArray(); + + void AddCandidate(int[] packedAxes, int[] lanes) + { + var packedInput = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var padsInput), lanes, packedAxes); + var cast = IR.F.Tensors.Cast(packedInput, op.NewType, op.CastMode); + var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(cast, lanes, packedAxes), inShape, padsInput); + if (cast.CheckedType is not InvalidType) + { + rets.Add(post); + } + } + + for (int i = 0; i < input.CheckedShape.Count; i++) + { + AddCandidate(new[] { i }, new[] { Lane }); + for (int j = i + 1; j < input.CheckedShape.Count; j++) + { + if (Rank > 1) + { + AddCandidate(new[] { i, j }, new[] { Lane, Lane }); + } + } + } + + return rets; + } +} + [RuleGenerator] public sealed partial class FoldPackUnpack : RewriteRule { diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs index 2b6a4e6ed..cd472cce6 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodegenVisitor.cs @@ -5,7 +5,10 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; +using System.Numerics; using System.Reactive; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; using NetFabric.Hyperlinq; @@ -403,7 +406,7 @@ private void MergeSnippetSet(List thenSet, List elseSe private TextSnippet Visit(TensorConst expr, Tensor tensor) { - var buffer = WriteRdata(tensor.BytesBuffer, _alignment); + var buffer = WriteRdata(tensor, _alignment); // stack: dtype shape strides buffer var snippet = BeginTextSnippet(expr); @@ -428,11 +431,11 @@ private Symbol WriteRdata(DataType dataType) return symbol; } - private Symbol WriteRdata(ReadOnlySpan data, int alignment) + private Symbol WriteRdata(Tensor tensor, int alignment) { _context.RdataWriter.AlignPosition(alignment); var symbol = AddSymbol(WellknownSectionNames.Rdata); - _context.RdataWriter.Write(data); + tensor.Serialize(_context.RdataWriter.BaseStream); return symbol; } diff --git a/src/Nncase.Core/ITensorInitializer.cs b/src/Nncase.Core/ITensorInitializer.cs new file mode 100644 index 000000000..64fa0211b --- /dev/null +++ b/src/Nncase.Core/ITensorInitializer.cs @@ -0,0 +1,16 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Nncase; + +public interface ITensorInitializer +{ + void Initialize(Tensor tensor) + where T : unmanaged, IEquatable; +} diff --git a/src/Nncase.Core/Tensor.cs b/src/Nncase.Core/Tensor.cs index c4b961ae4..2dc85fbfc 100644 --- a/src/Nncase.Core/Tensor.cs +++ b/src/Nncase.Core/Tensor.cs @@ -14,6 +14,7 @@ using CommunityToolkit.HighPerformance; using Nncase.Buffers; using Nncase.IR; +using Nncase.TIR; namespace Nncase; @@ -58,6 +59,9 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab private static readonly MethodInfo _tensorCreateFromArrayFunc = typeof(Tensor).GetMethod(nameof(CreateTensorFromArrayImpl), BindingFlags.Static | BindingFlags.NonPublic)!; + private static readonly MethodInfo _tensorCreateEmptyFunc = + typeof(Tensor).GetMethod(nameof(CreateTensorEmptyImpl), BindingFlags.Static | BindingFlags.NonPublic)!; + private static readonly MethodInfo _tensorCastFunc = typeof(Tensor).GetMethod(nameof(Cast))!; @@ -229,6 +233,13 @@ public static Tensor FromRange(int start, int count) return tensor; } + public static Tensor From(DataType dataType, ITensorInitializer initializer, ReadOnlySpan dimensions) + { + var tensor = Zeros(dataType, dimensions); + tensor.Initialize(initializer); + return tensor; + } + /// /// Create tensor from a memory, Set the shape as [n]. /// @@ -316,6 +327,13 @@ public static Tensor FromBytes(TensorType type, Memory buffer) return FromBytes(type.DType, buffer, type.Shape.ToValueArray()); } + public static Tensor FromStream(DataType type, Stream stream, ReadOnlySpan dimensions) + { + var tensor = Tensor.Zeros(type, dimensions); + tensor.Deserialize(stream); + return tensor; + } + /// /// Create tensor from an array. /// @@ -392,6 +410,11 @@ public static Tensor Zeros(ReadOnlySpan dimensions) return Tensor.FromScalar(value, dimensions); } + public static Tensor Zeros(DataType dataType, ReadOnlySpan dimensions) + { + return (Tensor)_tensorCreateEmptyFunc.MakeGenericMethod(dataType.CLRType).Invoke(null, new object[] { dimensions.ToArray() })!; + } + /// /// Return a tensor of given shape and type, filled with ones. /// @@ -442,6 +465,10 @@ public IEnumerator GetEnumerator() /// String of this tensor. public abstract string GetArrayString(bool includeWhitespace = true); + public abstract void Deserialize(Stream stream); + + public abstract void Serialize(Stream stream); + int IStructuralComparable.CompareTo(object? other, IComparer comparer) { return CompareTo(other, comparer); @@ -511,6 +538,8 @@ void IList.RemoveAt(int index) private protected abstract void SetValueCore(int index, object? value); + private protected abstract void Initialize(ITensorInitializer initializer); + private static Tensor CreateTensorFromBytesImpl(Memory buffer, int[] dimensions) where T : unmanaged, IEquatable { @@ -523,4 +552,10 @@ private static Tensor CreateTensorFromArrayImpl(Array array, int[] dimensions var mmgr = new ArrayMemoryManager(array); return new Tensor(mmgr.Memory, dimensions); } + + private static Tensor CreateTensorEmptyImpl(int[] dimensions) + where T : unmanaged, IEquatable + { + return new Tensor(dimensions); + } } diff --git a/src/Nncase.Core/TensorOfT.cs b/src/Nncase.Core/TensorOfT.cs index ba38ae4ef..bcda1c68e 100644 --- a/src/Nncase.Core/TensorOfT.cs +++ b/src/Nncase.Core/TensorOfT.cs @@ -11,10 +11,12 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; +using CommunityToolkit.HighPerformance; using CommunityToolkit.HighPerformance.Helpers; using NetFabric.Hyperlinq; using Nncase.Buffers; using Nncase.IR; +using Nncase.Utilities; namespace Nncase; @@ -325,6 +327,16 @@ public override string GetArrayString(bool includeWhitespace = true) return builder.ToString(); } + public override void Deserialize(Stream stream) + { + SpanUtility.Deserialize(Buffer.Span, stream); + } + + public override void Serialize(Stream stream) + { + SpanUtility.Serialize((ReadOnlySpan)Buffer.Span, stream); + } + /// /// Gets an enumerator that enumerates the elements of the . /// @@ -544,6 +556,9 @@ private protected override void SetValueCore(int index, object? value) SetValue(index, (T)Convert.ChangeType(value, typeof(T))!); } + private protected override void Initialize(ITensorInitializer initializer) => + initializer.Initialize(this); + private static void Indent(StringBuilder builder, int tabs, int spacesPerTab = 4) { for (int tab = 0; tab < tabs; tab++) diff --git a/src/Nncase.Core/Utilities/SpanUtility.cs b/src/Nncase.Core/Utilities/SpanUtility.cs index 849af1335..e0bcf790a 100644 --- a/src/Nncase.Core/Utilities/SpanUtility.cs +++ b/src/Nncase.Core/Utilities/SpanUtility.cs @@ -8,6 +8,7 @@ using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; +using CommunityToolkit.HighPerformance; namespace Nncase.Utilities; @@ -21,4 +22,28 @@ public static ReadOnlySpan UnsafeCast(ReadOnlySpan froms ref var castFirst = ref Unsafe.As(ref first); return MemoryMarshal.CreateReadOnlySpan(ref castFirst, froms.Length); } + + public static unsafe void Deserialize(Span span, Stream stream) + where T : unmanaged + { + var position = 0; + while (position < span.Length) + { + var length = Math.Min(span.Length - position, 1024 * 1024 * 1024 / sizeof(T)); + stream.ReadExactly(span.Slice(position, length).AsBytes()); + position += length; + } + } + + public static unsafe void Serialize(ReadOnlySpan span, Stream stream) + where T : unmanaged + { + var position = 0; + while (position < span.Length) + { + var length = Math.Min(span.Length - position, 1024 * 1024 * 1024 / sizeof(T)); + stream.Write(span.Slice(position, length).AsBytes()); + position += length; + } + } } diff --git a/src/Nncase.EGraph/Passes/RewriteProvider.cs b/src/Nncase.EGraph/Passes/RewriteProvider.cs index 631cbcfb6..a2e25d3cd 100644 --- a/src/Nncase.EGraph/Passes/RewriteProvider.cs +++ b/src/Nncase.EGraph/Passes/RewriteProvider.cs @@ -70,7 +70,9 @@ public IEGraph ERewrite(IEGraph eGraph, IEnumerable rules, RunPass { var replacedExprs = (from result in results let oldExpr = ((ENode)result.Root).Expr - from newExpr in rule.GetReplaceCandidates(result, context) + let candidates = rule.GetReplaceCandidates(result, context) + where candidates != null + from newExpr in candidates where newExpr != null select (oldExpr, eGraph.Find((ENode)result.Root), newExpr.InheritMetaData(oldExpr))).ToList(); diff --git a/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs b/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs index bf81e72ec..35291304d 100644 --- a/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs +++ b/src/Nncase.Evaluator/Extension/OrtKIExtensions.cs @@ -53,12 +53,12 @@ public static class OrtKIExtensions public static Tensor ToTensor(this OrtKISharp.Tensor tensor) { - return Tensor.FromBytes(tensor.DataType.ToDataType(), tensor.BytesBuffer.ToArray(), tensor.Shape.ToInts()); + return Tensor.From(tensor.DataType.ToDataType(), new TensorInitializerWithOrt(tensor), tensor.Shape.ToInts()); } public static Tensor ToTensor(this OrtKISharp.Tensor tensor, TensorType tensorType) { - return Tensor.FromBytes(tensorType.DType, tensor.BytesBuffer.ToArray(), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape.ToInts()); + return Tensor.From(tensorType.DType, new TensorInitializerWithOrt(tensor), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape.ToInts()); } public static TensorValue ToValue(this OrtKISharp.Tensor tensor) @@ -133,4 +133,20 @@ private static OrtKISharp.Tensor ToOrtTensor(Tensor tensor, OrtDataType ortDataT { return OrtKISharp.Tensor.MakeTensor(tensor.PinBuffer(), ortDataType, shape.ToLongs()); } + + private sealed class TensorInitializerWithOrt : ITensorInitializer + { + private readonly OrtKISharp.Tensor _tensor; + + public TensorInitializerWithOrt(OrtKISharp.Tensor tensor) + { + _tensor = tensor; + } + + public void Initialize(Tensor tensor) + where T : unmanaged, IEquatable + { + _tensor.GetBuffer().CopyTo(tensor.Buffer.Span); + } + } } diff --git a/src/Nncase.Importer/Onnx/DataGatter.cs b/src/Nncase.Importer/Onnx/DataGatter.cs index 0cd5da981..7cca14cba 100644 --- a/src/Nncase.Importer/Onnx/DataGatter.cs +++ b/src/Nncase.Importer/Onnx/DataGatter.cs @@ -114,11 +114,10 @@ private Tensor GetTensor(TensorProto tensor) var externalData = tensor.ExternalData; var location = Path.Join(parent, externalData[0].Value); var offset = externalDataCount > 1L ? long.Parse(externalData[1].Value) : 0; - using var br = new BinaryReader(new FileStream(location, FileMode.Open)); - var length = externalDataCount > 1 ? int.Parse(externalData[2].Value) : (int)br.BaseStream.Length; - br.BaseStream.Seek(offset, SeekOrigin.Begin); - var buffer = br.ReadBytes(length); - return Tensor.FromBytes(type, buffer, shape); + using var fs = new FileStream(location, FileMode.Open); + var length = externalDataCount > 1 ? long.Parse(externalData[2].Value) : fs.Length; + fs.Seek(offset, SeekOrigin.Begin); + return Tensor.FromStream(type, fs, shape); } return dt switch diff --git a/src/Nncase.Passes/DDrBufferSchdeulePass.cs b/src/Nncase.Passes/DDrBufferSchdeulePass.cs index 015b038b9..dcc7e751a 100644 --- a/src/Nncase.Passes/DDrBufferSchdeulePass.cs +++ b/src/Nncase.Passes/DDrBufferSchdeulePass.cs @@ -174,11 +174,11 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan) return memSpan; } - private ulong ComputeSize(IValue v) => v.AsTensors().Select(t => (ulong)t.BytesBuffer.Length).Sum(); + private ulong ComputeSize(IValue v) => v.AsTensors().Select(t => (ulong)t.Length * (ulong)t.ElementType.SizeInBytes).Sum(); private ulong ComputeSize(Const @const) => @const switch { - TensorConst { Value: Tensor tc } => (ulong)tc.BytesBuffer.Length, + TensorConst { Value: Tensor tc } => (ulong)tc.Length * (ulong)tc.ElementType.SizeInBytes, TupleConst tc => ComputeSize(tc.Value), _ => throw new NotSupportedException(), }; diff --git a/src/Nncase.Passes/Rules/Lower/QuantizerMatmul.cs b/src/Nncase.Passes/Rules/Lower/QuantizerMatmul.cs index f5012e39a..5c8a51b57 100644 --- a/src/Nncase.Passes/Rules/Lower/QuantizerMatmul.cs +++ b/src/Nncase.Passes/Rules/Lower/QuantizerMatmul.cs @@ -20,17 +20,14 @@ namespace Nncase.Passes.Rules.Lower; public sealed partial class QuantizerMatmul : IRewriteRule { public IPattern Pattern { get; } - = IsRangeOfMarker( - "markerC", - IsMatMul( + = IsMatMul( "matmul", "call", _ => true, IsRangeOfMarker("markerA", IsWildcard("inputA"), IsTensorConst("scaleA")), - IsRangeOfMarker("markerB", IsWildcard("inputB"), IsTensorConst("scaleB"))), - IsWildcard("scaleC")); + IsRangeOfMarker("markerB", IsWildcard("inputB"), IsTensorConst("scaleB"))); - private Expr? GetReplace(Expr matmul, Call call, Expr inputA, Marker markerA, TensorConst scaleA, Expr inputB, Marker markerB, TensorConst scaleB, Marker markerC, RunPassContext context) + private Expr? GetReplace(Expr matmul, Call call, Expr inputA, Marker markerA, TensorConst scaleA, Expr inputB, Marker markerB, TensorConst scaleB, RunPassContext context) { if (inputA is not TensorConst && inputB is not TensorConst) { diff --git a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs index da78650d3..c5602c210 100644 --- a/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs +++ b/src/Nncase.Schedule/Schedule/TileGraph/PrimGraphSolveResult.cs @@ -128,7 +128,32 @@ public Unit Visit(TileNode value, Context context) var srcBufView = IR.F.Buffer.BufferSubview(viewInfo.Buffer, viewInfo.Offsets, new IR.Tuple(viewInfo.Shape.Select(x => (Expr)x).ToArray())); if (kernelInfo.BufferInfos[bid.Index].State.HasFlag(MicroKernelBufferInfo.BufferState.Read)) { - letBuilder.Body(T.Memcopy(subViewVar, srcBufView)); + if (bid.Node.Op.GetType().Name.Contains("Matmul", StringComparison.Ordinal) && bid.IsOutput) + { + var kdim = bid.Node.WriteAccess.Domains.Length - 2; + var relatedWithK = bufferInfo.Masks[i].IsRelated(kdim); + var val = value; + bool isLoopRelated = false; + while (val.Parent is TileNode parent) + { + if (TileableNodeMemo.TryGetValue(val, out var m) && m.TileVars[kdim] != 1) + { + isLoopRelated = true; + break; + } + + val = parent; + } + + if (relatedWithK && isLoopRelated) + { + letBuilder.Body(T.Memcopy(subViewVar, srcBufView)); + } + } + else + { + letBuilder.Body(T.Memcopy(subViewVar, srcBufView)); + } } if (kernelInfo.BufferInfos[bid.Index].State.HasFlag(MicroKernelBufferInfo.BufferState.Write))