Skip to content

Commit

Permalink
change batch_norm scheme from aten::batch_norm to ipex::batch_norm to…
Browse files Browse the repository at this point in the history
… disable TE fusion path (#404)

* change batch_norm scheme from aten::batch_norm to ipex::batch_norm to disable TE fusion path

* remove jira link
  • Loading branch information
XiaobingSuper authored and EikanWang committed Dec 13, 2021
1 parent d8cd254 commit b19a9c5
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 17 deletions.
2 changes: 1 addition & 1 deletion tests/cpu/test_ipex_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_optimize_parameters_behavior(self):
x = torch.randn(1, 3, 224, 224)
traced_model = torch.jit.trace(opt_M, x)
trace_graph = traced_model.graph_for(x)
self.assertTrue(any(n.kind() == "aten::batch_norm" for n in trace_graph.nodes()))
self.assertTrue(any(n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()))
# TODO check weight_prepack.

def test_optimize_bf16_model(self):
Expand Down
39 changes: 29 additions & 10 deletions tests/cpu/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,14 @@ def __init__(self, dim=-1):
def forward(self, x):
return self.softmax(x)

class AtenBatchNormRepalce(nn.Module):
def __init__(self):
super(AtenBatchNormRepalce, self).__init__()
self.bn = torch.nn.BatchNorm2d(10)

def forward(self, x):
return self.bn(x)

class AddLayerNorm(torch.nn.Module):
def __init__(self, dim=32):
super(AddLayerNorm, self).__init__()
Expand Down Expand Up @@ -925,35 +933,35 @@ def test_output_conv_bn_2d(self):
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
torch.randn(32, 3, 64, 64),
kind_in_graph="ipex_prepack::convolution_run",
kind_not_in_graph="aten::batch_norm",
kind_not_in_graph="ipex::batch_norm",
levels=['O1'])
self._test_output_bf16(
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
torch.randn(32, 3, 64, 64),
kind_in_graph="ipex_prepack::convolution_run",
kind_not_in_graph="aten::batch_norm",
kind_not_in_graph="ipex::batch_norm",
prec=0.02,
levels=['O1'])

def test_output_bn_conv_2d(self):
self._test_output(
BatchNormConv_Fixed(2, 3, 32, kernel_size=3, stride=1),
torch.randn(32, 3, 64, 64),
kind_in_graph="aten::batch_norm",
kind_in_graph="ipex::batch_norm",
kind_not_in_graph=None)

def test_output_bn_conv_bn(self):
self._test_output(
BatchNorm_Conv_BatchNorm(2, 3, 32, kernel_size=3, stride=1),
torch.randn(32, 3, 64, 64),
kind_in_graph="aten::batch_norm",
kind_in_graph="ipex::batch_norm",
kind_not_in_graph=None)

def test_output_conv_reshape_bn_2d(self):
self._test_output(
ConvReshapeBatchNorm(2, 3, 32, (64, 16, 62, 62), kernel_size=3, stride=1),
torch.randn(32, 3, 64, 64),
kind_in_graph="aten::batch_norm",
kind_in_graph="ipex::batch_norm",
kind_not_in_graph=None)

def test_output_conv_conv_concate(self):
Expand Down Expand Up @@ -994,7 +1002,7 @@ def test_output_conv_bn_3d(self):
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
torch.randn(32, 3, 32, 32, 32),
kind_in_graph="aten::conv3d",
kind_not_in_graph="aten::batch_norm")
kind_not_in_graph="ipex::batch_norm")

def test_output_conv_relu_2d(self):
self._test_output(
Expand Down Expand Up @@ -1061,25 +1069,25 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
torch.rand(32, 3, 64, 64),
kind_in_graph="ipex_prepack::convolution_add_relu_run",
kind_not_in_graph="aten::batch_norm")
kind_not_in_graph="ipex::batch_norm")
self._test_output_bf16(
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
torch.rand(32, 3, 64, 64),
kind_in_graph="ipex_prepack::convolution_add_relu_run",
kind_not_in_graph="aten::batch_norm",
kind_not_in_graph="ipex::batch_norm",
prec=0.02)

def test_output_cascaded_conv_bn_sum_relu_3d(self):
self._test_output(
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
torch.rand(32, 3, 32, 32, 32),
kind_in_graph="ipex::conv3d_sum_relu",
kind_not_in_graph="aten::batch_norm")
kind_not_in_graph="ipex::batch_norm")
self._test_output_bf16(
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
torch.rand(32, 3, 32, 32, 32),
kind_in_graph="ipex::conv3d_sum_relu",
kind_not_in_graph="aten::batch_norm",
kind_not_in_graph="ipex::batch_norm",
prec=0.02)

def test_output_conv_transpose2d(self):
Expand Down Expand Up @@ -1346,6 +1354,17 @@ def test_ipex_softmax(self):
kind_in_graph="ipex::softmax",
prec=5e-3)

def test_ipex_batch_norm(self):
self._test_output(
AtenBatchNormRepalce(),
torch.rand(10, 10, 4, 4),
kind_in_graph="ipex::batch_norm")
self._test_output_bf16(
AtenBatchNormRepalce(),
torch.rand(10, 10, 4, 4, dtype=torch.bfloat16),
kind_in_graph="ipex::batch_norm",
prec=5e-3)

def test_restore_inplace(self):
class M(nn.Module):
def __init__(self, eltwise_fn, params_dict={}):
Expand Down
3 changes: 3 additions & 0 deletions torch_ipex/csrc/jit/fusion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
// replace aten::softmax with ipex::softmax
graph_rewrite::replaceAtenSoftmaxWithIpexSoftmax(graph);

// replace aten::batch_norm with ipex::batch_norm, it will be removed
// after TensorExprs fix the performance issue(IPB-808).
graph_rewrite::replaceAtenBatchNormWithIpexBatchNorm(graph);
// TODO: Some post processing?? ECS/EDC/Peephole???
ConstantPropagation(graph);
}
Expand Down
21 changes: 18 additions & 3 deletions torch_ipex/csrc/jit/graph_rewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ void FuseAddLayerNorm(std::shared_ptr<Graph>& graph) {
graph(%add_a, %add_b, %alpha, %shape:int[], %w, %b, %eps:float, %cudnn_enable:bool):
%r = ipex::add_layernorm(%add_a, %add_b, %alpha, %shape, %w, %b, %eps, %cudnn_enable)
return (%r) )";
SubgraphRewriter rewriter_aten;
IpexSubgraphRewriter rewriter_aten;
rewriter_aten.RegisterRewritePattern(aten_add_layernorm, fused_add_layernorm);
rewriter_aten.runOnGraph(graph);
}
Expand All @@ -346,7 +346,7 @@ void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph) {
%scores = ipex::mha_scores_calc(%q, %k, %relative_qk, %alpha, %dim_per_head, %softmax_dim, %dtype)
return (%scores) )";

SubgraphRewriter mha_fusion;
IpexSubgraphRewriter mha_fusion;
mha_fusion.RegisterRewritePattern(
div_matmul_add_softmax, div_matmul_add_softmax_fusion);
mha_fusion.RegisterRewritePattern(
Expand Down Expand Up @@ -384,7 +384,22 @@ void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph) {
rewriter_aten.runOnGraph(graph);
}

void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph> &graph) {
void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph) {
std::string batch_norm = R"(
graph(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):
%r = aten::batch_norm(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%r) )";
std::string ipex_batch_norm = R"(
graph(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):
%r = ipex::batch_norm(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
return (%r) )";

IpexSubgraphRewriter rewriter_batch_norm;
rewriter_batch_norm.RegisterRewritePattern(batch_norm, ipex_batch_norm);
rewriter_batch_norm.runOnGraph(graph);
}

void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph) {
std::string qembedingbag = R"(
graph(%weight, %input, %offsets, %sparse, %include_last_offset, %o_scale, %o_zp, %o_dtype):
%r = ipex::qembedding_bag(%weight, %input, %offsets, %sparse, %include_last_offset, %o_scale, %o_zp, %o_dtype)
Expand Down
7 changes: 4 additions & 3 deletions torch_ipex/csrc/jit/graph_rewrite.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph);
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);

void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph);
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph> &graph);
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph> &graph);
void replaceInteractionWithQInteraction(std::shared_ptr<Graph> &graph);
void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph);
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph>& graph);
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph);
void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph);

void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph);
void fuseConvWithEltwise(std::shared_ptr<Graph>& graph);
Expand Down
21 changes: 21 additions & 0 deletions torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,27 @@ RegisterOperators op({
},
aliasAnalysisFromSchema()),

Operator(
"ipex::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
[](const Node* node) -> Operation {
return [](Stack* stack) {
auto result = at::batch_norm(
(std::move(peek(stack, 0, 9))).toTensor(),
toOptionalTensor(std::move(peek(stack, 1, 9))),
toOptionalTensor(std::move(peek(stack, 2, 9))),
toOptionalTensor(std::move(peek(stack, 3, 9))),
toOptionalTensor(std::move(peek(stack, 4, 9))),
(std::move(peek(stack, 5, 9))).toBool(),
(std::move(peek(stack, 6, 9))).toDouble(),
(std::move(peek(stack, 7, 9))).toDouble(),
(std::move(peek(stack, 8, 9))).toBool());
drop(stack, 9);
pack(stack, std::move(result));
return 0;
};
},
aliasAnalysisFromSchema()),

Operator(
"ipex::qembedding_bag(Tensor weight, Tensor indices, Tensor offsets, "
"bool sparse, bool include_last_offset, "
Expand Down

0 comments on commit b19a9c5

Please sign in to comment.