Skip to content

Commit

Permalink
[onert/core] Add RmsNorm operation (#14161)
Browse files Browse the repository at this point in the history
This commit adds RmsNorm operation to onert core ir.

ONE-DCO-1.0-Signed-off-by: Seockho Kim seockho.kim@samsung.com
  • Loading branch information
seockho-kim authored Oct 7, 2024
1 parent c7bd3d2 commit b768236
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 0 deletions.
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/Operations.Include.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include "ir/operation/ResizeBilinear.h"
#include "ir/operation/ResizeNearestNeighbor.h"
#include "ir/operation/Reverse.h"
#include "ir/operation/RmsNorm.h"
#include "ir/operation/RNN.h"
#include "ir/operation/Select.h"
#include "ir/operation/Shape.h"
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/include/ir/Operations.lst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ OP(Reshape)
OP(ResizeBilinear)
OP(ResizeNearestNeighbor)
OP(Reverse)
OP(RmsNorm)
OP(RNN)
OP(Select)
OP(Shape)
Expand Down
63 changes: 63 additions & 0 deletions runtime/onert/core/include/ir/operation/RmsNorm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __ONERT_IR_OPERATION_RMS_NORM_H__
#define __ONERT_IR_OPERATION_RMS_NORM_H__

#include "ir/Operation.h"
#include "ir/InternalType.h"

namespace onert
{
namespace ir
{
namespace operation
{

class RmsNorm : public Operation
{
public:
enum Input
{
INPUT = 0,
GAMMA
};

struct Param
{
float epsilon;
};

public:
RmsNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param);

public:
void accept(OperationVisitor &v) const override;
OpCode opcode() const final { return OpCode::RmsNorm; }

public:
const Param &param() const { return _param; }

private:
Param _param;
};

} // namespace operation
} // namespace ir
} // namespace onert

#endif // __ONERT_IR_OPERATION_RMS_NORM_H__
7 changes: 7 additions & 0 deletions runtime/onert/core/src/ir/OperationDumper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,13 @@ void OperationDumper::visit(const Reverse &node)
dumpUnaryInputOp(node, axis);
}

void OperationDumper::visit(const RmsNorm &node)
{
std::string inputs =
"Gamma(" + std::to_string(node.getInputs().at(RmsNorm::Input::GAMMA).value()) + ")";
dumpUnaryInputOp(node, inputs);
}

void OperationDumper::visit(const RNN &node)
{
VERBOSE(LIR) << "* RNN" << std::endl;
Expand Down
1 change: 1 addition & 0 deletions runtime/onert/core/src/ir/OperationDumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class OperationDumper : public OperationVisitor
void visit(const operation::ResizeBilinear &) override;
void visit(const operation::ResizeNearestNeighbor &) override;
void visit(const operation::Reverse &) override;
void visit(const operation::RmsNorm &) override;
void visit(const operation::RNN &) override;
void visit(const operation::Select &node) override;
void visit(const operation::Shape &node) override;
Expand Down
37 changes: 37 additions & 0 deletions runtime/onert/core/src/ir/operation/RmsNorm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ir/operation/RmsNorm.h"
#include "ir/OperationVisitor.h"

namespace onert
{
namespace ir
{
namespace operation
{

void RmsNorm::accept(OperationVisitor &v) const { v.visit(*this); }

RmsNorm::RmsNorm(const OperandIndexSequence &inputs, const OperandIndexSequence &outputs,
const Param &param)
: Operation{OperandConstraint::createExact(2u), inputs, outputs}, _param{param}
{
}

} // namespace operation
} // namespace ir
} // namespace onert
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ operation::Reverse generateReverse()
return operation::Reverse{OperandIndexSequence{1, 2}, OperandIndexSequence{0}};
}

operation::RmsNorm generateRmsNorm()
{
operation::RmsNorm::Param param;
param.epsilon = 0.f;

return operation::RmsNorm{OperandIndexSequence{1, 2}, OperandIndexSequence{0}, param};
}

operation::RNN generateRNN()
{
operation::RNN::Param param;
Expand Down Expand Up @@ -750,6 +758,9 @@ TEST(UntrainableOperation, testAllOps)
const auto reverse = generateReverse();
verifyOp(reverse);

const auto rms_norm = generateRmsNorm();
verifyOp(rms_norm);

const auto rnn = generateRNN();
verifyOp(rnn);

Expand Down Expand Up @@ -1123,6 +1134,12 @@ TEST(UntrainableOperation, neg_TrainableOperationVisitor)
EXPECT_ANY_THROW(visitor.invoke(*untrainable));
}

{
const auto rms_norm = generateRmsNorm();
auto untrainable = generateUntrainableOperation(rms_norm);
EXPECT_ANY_THROW(visitor.invoke(*untrainable));
}

{
const auto rnn = generateRNN();
auto untrainable = generateUntrainableOperation(rnn);
Expand Down

0 comments on commit b768236

Please sign in to comment.