Skip to content

Commit

Permalink
Add squared relu (#10316)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiezipeng-ML authored Aug 22, 2023
1 parent 2d24fe0 commit 6e1fa45
Show file tree
Hide file tree
Showing 23 changed files with 369 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Non-linear activation functions
selu
celu
leaky_relu
square_relu
prelu
glu
gelu
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Non-linear Activations (weighted sum, nonlinearity)
nn.CELU
nn.GELU
nn.QuickGELU
nn.SquareReLU
nn.SiLU
nn.Sigmoid
nn.Mish
Expand Down
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Pointwise Ops
fmod
gelu
quick_gelu
square_relu
log
log1p
log2
Expand Down
31 changes: 31 additions & 0 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,36 @@ class QuickGeLU : public OpExprGradFunction<QuickGeluCaptureState> {
}
};

struct SquareReLUCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};

class SquareReLU : public OpExprGradFunction<SquareReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(SquareReLUCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg)
CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg)
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SquareReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::SquareReLUGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};

class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
Expand Down Expand Up @@ -638,6 +668,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("fast_gelu", FastGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("quick_gelu", QuickGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("square_relu", SquareReLU);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX)
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareReLUBackwardWithDyX)

#define BINARY_ACTIVATION_BACKWARD_OP_SEQ \
BINARY_ACTIVATION_BACKWARD_OP_SEQ_0 \
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ namespace primitive {
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquareReLU)

#define UNARY_COMPLEX_C2C_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj) \
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kQuickGeluBackwardWithDyX, Src,
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Dst>((x > static_cast<Src>(0.0)) ? static_cast<Src>(2.0) * x * dy
: static_cast<Src>(0.0));
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kQuickGelu, Dst, Src> {
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kSquareReLU, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);
}
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -371,6 +380,7 @@ SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

Expand Down
12 changes: 12 additions & 0 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kQuickGeluBackwardWithDyX, Src
const Src alpha = static_cast<Src>(1.702);
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Dst>((x > static_cast<Src>(0.0)) ? static_cast<Src>(2.0) * x * dy
: static_cast<Src>(0.0));
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -405,6 +415,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);

SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);
Expand Down Expand Up @@ -479,6 +490,7 @@ SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);

SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);
Expand Down
11 changes: 11 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kQuickGelu, Dst, Src> {
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSquareReLU, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const {
return static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);
}
};

namespace unary_functor_internal {

namespace {
Expand Down Expand Up @@ -491,6 +500,7 @@ SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquareReLU);

/*********nv_bfloat16_kernel*******/

Expand Down Expand Up @@ -558,6 +568,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum class BinaryOp {
kTanBackwardWithDyX,
kFastGeluBackwardWithDyX,
kQuickGeluBackwardWithDyX,
kSquareReLUBackwardWithDyX,
};

}
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum class UnaryOp {
kThreshold,
kFastGelu,
kQuickGelu,
kSquareReLU,
// math op
kAbs,
kAcos,
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,14 @@
signature: "Tensor (Tensor dy, Tensor x) => QuickGeluGrad"
bind_python: False

- name: "square_relu"
signature: "Tensor (Tensor x) => SquareReLU"
bind_python: True

- name: "square_relu_grad"
signature: "Tensor (Tensor dy, Tensor x) => SquareReLUGrad"
bind_python: False

- name: "gelu_with_approximate"
signature: 'Tensor (Tensor x, String approximate="none") => GeluWithApproximate'
bind_python: True
Expand Down
17 changes: 17 additions & 0 deletions oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ class QuickGeluGradFunctor : public BinaryFunctor {
}
};

class SquareReLUFunctor : public UnaryFunctor {
public:
SquareReLUFunctor() {
op_ = CHECK_JUST(one::OpBuilder("square_relu").Input("x").Output("y").Build());
}
};

class SquareReLUGradFunctor : public BinaryFunctor {
public:
SquareReLUGradFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("square_relu_grad").Input("dy").Input("x").Output("dx").Build());
}
};

class GluFunctor {
public:
GluFunctor() {}
Expand Down Expand Up @@ -779,6 +794,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::FastGeluGradFunctor>("FastGeluGrad");
m.add_functor<impl::QuickGeluFunctor>("QuickGelu");
m.add_functor<impl::QuickGeluGradFunctor>("QuickGeluGrad");
m.add_functor<impl::SquareReLUFunctor>("SquareReLU");
m.add_functor<impl::SquareReLUGradFunctor>("SquareReLUGrad");
m.add_functor<impl::GluFunctor>("Glu");
m.add_functor<impl::HardSigmoidFunctor>("HardSigmoid");
m.add_functor<impl::HardSigmoidGradFunctor>("HardSigmoidGrad");
Expand Down
27 changes: 27 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ def OneFlow_QuickGeluGradOp : OneFlow_BaseOp<"quick_gelu_grad", [NoMemoryEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_SquareReLUGradOp : OneFlow_BaseOp<"square_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_GridSampleOp : OneFlow_BaseOp<"grid_sample", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$input,
Expand Down Expand Up @@ -10414,6 +10428,19 @@ def OneFlow_QuickGeluOp : OneFlow_BaseOp<"quick_gelu", [NoMemoryEffect, DeclareO
let has_data_type_infer_fn = 1;
}

def OneFlow_SquareReLUOp : OneFlow_BaseOp<"square_relu", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
Expand Down
26 changes: 26 additions & 0 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,32 @@ REGISTER_USER_KERNEL("quick_gelu_grad")
})
.SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX, "dx",
"dy"));
REGISTER_USER_KERNEL("square_relu")
.SetCreateFn([]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kSquareReLU, src->data_type(),
dst->data_type());
});
})
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSquareReLU, "y", "x"));

REGISTER_USER_KERNEL("square_relu_grad")
.SetCreateFn([]() {
return user_op::NewOpKernel<BinaryPrimitiveKernel>(
"dx", "dy", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0);
return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(
ctx->device_type(), ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,
src->data_type(), dst->data_type(), 1 /*max_num_dims*/);
});
})
.SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,
"dx", "dy"));

REGISTER_USER_KERNEL("leaky_relu")
.SetCreateFn([]() {
Expand Down
Loading

0 comments on commit 6e1fa45

Please sign in to comment.