Skip to content

Commit

Permalink
[onert] Apply softmax to CategoricalCrossEntropy automatically (#14105)
Browse files Browse the repository at this point in the history
This commit apply softmax automatically when using CategoricalCrossEntropy loss if models to be trained are not applied softmax.

ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
  • Loading branch information
ragmani authored Oct 7, 2024
1 parent c82f8cb commit 190d26f
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 20 deletions.
8 changes: 7 additions & 1 deletion runtime/onert/api/nnfw/include/nnfw_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,13 @@ typedef struct nnfw_train_info
float learning_rate = 0.001f;
/** Batch size */
uint32_t batch_size = 1;
/** loss info */
/** loss info
* Note that you don't need to worry about whether the model you use does not include softmax
* when you try to use NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY. Using
* NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY will ensure that the predicted input of loss is
* the result of performing softmax once regardless of whether the output of the model is
* the result of softmax or not.
*/
nnfw_loss_info loss_info{.loss = NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR,
.reduction_type = NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE};
/** optimizer type */
Expand Down
5 changes: 4 additions & 1 deletion runtime/onert/backend/train/KernelGenerator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,9 +421,12 @@ void KernelGenerator::visit(const ir::train::operation::Loss &node)
}
case ir::train::LossCode::CategoricalCrossentropy:
{
const auto y_pred_op_code = node.y_pred_op_code();
bool is_normalization_required = (y_pred_op_code != ir::OpCode::Softmax);
auto fn = std::make_unique<ops::LossCategoricalCrossentropyLayer>();
fn->configure(y_pred_tensor, y_true_tensor, output_tensor, back_prop_y_pred_tensor,
reduction_type, loss_param.cce.axis, loss_param.cce.label_smoothing);
reduction_type, loss_param.cce.axis, loss_param.cce.label_smoothing,
is_normalization_required);
_return_fn = std::move(fn);
break;
}
Expand Down
30 changes: 20 additions & 10 deletions runtime/onert/backend/train/ops/LossCategoricalCrossentropyLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ namespace train
namespace ops
{

void LossCategoricalCrossentropyLayer::configure(const IPortableTensor *y_pred,
const IPortableTensor *y_true,
IPortableTensor *output,
IPortableTensor *back_prop_y_pred,
ir::train::LossReductionType reduction_type,
int32_t axis, float label_smoothing)
void LossCategoricalCrossentropyLayer::configure(
const IPortableTensor *y_pred, const IPortableTensor *y_true, IPortableTensor *output,
IPortableTensor *back_prop_y_pred, ir::train::LossReductionType reduction_type, int32_t axis,
float label_smoothing, bool is_normalization_required)
{
LossLayer::configure(y_pred, y_true, output, back_prop_y_pred, reduction_type);

_axis = axis;
_label_smoothing = label_smoothing;
_is_normalization_required = is_normalization_required;
}

void LossCategoricalCrossentropyLayer::forward(bool)
Expand All @@ -59,12 +58,23 @@ void LossCategoricalCrossentropyLayer::backward()
{
assert(_back_prop_y_pred != nullptr);

const auto reduction_type = convertLossReductionType(_reduction_type);
if (_y_pred->data_type() == OperandType::FLOAT32)
{
nnfw::cker::train::CategoricalCrossEntropyGrad(
getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true), getBuffer<float>(_y_true),
getShape(_back_prop_y_pred), getBuffer<float>(_back_prop_y_pred), reduction_type);
const auto reduction_type = convertLossReductionType(_reduction_type);
if (_is_normalization_required)
{
// TODO Eliminate duplicate calculations for output
nnfw::cker::train::CategoricalCrossEntropyWithLogits(
getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true), getBuffer<float>(_y_true),
getShape(_output), getBuffer<float>(_output), getShape(_back_prop_y_pred),
getBuffer<float>(_back_prop_y_pred), reduction_type);
}
else
{
nnfw::cker::train::CategoricalCrossEntropyGrad(
getShape(_y_pred), getBuffer<float>(_y_pred), getShape(_y_true), getBuffer<float>(_y_true),
getShape(_back_prop_y_pred), getBuffer<float>(_back_prop_y_pred), reduction_type);
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ class LossCategoricalCrossentropyLayer : public LossLayer

void configure(const IPortableTensor *y_pred, const IPortableTensor *y_true,
IPortableTensor *output, IPortableTensor *back_prop_y_pred,
ir::train::LossReductionType reduction_type, int32_t axis, float label_smoothing);
ir::train::LossReductionType reduction_type, int32_t axis, float label_smoothing,
bool is_normalization_required);
void forward(bool training) override;
void backward() override;

private:
int32_t _axis{-1};
float _label_smoothing{0.0f};
bool _is_normalization_required{false};
};

} // namespace ops
Expand Down
4 changes: 3 additions & 1 deletion runtime/onert/core/include/ir/train/operation/Loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Loss : public ir::operation::Loss, public TrainableOperation
using OperationType = ir::operation::Loss;

public:
Loss(const OperationType &operation, const LossInfo &info);
Loss(const OperationType &operation, const LossInfo &info, ir::OpCode y_pred_op_code);

public:
std::unique_ptr<ITrainableOperation> clone() const override;
Expand All @@ -49,9 +49,11 @@ class Loss : public ir::operation::Loss, public TrainableOperation

public:
const LossInfo &param() const { return _param; }
ir::OpCode y_pred_op_code() const { return _y_pred_op_code; }

private:
LossInfo _param;
ir::OpCode _y_pred_op_code; // The op code of the last node computing y_pred
};

} // namespace operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ void TrainableOperationConverter::visit(const ir::operation::FullyConnected &nod

void TrainableOperationConverter::visit(const ir::operation::Loss &node)
{
_return_op = std::make_unique<ir::train::operation::Loss>(node, _training_info->lossInfo());
const auto &y_pred_index = node.getInputs().at(ir::operation::Loss::Input::Y_PRED);
const auto &y_pred = _tgraph.operands().at(y_pred_index);
const auto &y_pred_node = _tgraph.operations().at(y_pred.getDef());
const auto y_pred_op_code = y_pred_node.opcode();

_return_op =
std::make_unique<ir::train::operation::Loss>(node, _training_info->lossInfo(), y_pred_op_code);
}

void TrainableOperationConverter::visit(const ir::operation::Pad &node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,14 @@ void LossInsertionPass::run()
auto output_index = _trainable_graph.addOperand(output_shape, float_op);
ir::OperandIndexSequence outputs{output_index};

// The y_pred node information may be required in some loss layers (e.g.,
// CategoricalCrossEntropy(SoftmaxCrossEntropy));
const auto &y_pred_node = _trainable_graph.operations().at(y_pred.getDef());
const auto y_pred_op_code = y_pred_node.opcode();

auto loss_op = std::make_unique<ir::operation::Loss>(inputs, outputs);
auto trainable_loss_op = std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info);
auto trainable_loss_op =
std::make_unique<ir::train::operation::Loss>(*loss_op, loss_info, y_pred_op_code);
trainable_loss_op->enableBackward();

_trainable_graph.addOperation(std::move(trainable_loss_op));
Expand Down
8 changes: 7 additions & 1 deletion runtime/onert/core/src/ir/train/TrainableGraph.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,14 @@ OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandInde
const OperandIndexSequence outputs)
{
// Add "Loss" operation
const auto &y_pred_index = inputs.at(0);
const auto &y_pred = tgraph.operands().at(y_pred_index);
const auto &y_pred_node = tgraph.operations().at(y_pred.getDef());
const auto y_pred_op_code = y_pred_node.opcode();

auto loss_op = operation::Loss(inputs, outputs);
return tgraph.addOperation(std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{}));
return tgraph.addOperation(
std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{}, y_pred_op_code));
}

TEST(TrainableGraph, topological_sort_linear)
Expand Down
8 changes: 7 additions & 1 deletion runtime/onert/core/src/ir/train/UseDefGenerator.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,14 @@ OperationIndex addLossOperation(train::TrainableGraph &tgraph, const OperandInde
const OperandIndexSequence outputs)
{
// Add "Loss" operation
const auto &y_pred_index = inputs.at(0);
const auto &y_pred = tgraph.operands().at(y_pred_index);
const auto &y_pred_node = tgraph.operations().at(y_pred.getDef());
const auto y_pred_op_code = y_pred_node.opcode();

auto loss_op = operation::Loss(inputs, outputs);
return tgraph.addOperation(std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{}));
return tgraph.addOperation(
std::make_unique<train::operation::Loss>(loss_op, train::LossInfo{}, y_pred_op_code));
}

train::UseDefChain createUseDefChain(const Operand &operand,
Expand Down
5 changes: 3 additions & 2 deletions runtime/onert/core/src/ir/train/operation/Loss.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ void Loss::accept(OperationVisitor &v) const { v.visit(*this); }

void Loss::accept(TrainableOperationVisitor &v) const { v.visit(*this); }

Loss::Loss(const OperationType &operation, const LossInfo &param)
: OperationType{operation.getInputs(), operation.getOutputs()}, _param{param}
Loss::Loss(const OperationType &operation, const LossInfo &param, ir::OpCode y_pred_op_code)
: OperationType{operation.getInputs(), operation.getOutputs()}, _param{param},
_y_pred_op_code{y_pred_op_code}
{
// DO NOTHING
}
Expand Down

0 comments on commit 190d26f

Please sign in to comment.