diff --git a/oneflow/core/autograd/gradient_funcs/upsample.cpp b/oneflow/core/autograd/gradient_funcs/upsample.cpp index 4a4b7de3d0c..54617c8b113 100644 --- a/oneflow/core/autograd/gradient_funcs/upsample.cpp +++ b/oneflow/core/autograd/gradient_funcs/upsample.cpp @@ -81,6 +81,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample", Upsample); struct UpsampleNearest2DCaptureState : public AutoGradCaptureState { bool requires_grad = false; + bool has_like_input = false; double height_scale = 0.0; double width_scale = 0.0; std::vector output_size; @@ -105,6 +106,8 @@ class UpsampleNearest2D : public OpExprGradFunctiondata_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); + ctx->has_like_input = inputs.size() == 2; + if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } @@ -114,9 +117,15 @@ class UpsampleNearest2D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad( - JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, - ctx->output_size, ctx->data_format)); + if (ctx->has_like_input) { + const std::shared_ptr& like = ctx->SavedTensors().at(1); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format)); + } else { + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest2DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->height_scale, ctx->width_scale, + ctx->output_size, ctx->data_format)); + } return Maybe::Ok(); } @@ -227,6 +236,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_linear_1d", UpsampleLinear1D); struct UpsampleNearest1DCaptureState : public AutoGradCaptureState { bool requires_grad = false; + bool has_like_input = false; double scale_factor = 0.0; std::vector output_size; std::string data_format; @@ -238,7 +248,7 @@ class UpsampleNearest1D : public OpExprGradFunction Capture(UpsampleNearest1DCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, const AttrMap& attrs) const override { - CHECK_EQ_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) + CHECK_GE_OR_RETURN(inputs.size(), 1); // NOLINT(maybe-need-error-msg) CHECK_EQ_OR_RETURN(outputs.size(), 1); // NOLINT(maybe-need-error-msg) ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -249,6 +259,8 @@ class UpsampleNearest1D : public OpExprGradFunctiondata_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); + ctx->has_like_input = inputs.size() == 2; + if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } @@ -258,9 +270,15 @@ class UpsampleNearest1D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - JUST(oneflow::VectorAt(*in_grads, 0)) = JUST( - functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x, - ctx->scale_factor, ctx->output_size, ctx->data_format)); + if (ctx->has_like_input) { + const std::shared_ptr& like = ctx->SavedTensors().at(1); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest1DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format)); + } else { + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST( + functional::UpsampleNearest1DGrad(JUST(oneflow::VectorAt(out_grads, 0)), x, + ctx->scale_factor, ctx->output_size, ctx->data_format)); + } return Maybe::Ok(); } @@ -322,6 +340,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("upsample_bicubic_2d", UpsampleBicubic2D); struct UpsampleNearest3DCaptureState : public AutoGradCaptureState { bool requires_grad = false; + bool has_like_input = false; double depth_scale = 0.0; double height_scale = 0.0; double width_scale = 0.0; @@ -348,6 +367,10 @@ class UpsampleNearest3D : public OpExprGradFunctiondata_format = JUST(composed_attrs.GetAttr("data_format")); ctx->SaveTensorForBackward(inputs.at(0)); + ctx->has_like_input = inputs.size() == 2; + if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); } + ctx->has_like_input = inputs.size() == 2; + if (ctx->has_like_input) { ctx->SaveTensorForBackward(inputs.at(1)); } return Maybe::Ok(); } @@ -357,9 +380,15 @@ class UpsampleNearest3D : public OpExprGradFunction& x = ctx->SavedTensors().at(0); in_grads->resize(1); - JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad( - JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, - ctx->width_scale, ctx->output_size, ctx->data_format)); + if (ctx->has_like_input) { + const std::shared_ptr& like = ctx->SavedTensors().at(1); + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, like, ctx->data_format)); + } else { + JUST(oneflow::VectorAt(*in_grads, 0)) = JUST(functional::UpsampleNearest3DGrad( + JUST(oneflow::VectorAt(out_grads, 0)), x, ctx->depth_scale, ctx->height_scale, + ctx->width_scale, ctx->output_size, ctx->data_format)); + } return Maybe::Ok(); } diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3323db0e93c..3a27be341ef 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1777,27 +1777,35 @@ bind_python: False - name: "upsample_nearest_1d" - signature: - 'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, - String data_format="channels_first") => UpsampleNearest1D' + signature: [ + 'Tensor (Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, + String data_format="channels_first") => UpsampleNearest1D', + 'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest1D' + ] bind_python: True - name: "upsample_nearest_1d_grad" - signature: - 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, - String data_format="channels_first") => UpsampleNearest1DGrad' + signature: [ + 'Tensor (Tensor dy, Tensor x, Double scale_factor=0.0, Int64List[1] output_size=None, + String data_format="channels_first") => UpsampleNearest1DGrad', + 'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest1DGrad' + ] bind_python: False - name: "upsample_nearest_2d" - signature: - 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, - String data_format="channels_first") => UpsampleNearest2D' + signature: [ + 'Tensor (Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, + String data_format="channels_first") => UpsampleNearest2D', + 'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest2D' + ] bind_python: True - name: "upsample_nearest_2d_grad" - signature: - 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, - String data_format="channels_first") => UpsampleNearest2DGrad' + signature: [ + 'Tensor (Tensor dy, Tensor x, Double height_scale=0.0, Double width_scale=0.0, Int64List[2] output_size=None, + String data_format="channels_first") => UpsampleNearest2DGrad', + 'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") =>UpsampleNearest2DGrad' + ] bind_python: False - name: "upsample_bilinear_2d" @@ -1825,15 +1833,19 @@ bind_python: False - name: "upsample_nearest_3d" - signature: - 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, - String data_format="channels_first") => UpsampleNearest3D' + signature: [ + 'Tensor (Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, + String data_format="channels_first") => UpsampleNearest3D', + 'Tensor (Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest3D' + ] bind_python: True - name: "upsample_nearest_3d_grad" - signature: - 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, - String data_format="channels_first") => UpsampleNearest3DGrad' + signature: [ + 'Tensor (Tensor dy, Tensor x, Double depth_scale=0.0, Double height_scale=0.0, Double width_scale=0.0, Int64List[3] output_size=None, + String data_format="channels_first") => UpsampleNearest3DGrad', + 'Tensor (Tensor dy, Tensor x, Tensor like, String data_format="channels_first") => UpsampleNearest3DGrad' + ] bind_python: False - name: "upsample_trilinear_3d" @@ -3494,3 +3506,8 @@ - name: "fused_clip_grad" signature: "Tensor (TensorTuple model_diff, Float max_norm, Float norm_type) => FusedClipGrad" bind_python: True + +- name: "fused_group_norm_quantization" + signature: 'TensorTuple[y, y_scale, y_zero_point] (Tensor x, Tensor gamma=None, Tensor beta=None, Bool affine, Int32 num_groups, Double epsilon=1e-5, String data_format="channels_first", String activation="none", String quantization_scheme="affine", Int32 quantization_bit=8, String quantization_formula="oneflow", String quantization_mode="tensorwise") => FusedGroupNormQuantization' + bind_python: True + diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index aef7ef62a3b..2f1f2db11f6 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -1558,7 +1558,7 @@ class ReshapeFunctor { } } auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("shape"); - attrs.SetAllAttrs(infered_shape); + attrs.SetAllAttrs(shape); return OpInterpUtil::Dispatch(*op_, {x}, attrs); } @@ -1945,6 +1945,24 @@ class UpsampleNearest1DFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike1DFunctor { + public: + UpsampleNearestLike1DFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("upsample_nearest_1d").Input("x").Input("like").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleNearest1DGradFunctor { public: UpsampleNearest1DGradFunctor() { @@ -1968,6 +1986,29 @@ class UpsampleNearest1DGradFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike1DGradFunctor { + public: + UpsampleNearestLike1DGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_1d_grad") + .Input("dy") + .Input("x") + .Input("like") + .Output("dx") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {dy, x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleNearest2DFunctor { public: UpsampleNearest2DFunctor() { @@ -1991,6 +2032,24 @@ class UpsampleNearest2DFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike2DFunctor { + public: + UpsampleNearestLike2DFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("upsample_nearest_2d").Input("x").Input("like").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleNearest2DGradFunctor { public: UpsampleNearest2DGradFunctor() { @@ -2016,6 +2075,29 @@ class UpsampleNearest2DGradFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike2DGradFunctor { + public: + UpsampleNearestLike2DGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_2d_grad") + .Input("dy") + .Input("x") + .Input("like") + .Output("dx") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {dy, x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleBilinear2DFunctor { public: UpsampleBilinear2DFunctor() { @@ -2135,6 +2217,24 @@ class UpsampleNearest3DFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike3DFunctor { + public: + UpsampleNearestLike3DFunctor() { + op_ = CHECK_JUST( + one::OpBuilder("upsample_nearest_3d").Input("x").Input("like").Output("y").Build()); + } + Maybe operator()(const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleNearest3DGradFunctor { public: UpsampleNearest3DGradFunctor() { @@ -2160,6 +2260,29 @@ class UpsampleNearest3DGradFunctor { std::shared_ptr op_; }; +class UpsampleNearestLike3DGradFunctor { + public: + UpsampleNearestLike3DGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("upsample_nearest_3d_grad") + .Input("dy") + .Input("x") + .Input("like") + .Output("dx") + .Build()); + } + Maybe operator()(const std::shared_ptr& dy, + const std::shared_ptr& x, + const std::shared_ptr& like, + const std::string& data_format) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("data_format"); + attrs.SetAllAttrs(data_format); + return OpInterpUtil::Dispatch(*op_, {dy, x, like}, attrs); + } + + private: + std::shared_ptr op_; +}; + class UpsampleTrilinear3DFunctor { public: UpsampleTrilinear3DFunctor() { @@ -4118,18 +4241,24 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("UnfoldTensor"); m.add_functor("UnfoldTensorGrad"); m.add_functor("UpsampleGrad"); - m.add_functor("UpsampleNearest2D"); - m.add_functor("UpsampleNearest2DGrad"); + m.add_functor( + "UpsampleNearest2D"); + m.add_functor( + "UpsampleNearest2DGrad"); m.add_functor("UpsampleBilinear2D"); m.add_functor("UpsampleBilinear2DGrad"); m.add_functor("UpsampleLinear1D"); m.add_functor("UpsampleLinear1DGrad"); - m.add_functor("UpsampleNearest1D"); - m.add_functor("UpsampleNearest1DGrad"); + m.add_functor( + "UpsampleNearest1D"); + m.add_functor( + "UpsampleNearest1DGrad"); m.add_functor("UpsampleBicubic2D"); m.add_functor("UpsampleBicubic2DGrad"); - m.add_functor("UpsampleNearest3D"); - m.add_functor("UpsampleNearest3DGrad"); + m.add_functor( + "UpsampleNearest3D"); + m.add_functor( + "UpsampleNearest3DGrad"); m.add_functor("UpsampleTrilinear3D"); m.add_functor("UpsampleTrilinear3DGrad"); m.add_functor("UnsortedSegmentSumLike"); @@ -4211,4 +4340,4 @@ ONEFLOW_FUNCTION_LIBRARY(m) { } // namespace functional } // namespace one -} // namespace oneflow +} // namespace oneflow \ No newline at end of file diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 05f5ea56bc3..1d1f9de2001 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -11295,7 +11295,8 @@ def OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<"upsample_linear_1d_grad", [ def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$y @@ -11314,7 +11315,8 @@ def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoMemor def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$dx @@ -11332,7 +11334,8 @@ def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$y @@ -11352,7 +11355,8 @@ def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoMemor def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$dx @@ -11371,7 +11375,8 @@ def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$y @@ -11392,7 +11397,8 @@ def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoMemor def OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<"upsample_nearest_3d_grad", [NoMemoryEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$dy, - OneFlow_Tensor:$x + OneFlow_Tensor:$x, + Optional:$like ); let output = (outs OneFlow_Tensor:$dx diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cpp b/oneflow/user/kernels/upsample_nearest_kernel.cpp index 70d0d3041bd..20699fe113e 100644 --- a/oneflow/user/kernels/upsample_nearest_kernel.cpp +++ b/oneflow/user/kernels/upsample_nearest_kernel.cpp @@ -133,7 +133,7 @@ class UpsampleNearest1DCPUKernel final : public user_op::OpKernel { const int64_t channels = x_tensor->shape_view().At(1); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); } @@ -173,7 +173,7 @@ class UpsampleNearestGrad1DCPUKernel final : public user_op::OpKernel { const int64_t channels = dx_tensor->shape_view().At(1); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { @@ -227,7 +227,7 @@ class UpsampleNearest2DCPUKernel final : public user_op::OpKernel { const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } @@ -273,7 +273,7 @@ class UpsampleNearest2DGradCPUKernel final : public user_op::OpKernel { const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } @@ -330,7 +330,7 @@ class UpsampleNearest3DCPUKernel final : public user_op::OpKernel { const int64_t out_height = y_blob->shape_view().At(3); const int64_t out_width = y_blob->shape_view().At(4); const int64_t elem_cnt = y_blob->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); @@ -373,7 +373,7 @@ class UpsampleNearestGrad3DCPUKernel final : public user_op::OpKernel { const int64_t out_height = dy_blob->shape_view().At(3); const int64_t out_width = dy_blob->shape_view().At(4); const int64_t elem_cnt = dy_blob->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); diff --git a/oneflow/user/kernels/upsample_nearest_kernel.cu b/oneflow/user/kernels/upsample_nearest_kernel.cu index c44ff9e3b03..b7199423cc3 100644 --- a/oneflow/user/kernels/upsample_nearest_kernel.cu +++ b/oneflow/user/kernels/upsample_nearest_kernel.cu @@ -185,7 +185,7 @@ class UpsampleNearest1DGPUKernel final : public user_op::OpKernel { const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); const int64_t in_height = x_tensor->shape_view().At(2); const int64_t out_height = y_tensor->shape_view().At(2); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { @@ -222,7 +222,7 @@ class UpsampleNearestGrad1DGPUKernel final : public user_op::OpKernel { const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); const int64_t in_height = dx_tensor->shape_view().At(2); const int64_t out_height = dy_tensor->shape_view().At(2); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); } if (in_height == out_height) { @@ -280,7 +280,7 @@ class UpsampleNearest2DGPUKernel final : public user_op::OpKernel, const int64_t in_width = x_tensor->shape_view().At(3); const int64_t out_height = y_tensor->shape_view().At(2); const int64_t out_width = y_tensor->shape_view().At(3); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } @@ -328,7 +328,7 @@ class UpsampleNearest2DGradGPUKernel final : public user_op::OpKernel { const int64_t in_width = dx_tensor->shape_view().At(3); const int64_t out_height = dy_tensor->shape_view().At(2); const int64_t out_width = dy_tensor->shape_view().At(3); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); } @@ -396,7 +396,7 @@ class UpsampleNearest3DGPUKernel final : public user_op::OpKernel { const int64_t out_height = y_tensor->shape_view().At(3); const int64_t out_width = y_tensor->shape_view().At(4); const int64_t elem_cnt = y_tensor->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); @@ -440,7 +440,7 @@ class UpsampleNearestGrad3DGPUKernel final : public user_op::OpKernel { const int64_t out_height = dy_tensor->shape_view().At(3); const int64_t out_width = dy_tensor->shape_view().At(4); const int64_t elem_cnt = dy_tensor->shape_view().elem_cnt(); - if (!output_size.empty()) { + if (!output_size.empty() || ctx->Tensor4ArgNameAndIndex("like", 0)) { depth_scale = static_cast(out_depth) / static_cast(in_depth); height_scale = static_cast(out_height) / static_cast(in_height); width_scale = static_cast(out_width) / static_cast(in_width); diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index c533ddc3e27..7ef1303f9af 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -23,6 +23,18 @@ typename std::enable_if<(N <= 3), Maybe>::type UpsamplingInferLogicalDesc( user_op::InferContext* ctx, const std::string& func_name) { const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); user_op::TensorDesc* y_desc = ctx->MutOutputTensorDesc("y", 0); + if (ctx->has_input("like", 0)) { + const user_op::TensorDesc& like_desc = ctx->InputTensorDesc("like", 0); + int64_t like_num_axes = like_desc.shape().NumAxes(); + CHECK_GT_OR_RETURN(like_num_axes, N) + << "like shape size should > " << N << ", but got " << like_desc.shape().ToString(); + Shape output_shape = x_desc.shape(); + for (int i = 0; i < N; ++i) { + output_shape[i + 2] = like_desc.shape().At(like_num_axes - N + i); + } + y_desc->set_shape(output_shape); + return Maybe::Ok(); + } if (N == 1) { CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" && x_desc.shape().NumAxes() == (N + 2)) @@ -135,7 +147,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest2DOp::GetSbp(user_op::SbpContext* ctx) { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(user_op::OpArg("y", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { @@ -234,11 +246,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(user_op::OpArg("dx", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { @@ -259,11 +267,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(user_op::OpArg("dx", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { @@ -335,11 +339,7 @@ namespace oneflow { } /*static*/ Maybe UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(user_op::OpArg("dx", 0), 0).Build(); return Maybe::Ok(); } /*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { diff --git a/python/oneflow/nn/functional/__init__.py b/python/oneflow/nn/functional/__init__.py index f731f143037..38a3df9b67c 100644 --- a/python/oneflow/nn/functional/__init__.py +++ b/python/oneflow/nn/functional/__init__.py @@ -14,6 +14,7 @@ limitations under the License. """ from oneflow.nn.modules.interpolate import interpolate +from oneflow.nn.modules.interpolate_like import interpolate_like from oneflow.nn.modules.affine_grid import affine_grid from oneflow.nn.modules.grid_sample import grid_sample from oneflow.nn.modules.sparse_softmax_cross_entropy import sparse_softmax_cross_entropy diff --git a/python/oneflow/nn/modules/interpolate_like.py b/python/oneflow/nn/modules/interpolate_like.py new file mode 100644 index 00000000000..5089e7153de --- /dev/null +++ b/python/oneflow/nn/modules/interpolate_like.py @@ -0,0 +1,153 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import math +import warnings +from typing import Optional, Tuple, Union + +import oneflow as flow +from oneflow.framework.tensor import register_tensor_op +from oneflow.nn.modules.module import Module + + +class InterpolateLike: + def __init__( + self, mode: str = "nearest", align_corners: Optional[bool] = None, + ): + if mode in ("nearest", "area") and align_corners is not None: + raise ValueError( + "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear" + ) + self.mode = mode + if align_corners == None: + align_corners = False + self.align_corners = align_corners + if self.mode not in ( + "nearest", + "bilinear", + "linear", + "area", + "bicubic", + "trilinear", + ): + raise ValueError( + 'interpolation must be "nearest" or "bilinear" or "linear" or "area" or "bicubic" or "trilinear".' + ) + if self.mode == "nearest" and self.align_corners: + raise ValueError('interpolation "nearest" does not support align_corners.') + + def forward(self, x, like): + if len(x.shape) == 3 and self.mode == "bilinear": + raise NotImplementedError("Got 3D input, but bilinear mode needs 4D input") + if len(x.shape) == 3 and self.mode == "trilinear": + raise NotImplementedError("Got 3D input, but trilinear mode needs 5D input") + if len(x.shape) == 4 and self.mode == "linear": + raise NotImplementedError("Got 4D input, but linear mode needs 3D input") + if len(x.shape) == 4 and self.mode == "trilinear": + raise NotImplementedError("Got 4D input, but trilinear mode needs 5D input") + if len(x.shape) == 5 and self.mode == "linear": + raise NotImplementedError("Got 5D input, but linear mode needs 3D input") + if len(x.shape) == 5 and self.mode == "bilinear": + raise NotImplementedError("Got 5D input, but bilinear mode needs 4D input") + + dim = len(x.shape) - 2 + if len(x.shape) == 3 and self.mode == "nearest": + return flow._C.upsample_nearest_1d(x, like, data_format="channels_first",) + if len(x.shape) == 4 and self.mode == "nearest": + return flow._C.upsample_nearest_2d(x, like, data_format="channels_first",) + if len(x.shape) == 5 and self.mode == "nearest": + return flow._C.upsample_nearest_3d(x, like, data_format="channels_first",) + + raise NotImplementedError( + "Input Error: Only 3D, 4D and 5D input Tensors supported" + " (got {}D) for the modes: nearest" + " (got {})".format(len(x.shape), self.mode) + ) + + +def interpolate_like( + input, like, mode="nearest", align_corners=None, +): + """The interface is consistent with PyTorch. + + The documentation is referenced from: https://pytorch.org/docs/1.10/_modules/torch/nn/functional.html#interpolate. + + + Down/up samples the input to :Tensor:`like` shape. + + The algorithm used for interpolation is determined by :attr:`mode`. + + Currently temporal, spatial and volumetric sampling are supported, i.e. + expected inputs are 3-D, 4-D or 5-D in shape. + + The input dimensions are interpreted in the form: + `mini-batch x channels x [optional depth] x [optional height] x width`. + + The modes available for resizing are: `nearest`, `linear` (3D-only), + `bilinear`, `bicubic` (4D-only), `trilinear` (5D-only), `area` + + Args: + input (Tensor): the input tensor + like (Tensor): the like tensor + mode (str): algorithm used for upsampling: + ``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` | + ``'trilinear'`` | ``'area'``. Default: ``'nearest'`` + align_corners (bool, optional): Geometrically, we consider the pixels of the + input and output as squares rather than points. + If set to ``True``, the input and output tensors are aligned by the + center points of their corner pixels, preserving the values at the corner pixels. + If set to ``False``, the input and output tensors are aligned by the corner + points of their corner pixels, and the interpolation uses edge value padding + for out-of-boundary values. This only has an effect when :attr:`mode` + is ``'linear'``, ``'bilinear'``, ``'bicubic'`` or ``'trilinear'``. + Default: ``False`` + + .. note:: + With ``mode='bicubic'``, it's possible to cause overshoot, in other words it can produce + negative values or values greater than 255 for images. + Explicitly call ``result.clamp(min=0, max=255)`` if you want to reduce the overshoot + when displaying the image. + + .. warning:: + With ``align_corners = True``, the linearly interpolating modes + (`linear`, `bilinear`, and `trilinear`) don't proportionally align the + output and input pixels, and thus the output values can depend on the + input size. This was the default behavior for these modes up to version + 0.3.1. Since then, the default behavior is ``align_corners = False``. + See :class:`~torch.nn.Upsample` for concrete examples on how this + affects the outputs. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> import numpy as np + + >>> input = flow.tensor(np.arange(1, 5).reshape((1, 1, 4)), dtype=flow.float32) + >>> like = flow.randn(1, 1, 8) + >>> output = flow.nn.functional.interpolate_like(input, like, mode="linear") + >>> output + tensor([[[1.0000, 1.2500, 1.7500, 2.2500, 2.7500, 3.2500, 3.7500, 4.0000]]], + dtype=oneflow.float32) + + """ + return InterpolateLike(mode=mode, align_corners=align_corners,).forward(input, like) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True)