Skip to content

Commit

Permalink
Add FoldConv2DBiasWithMarker
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jan 9, 2025
1 parent 1f0e822 commit 69c16aa
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.IR.F.NN;
using static Nncase.IR.F.Tensors;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.F.NN;
using static Nncase.PatternMatch.F.Tensors;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.MetadataUtility;
using Shape = Nncase.IR.Shape;

namespace Nncase.Passes.Rules.Neutral;

// rules in this file are used for ShapeBucket

/// <summary>
/// Transform <see cref="IR.NN.Conv2D"/> to <see cref="IR.Math.Binary"/>.
/// </summary>
[RuleGenerator]
public sealed partial class FoldConv2DBiasWithMarker : IRewriteRule
{
private static int _counter;

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-windows

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-linux

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-windows

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-x86_64-windows

The field 'FoldConv2DBiasWithMarker._counter' is never used

Check warning on line 33 in src/Nncase.Passes/Rules/WithMarker/FoldConv2DBiasWithMarker.cs

View workflow job for this annotation

GitHub Actions / build-aarch64-macos

The field 'FoldConv2DBiasWithMarker._counter' is never used

/// <inheritdoc/>
public IPattern Pattern { get; } = IsRangeOfMarker(
"binarym",
IsBinary(
"binary",
"binaryCall",
p => p.BinaryOp is BinaryOp.Add,
IsReshape(
IsRangeOfMarker(
"convm",
IsConv2D(
"conv2d",
_ => true,
IsWildcard("input"),
IsWildcard("weights"),
IsTensorConst("bias") with { TypePattern = HasRank(1) },
IsWildcard("stride"),
IsWildcard("padding"),
IsWildcard("dilation"),
IsWildcard("groups")),
IsWildcard()),
IsWildcard("shape")),
IsRangeOfMarker("bm", IsTensorConst("b") with { TypePattern = HasRank(1) }, IsWildcard())),
IsWildcard());

private Expr? GetReplace(Conv2D conv2d, Call binaryCall, Expr input, Expr weights, Tensor bias, Tensor b, Expr shape, Expr stride, Expr padding, Expr dilation, Expr groups, Marker binarym)
{
var newBias = IR.F.Math.Add(bias, b).Evaluate().AsTensor();
var newConv2d = Conv2D(
input,
weights,
newBias,
stride,
padding,
dilation,
conv2d.PadMode,
groups).InheritMetaData(binaryCall);
var m = Reshape(binarym.With(target: newConv2d), shape).InheritMetaData(binaryCall);
return m;
}
}

0 comments on commit 69c16aa

Please sign in to comment.