Skip to content

Commit

Permalink
fix bugs found in xpu debugging (#1274)
Browse files Browse the repository at this point in the history
* fix matmul quantizer
* optimize output copy
* add pack rule for cast
* fix importer of onnx tensor > 2GB
* fix const size compuation if > 2GB
* fix null candidate of egraph
* fix pack eval when > 2GB
* fix cpu function builder with const > 2GB
* fix stackvm codegen with const > 2GB
* Add (De)Serialize for Tensor
* Simplify Pack evaluator
* Fix FunctionBuilder
* fix span utility
---------
Co-authored-by: sunnycase <sunnycase@live.cn>
  • Loading branch information
xhuohai authored Dec 2, 2024
1 parent 3b93a41 commit 492e867
Show file tree
Hide file tree
Showing 14 changed files with 210 additions and 24 deletions.
8 changes: 5 additions & 3 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/FunctionBuilder.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
using System.Drawing;
using System.Numerics;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
Expand Down Expand Up @@ -50,16 +52,16 @@ public unsafe ILinkableFunction Build(TIR.PrimFunction function)
ulong rdataPoolSize = ulong.MinValue;
foreach (var (@const, range) in function.SchedResult.Rdatas)
{
var bytes = ((TensorConst)@const).Value.BytesBuffer;
var tensor = ((TensorConst)@const).Value;
var size = range.Max - range.Min;
rdataPoolSize = System.Math.Max(range.Max, rdataPoolSize);
if ((uint)bytes.Length != size)
if ((ulong)tensor.Length * (ulong)tensor.ElementType.SizeInBytes != size)
{
throw new InvalidDataException("The Buffer Size Not Equal!");
}

_rdataWriter.Position(checked((long)range.Min));
_rdataWriter.Write(bytes);
tensor.Serialize(_rdataWriter.BaseStream);
}

// 3. build function.
Expand Down
7 changes: 6 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/CPU/Pack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.Http.Headers;
using System.Numerics;
using System.Runtime.InteropServices;
using CommunityToolkit.HighPerformance;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.IR.CPU;
Expand Down Expand Up @@ -41,7 +45,8 @@ public IValue Visit(IEvaluateContext context, Pack target)
input = input.Pack(lanes, axis);
}

return Value.FromTensor(Tensor.FromBytes(new VectorType(input.DataType.ToDataType(), target.Lanes), input.BytesBuffer.ToArray(), input.Shape.ToArray().SkipLast(target.Lanes.Count).Select(i => (int)i).ToArray()));
var dt = input.DataType.ToDataType();
return Value.FromTensor(input.ToTensor(new TensorType(new VectorType(input.DataType.ToDataType(), target.Lanes), new Shape(input.Shape.SkipLast(target.Lanes.Count).Select(i => (int)i)))));
}
}

Expand Down
46 changes: 46 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/PackRule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,52 @@ void AddCandidate(int[] packAxes, int[] lanes)
}
}

public sealed class PackCast : PackRule
{
public PackCast(int rank, int lane)
: base(rank, lane)
{
}

public override Pattern Pattern { get; } = IsCast(
"target",
_ => true,
IsWildcard("input", e => e is not Call { Target: IR.CPU.Unpack }) with { TypePattern = IsFloat() & !IsVector() });

public override List<Expr> GetReplaceCandidates(IMatchResult result, RunPassContext context)
{
var rets = new List<Expr>();
var op = (IR.Tensors.Cast)result["target"];
var input = (Expr)result["input"];
var inShape = input.CheckedShape.ToValueArray();

void AddCandidate(int[] packedAxes, int[] lanes)
{
var packedInput = IR.F.CPU.Pack(PackUtility.PadForPack(input, inShape, packedAxes, lanes, 0f, out var padsInput), lanes, packedAxes);
var cast = IR.F.Tensors.Cast(packedInput, op.NewType, op.CastMode);
var post = PackUtility.SliceForPack(IR.F.CPU.Unpack(cast, lanes, packedAxes), inShape, padsInput);
if (cast.CheckedType is not InvalidType)
{
rets.Add(post);
}
}

for (int i = 0; i < input.CheckedShape.Count; i++)
{
AddCandidate(new[] { i }, new[] { Lane });
for (int j = i + 1; j < input.CheckedShape.Count; j++)
{
if (Rank > 1)
{
AddCandidate(new[] { i, j }, new[] { Lane, Lane });
}
}
}

return rets;
}
}

[RuleGenerator]
public sealed partial class FoldPackUnpack : RewriteRule<Pattern>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Numerics;
using System.Reactive;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using NetFabric.Hyperlinq;
Expand Down Expand Up @@ -403,7 +406,7 @@ private void MergeSnippetSet(List<TextSnippet> thenSet, List<TextSnippet> elseSe

private TextSnippet Visit(TensorConst expr, Tensor tensor)
{
var buffer = WriteRdata(tensor.BytesBuffer, _alignment);
var buffer = WriteRdata(tensor, _alignment);

// stack: dtype shape strides buffer
var snippet = BeginTextSnippet(expr);
Expand All @@ -428,11 +431,11 @@ private Symbol WriteRdata(DataType dataType)
return symbol;
}

private Symbol WriteRdata(ReadOnlySpan<byte> data, int alignment)
private Symbol WriteRdata(Tensor tensor, int alignment)
{
_context.RdataWriter.AlignPosition(alignment);
var symbol = AddSymbol(WellknownSectionNames.Rdata);
_context.RdataWriter.Write(data);
tensor.Serialize(_context.RdataWriter.BaseStream);
return symbol;
}

Expand Down
16 changes: 16 additions & 0 deletions src/Nncase.Core/ITensorInitializer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Nncase;

public interface ITensorInitializer
{
void Initialize<T>(Tensor<T> tensor)
where T : unmanaged, IEquatable<T>;
}
35 changes: 35 additions & 0 deletions src/Nncase.Core/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
using CommunityToolkit.HighPerformance;
using Nncase.Buffers;
using Nncase.IR;
using Nncase.TIR;

namespace Nncase;

Expand Down Expand Up @@ -58,6 +59,9 @@ public abstract partial class Tensor : IStructuralComparable, IStructuralEquatab
private static readonly MethodInfo _tensorCreateFromArrayFunc =
typeof(Tensor).GetMethod(nameof(CreateTensorFromArrayImpl), BindingFlags.Static | BindingFlags.NonPublic)!;

private static readonly MethodInfo _tensorCreateEmptyFunc =
typeof(Tensor).GetMethod(nameof(CreateTensorEmptyImpl), BindingFlags.Static | BindingFlags.NonPublic)!;

private static readonly MethodInfo _tensorCastFunc =
typeof(Tensor).GetMethod(nameof(Cast))!;

Expand Down Expand Up @@ -229,6 +233,13 @@ public static Tensor<int> FromRange(int start, int count)
return tensor;
}

public static Tensor From(DataType dataType, ITensorInitializer initializer, ReadOnlySpan<int> dimensions)
{
var tensor = Zeros(dataType, dimensions);
tensor.Initialize(initializer);
return tensor;
}

/// <summary>
/// Create tensor from a memory, Set the shape as [n].
/// </summary>
Expand Down Expand Up @@ -316,6 +327,13 @@ public static Tensor FromBytes(TensorType type, Memory<byte> buffer)
return FromBytes(type.DType, buffer, type.Shape.ToValueArray());
}

public static Tensor FromStream(DataType type, Stream stream, ReadOnlySpan<int> dimensions)
{
var tensor = Tensor.Zeros(type, dimensions);
tensor.Deserialize(stream);
return tensor;
}

/// <summary>
/// Create tensor from an array.
/// </summary>
Expand Down Expand Up @@ -392,6 +410,11 @@ public static Tensor Zeros<T>(ReadOnlySpan<int> dimensions)
return Tensor.FromScalar<T>(value, dimensions);
}

public static Tensor Zeros(DataType dataType, ReadOnlySpan<int> dimensions)
{
return (Tensor)_tensorCreateEmptyFunc.MakeGenericMethod(dataType.CLRType).Invoke(null, new object[] { dimensions.ToArray() })!;
}

/// <summary>
/// Return a tensor of given shape and type, filled with ones.
/// </summary>
Expand Down Expand Up @@ -442,6 +465,10 @@ public IEnumerator GetEnumerator()
/// <returns>String of this tensor.</returns>
public abstract string GetArrayString(bool includeWhitespace = true);

public abstract void Deserialize(Stream stream);

public abstract void Serialize(Stream stream);

int IStructuralComparable.CompareTo(object? other, IComparer comparer)
{
return CompareTo(other, comparer);
Expand Down Expand Up @@ -511,6 +538,8 @@ void IList.RemoveAt(int index)

private protected abstract void SetValueCore(int index, object? value);

private protected abstract void Initialize(ITensorInitializer initializer);

private static Tensor CreateTensorFromBytesImpl<T>(Memory<byte> buffer, int[] dimensions)
where T : unmanaged, IEquatable<T>
{
Expand All @@ -523,4 +552,10 @@ private static Tensor CreateTensorFromArrayImpl<T>(Array array, int[] dimensions
var mmgr = new ArrayMemoryManager<T>(array);
return new Tensor<T>(mmgr.Memory, dimensions);
}

private static Tensor CreateTensorEmptyImpl<T>(int[] dimensions)
where T : unmanaged, IEquatable<T>
{
return new Tensor<T>(dimensions);
}
}
15 changes: 15 additions & 0 deletions src/Nncase.Core/TensorOfT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using CommunityToolkit.HighPerformance;
using CommunityToolkit.HighPerformance.Helpers;
using NetFabric.Hyperlinq;
using Nncase.Buffers;
using Nncase.IR;
using Nncase.Utilities;

namespace Nncase;

Expand Down Expand Up @@ -325,6 +327,16 @@ public override string GetArrayString(bool includeWhitespace = true)
return builder.ToString();
}

public override void Deserialize(Stream stream)
{
SpanUtility.Deserialize(Buffer.Span, stream);
}

public override void Serialize(Stream stream)
{
SpanUtility.Serialize((ReadOnlySpan<T>)Buffer.Span, stream);
}

/// <summary>
/// Gets an enumerator that enumerates the elements of the <see cref="Tensor{T}"/>.
/// </summary>
Expand Down Expand Up @@ -544,6 +556,9 @@ private protected override void SetValueCore(int index, object? value)
SetValue(index, (T)Convert.ChangeType(value, typeof(T))!);
}

private protected override void Initialize(ITensorInitializer initializer) =>
initializer.Initialize(this);

private static void Indent(StringBuilder builder, int tabs, int spacesPerTab = 4)
{
for (int tab = 0; tab < tabs; tab++)
Expand Down
25 changes: 25 additions & 0 deletions src/Nncase.Core/Utilities/SpanUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Runtime.InteropServices;
using System.Text;
using System.Threading.Tasks;
using CommunityToolkit.HighPerformance;

namespace Nncase.Utilities;

Expand All @@ -21,4 +22,28 @@ public static ReadOnlySpan<TTo> UnsafeCast<TFrom, TTo>(ReadOnlySpan<TFrom> froms
ref var castFirst = ref Unsafe.As<TFrom, TTo>(ref first);
return MemoryMarshal.CreateReadOnlySpan(ref castFirst, froms.Length);
}

public static unsafe void Deserialize<T>(Span<T> span, Stream stream)
where T : unmanaged
{
var position = 0;
while (position < span.Length)
{
var length = Math.Min(span.Length - position, 1024 * 1024 * 1024 / sizeof(T));
stream.ReadExactly(span.Slice(position, length).AsBytes());
position += length;
}
}

public static unsafe void Serialize<T>(ReadOnlySpan<T> span, Stream stream)
where T : unmanaged
{
var position = 0;
while (position < span.Length)
{
var length = Math.Min(span.Length - position, 1024 * 1024 * 1024 / sizeof(T));
stream.Write(span.Slice(position, length).AsBytes());
position += length;
}
}
}
4 changes: 3 additions & 1 deletion src/Nncase.EGraph/Passes/RewriteProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ public IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPass
{
var replacedExprs = (from result in results
let oldExpr = ((ENode)result.Root).Expr
from newExpr in rule.GetReplaceCandidates(result, context)
let candidates = rule.GetReplaceCandidates(result, context)
where candidates != null
from newExpr in candidates
where newExpr != null
select (oldExpr, eGraph.Find((ENode)result.Root), newExpr.InheritMetaData(oldExpr))).ToList();

Expand Down
20 changes: 18 additions & 2 deletions src/Nncase.Evaluator/Extension/OrtKIExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ public static class OrtKIExtensions

public static Tensor ToTensor(this OrtKISharp.Tensor tensor)
{
return Tensor.FromBytes(tensor.DataType.ToDataType(), tensor.BytesBuffer.ToArray(), tensor.Shape.ToInts());
return Tensor.From(tensor.DataType.ToDataType(), new TensorInitializerWithOrt(tensor), tensor.Shape.ToInts());
}

public static Tensor ToTensor(this OrtKISharp.Tensor tensor, TensorType tensorType)
{
return Tensor.FromBytes(tensorType.DType, tensor.BytesBuffer.ToArray(), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape.ToInts());
return Tensor.From(tensorType.DType, new TensorInitializerWithOrt(tensor), tensorType.Shape.IsFixed ? tensorType.Shape : tensor.Shape.ToInts());
}

public static TensorValue ToValue(this OrtKISharp.Tensor tensor)
Expand Down Expand Up @@ -133,4 +133,20 @@ private static OrtKISharp.Tensor ToOrtTensor(Tensor tensor, OrtDataType ortDataT
{
return OrtKISharp.Tensor.MakeTensor(tensor.PinBuffer(), ortDataType, shape.ToLongs());
}

private sealed class TensorInitializerWithOrt : ITensorInitializer
{
private readonly OrtKISharp.Tensor _tensor;

public TensorInitializerWithOrt(OrtKISharp.Tensor tensor)
{
_tensor = tensor;
}

public void Initialize<T>(Tensor<T> tensor)
where T : unmanaged, IEquatable<T>
{
_tensor.GetBuffer<T>().CopyTo(tensor.Buffer.Span);
}
}
}
9 changes: 4 additions & 5 deletions src/Nncase.Importer/Onnx/DataGatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ private Tensor GetTensor(TensorProto tensor)
var externalData = tensor.ExternalData;
var location = Path.Join(parent, externalData[0].Value);
var offset = externalDataCount > 1L ? long.Parse(externalData[1].Value) : 0;
using var br = new BinaryReader(new FileStream(location, FileMode.Open));
var length = externalDataCount > 1 ? int.Parse(externalData[2].Value) : (int)br.BaseStream.Length;
br.BaseStream.Seek(offset, SeekOrigin.Begin);
var buffer = br.ReadBytes(length);
return Tensor.FromBytes(type, buffer, shape);
using var fs = new FileStream(location, FileMode.Open);
var length = externalDataCount > 1 ? long.Parse(externalData[2].Value) : fs.Length;
fs.Seek(offset, SeekOrigin.Begin);
return Tensor.FromStream(type, fs, shape);
}

return dt switch
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Passes/DDrBufferSchdeulePass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ protected override TIR.MemSpan RewriteLeafMemSpan(TIR.MemSpan memSpan)
return memSpan;
}

private ulong ComputeSize(IValue v) => v.AsTensors().Select(t => (ulong)t.BytesBuffer.Length).Sum();
private ulong ComputeSize(IValue v) => v.AsTensors().Select(t => (ulong)t.Length * (ulong)t.ElementType.SizeInBytes).Sum();

private ulong ComputeSize(Const @const) => @const switch
{
TensorConst { Value: Tensor tc } => (ulong)tc.BytesBuffer.Length,
TensorConst { Value: Tensor tc } => (ulong)tc.Length * (ulong)tc.ElementType.SizeInBytes,
TupleConst tc => ComputeSize(tc.Value),
_ => throw new NotSupportedException(),
};
Expand Down
Loading

0 comments on commit 492e867

Please sign in to comment.