Skip to content

Commit

Permalink
Merge branch 'master' into release/2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyang2057 committed Jan 24, 2024
2 parents c0f32e3 + de6c89d commit 83d5055
Show file tree
Hide file tree
Showing 19 changed files with 314 additions and 81 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/compiler-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ jobs:
cache-dependency-path: '**/requirements.test.txt'

- name: Install Python Packages
run: pip install -r requirements.test.txt
run:
python -m pip install --upgrade pip
pip install -r requirements.test.txt

- name: Create Test Environment
run: mkdir test_results
Expand Down
1 change: 1 addition & 0 deletions Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
<PackageVersion Include="Xunit.DependencyInjection" Version="8.3.0" />
<PackageVersion Include="xunit.runner.visualstudio" Version="2.5.0" />
<PackageVersion Include="xunit.v3.assert" Version="0.1.1-pre.239" />
<PackageVersion Include="Razor.Templating.Core" Version="1.9.0" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="StyleCop.Analyzers">
Expand Down
18 changes: 15 additions & 3 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
using Nncase.Passes.Rules.Neutral;
using Nncase.Passes.Rules.ShapeBucket;
using Nncase.Passes.Rules.ShapeExpr;
using Nncase.Passes.Rules.WithMarker;
using Nncase.Passes.Transforms;
using Nncase.Quantization;
using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketRegister;
using CombinePadTranspose = Nncase.Passes.Rules.WithMarker.CombinePadTranspose;
using CombineReshapePad = Nncase.Passes.Rules.Neutral.CombineReshapePad;
using FoldConstCall = Nncase.Passes.Rules.Neutral.FoldConstCall;

namespace Nncase.Compiler;
Expand Down Expand Up @@ -97,6 +100,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.NormAxisReshape>();
p.Add<Passes.Rules.Neutral.NormAxisReduceArg>();
p.Add<Passes.Rules.Neutral.NormAxisSlice>();
p.Add<Passes.Rules.Neutral.SwapBinaryArgs>();
p.Add<Passes.Rules.Neutral.SqueezeTransposeShape>();
p.Add<Passes.Rules.Neutral.Squeeze5DTranspose>();
p.Add<Passes.Rules.Neutral.SqueezeBinaryShape>();
Expand All @@ -117,8 +121,6 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
p.Add<Passes.Rules.Neutral.SplitSpaceToBatch>();
p.Add<Passes.Rules.Neutral.SplitBatchToSpace>();
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldShapeOf>();
p.Add<Passes.Rules.Neutral.FoldTwoReshapes>();
Expand All @@ -131,20 +133,28 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldUnsqueezeSqueeze>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<Passes.Rules.Neutral.FoldNopClamp>();
p.Add<Passes.Rules.ShapeBucket.FoldRepeatMarker>();
p.Add<Passes.Rules.Neutral.SqueezeToReshape>();
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Passes.Rules.ShapeExpr.GatherToGetItem>();
p.Add<Passes.Rules.ShapeExpr.FoldGetItemShapeOf>();
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
p.Add<Passes.Rules.Neutral.FoldDilatedConv2D>();
});

passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
p.Add<Passes.Rules.Neutral.FoldConstCall>();
p.Add<Passes.Rules.Neutral.FoldNopTranspose>();
p.Add<Passes.Rules.Neutral.FoldTwoTransposes>();
p.Add<FoldRepeatMarker>();
p.Add<Passes.Rules.WithMarker.FoldTransposeActTranspose>();
p.Add<Passes.Rules.WithMarker.FoldTransposeBinaryActTranspose>();
p.Add<Passes.Rules.WithMarker.CombineReshapePad>();
p.Add<Passes.Rules.WithMarker.CombineTransposePad>();
p.Add<Passes.Rules.WithMarker.CombinePadTranspose>();
p.Add<Passes.Rules.Neutral.CombineTransposeUnary>();
p.Add<Passes.Rules.Neutral.CombineTransposePad>();
p.Add<Passes.Rules.Neutral.CombinePadTranspose>();
Expand Down Expand Up @@ -179,12 +189,14 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldNopSlice>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.SpaceToBatchToPad>();
p.Add<Passes.Rules.Neutral.FoldConv2DAddMul>();
});

_compileSession.Target.RegisterTargetInDependentPass(passManager, _compileSession.CompileOptions);

passManager.AddWithName<DataflowPass>("BroadcastMarker").Configure(p =>
{
p.Add<FoldTransposeActTranspose>();
p.Add<BroadcastInputMarker>();
p.Add<BroadcastOutputMarker>();
});
Expand Down Expand Up @@ -220,8 +232,8 @@ public void RegisterShapeBucket(IPassManager p)
MergeOp(p, true);
ClearMarker(p);
MergeFusion(p, singleVar, true);
Bucket(p);
Rebuild(p, singleVar);
Bucket(p);
Simplify(p);
}
else
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Core/IR/NN/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public sealed partial class Conv2D : Op
/// <summary>
/// Gets FusedClamp.
/// </summary>
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & HasDataType(DataTypes.Float32));
public static readonly ParameterInfo FusedClamp = new(typeof(Conv2D), 7, "fused_clamp", HasShape(new Shape(2)) & IsFloat());

public PadMode PadMode { get; }

Expand Down
14 changes: 13 additions & 1 deletion src/Nncase.EGraph/Passes/EGraphExtractors/SatExtractor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode,
}

var solver = new CpSolver();
int max_time = 600;
int max_time = 120;
if (System.Environment.GetEnvironmentVariable("SOLVE_MAX_TIME") is string s_solve_max_time)
{
try
Expand All @@ -91,6 +91,18 @@ public Expr Extract(EClass root, IEGraph eGraph, out IReadOnlyDictionary<ENode,
}

int processorCount = Math.Max(System.Environment.ProcessorCount / 2, 1);
if (System.Environment.GetEnvironmentVariable("SOLVE_PROCESSOR_COUNT") is string s_solve_processor_count)
{
try
{
var solve_processor_count = int.Parse(s_solve_processor_count);
processorCount = solve_processor_count;
}
catch (System.Exception)
{
}
}

solver.StringParameters = $"max_time_in_seconds:{max_time},num_workers:{processorCount}";

var enableDump = DumpScope.Current.IsEnabled(DumpFlags.EGraphCost);
Expand Down
4 changes: 3 additions & 1 deletion src/Nncase.Evaluator/NN/Activations.cs
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ private IRType Visit(TensorType input)
/// <summary>
/// Evaluator for <see cref="Sigmoid"/>.
/// </summary>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>
public class SwishEvaluator : IEvaluator<Swish>, ITypeInferencer<Swish>, ICostEvaluator<Swish>, IMetricEvaluator<Swish>, IShapeEvaluator<Swish>
{
/// <inheritdoc/>
public IValue Visit(IEvaluateContext context, Swish swish)
Expand Down Expand Up @@ -560,6 +560,8 @@ public Metric Visit(IMetricEvaluateContext context, Swish target)
};
}

public Expr Visit(IShapeEvaluateContext context, Swish target) => context.GetArgumentShape(target, Swish.Input);

private IRType Visit(IRType input)
{
if (input is DistributedType d && d.NdSBP.Any(s => s is SBPPartialSum))
Expand Down
3 changes: 2 additions & 1 deletion src/Nncase.Evaluator/NN/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ public IValue Visit(IEvaluateContext context, Conv2D conv)
new long[] { kernelShape[2], kernelShape[3] },
ToOnnxPadFormat(pad),
stride);
return OrtKI.Clip(result, fusedClamp[0], fusedClamp[1]).ToValue();
var outType = input.ToTensor().ElementType;
return Value.FromTensor(OrtKI.Clip(result.ToTensor().Cast<float>().ToOrtTensor(), fusedClamp[0], fusedClamp[1]).ToTensor().CastTo(outType));
}

/// <inheritdoc/>
Expand Down
8 changes: 7 additions & 1 deletion src/Nncase.Evaluator/NN/LayerNorm.cs
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,13 @@ private Tensor LayerNormImpl(Tensor input, Tensor scale, Tensor bias, int axis,
}
}

return new Tensor<float>(outputArray, input.Shape);
var ret = new Tensor<float>(outputArray, input.Shape);
return input.ElementType switch
{
Float32Type => ret,
Float16Type => ret.Cast<Half>(CastMode.KDefault),
_ => throw new NotSupportedException("Not Supported Type of Layernorm!"),
};
}
#endif
}
4 changes: 3 additions & 1 deletion src/Nncase.Importer/Onnx/Transpose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Linq;
using DryIoc.ImTools;
using LanguageExt.UnsafeValueAccess;
using Nncase.IR;
using Nncase.IR.Tensors;
Expand All @@ -16,7 +17,8 @@ public partial class OnnxImporter
private Expr VisitTranspose(NodeProto op)
{
var input = GetSingleInputExpr(op);
var perm = Tensor.From<long>(GetIntsAttribute(op, "perm"));
var defaultPerm = Enumerable.Range(0, input.CheckedShape.Rank).Reverse().ToArray();
var perm = Tensor.From(GetIntsAttribute(op, "perm", defaultPerm));
return F.Tensors.Transpose(input, perm).With(metadata: new IRMetadata() { OutputNames = op.Output, });
}
}
Expand Down
12 changes: 8 additions & 4 deletions src/Nncase.Importer/TFLite/SpaceToBatchND.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using Nncase.IR;
using Nncase.IR.Tensors;
using static Nncase.IR.F.NN;
Expand All @@ -14,15 +16,16 @@ private Expr VisitSpaceToBatchND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var paddings = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
paddings = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, paddings }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var stb = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), blockShape, paddings));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(stb, new[] { 1 });
}
Expand All @@ -34,15 +37,16 @@ private Expr VisitBatchToSpaceND(in tflite.Operator op)
{
var (input, blockShape) = GetInputExprs(op, 0, 1);
var crops = GetInputExprs(op, 2);
if (input.CheckedShape.Rank == 3)
bool needUnsqueeze = input.CheckedShape.Rank == 3;
if (needUnsqueeze)
{
blockShape = Concat(new IR.Tuple(new[] { new[] { 1 }, blockShape }), 0);
crops = Concat(new IR.Tuple(new[] { new[,] { { 0, 0 } }, crops }), 0);
input = Unsqueeze(input, new[] { -3 });
}

var bts = NCHWToNHWC(BatchToSpace(NHWCToNCHW(input), blockShape, crops));
if (input.CheckedShape.Rank == 3)
if (needUnsqueeze)
{
return Squeeze(bts, new[] { 1 });
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/AddRangeOfAndMarker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ private IR.Tuple WrapLSTMOutput(Call call, int outputSize, bool configExist, boo
{
var outputNames = new List<string>();
var getItem = IR.F.Tensors.GetItem(call, i);
outputNames.Add("LSTMOutput_" + call.Metadata.OutputNames?[i]);
outputNames.Add(call.Metadata.OutputNames?[i] ?? "LSTMOutput_" + i.ToString());
outputs[i].Metadata.OutputNames = outputNames;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public sealed partial class CombineQuantizeConcat : RewriteRule<Pattern>
}
else
{
if (user != tuple)
if (!ReferenceEquals(user, tuple))
{
return null;
}
Expand Down
23 changes: 14 additions & 9 deletions src/Nncase.Passes/Rules/Neutral/FoldBinary.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ public sealed partial class FoldNopBinary : IRewriteRule
"binary",
x => x.BinaryOp is BinaryOp.Add or BinaryOp.Sub or BinaryOp.Mul or BinaryOp.Div or BinaryOp.Mod or BinaryOp.Pow,
IsWildcard("lhs"),
IsTensorConst("rhs", IsScalar()));
IsTensorConst("rhs"));

private Expr? GetReplace(Binary binary, Expr lhs, TensorConst rhs)
{
return (binary.BinaryOp, rhs.Value.ToScalar<float>()) switch
if (lhs.CheckedType is Nncase.IR.AnyType || lhs.CheckedShape == rhs.CheckedShape)
{
(BinaryOp.Add, 0f) => lhs,
(BinaryOp.Sub, 0f) => lhs,
(BinaryOp.Mul, 1f) => lhs,
(BinaryOp.Pow, 1f) => lhs,
(BinaryOp.Div, 1f) => lhs,
_ => null,
};
return binary.BinaryOp switch
{
BinaryOp.Add when rhs.Value.ToArray<float>().All(x => x == 0) => lhs,
BinaryOp.Sub when rhs.Value.ToArray<float>().All(x => x == 0) => lhs,
BinaryOp.Mul when rhs.Value.ToArray<float>().All(x => x == 1) => lhs,
BinaryOp.Pow when rhs.Value.ToArray<float>().All(x => x == 1) => lhs,
BinaryOp.Div when rhs.Value.ToArray<float>().All(x => x == 1) => lhs,
_ => null,
};
}

return null;
}
}
3 changes: 2 additions & 1 deletion src/Nncase.Passes/Rules/Neutral/FoldDilatedConv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ private static CallPattern Conv2DPattern() =>
(Conv2D.Stride.Index, (Expr)new[] { strideH, strideW }),
(Conv2D.Dilation.Index, (Expr)new[] { dilationH, dilationW }),
};
return ReplaceUtility.ReplaceCallParams(conv, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
var res = ReplaceUtility.ReplaceCallParams(conv.Target, conv.Arguments.ToArray(), pairs).InheritMetaData(btsCall);
return res;
}

private (int[] Begin, int[] End) GetBeginEnd(int[] btsBlockShape, int[,] crop, int[] btsInputShape)
Expand Down
1 change: 1 addition & 0 deletions src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ protected override Task<BaseFunction> RunCoreAsync(BaseFunction main, RunPassCon
var memo = EvaluatorUtil.GetMemo(body, ConcatDictionary(input, varValues));
var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues));
f.Visit(main);
GC.Collect();
return f.FusionShape;
}).SelectMany(x => x)
.ToLookup(x => x.Key, x => x.Value)
Expand Down
Loading

0 comments on commit 83d5055

Please sign in to comment.