diff --git a/oneflow/core/ep/common/primitive/elementwise_unary.h b/oneflow/core/ep/common/primitive/elementwise_unary.h index b6e5c741973..a3a01e69799 100644 --- a/oneflow/core/ep/common/primitive/elementwise_unary.h +++ b/oneflow/core/ep/common/primitive/elementwise_unary.h @@ -88,6 +88,16 @@ namespace primitive { OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu) +#define UNARY_COMPLEX_C2C_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj) + +#define UNARY_COMPLEX_C2R_OP_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReal) \ + OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImag) + +#define UNARY_COMPLEX_R2C_OP_SEQ \ + OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRealGrad) \ + OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImagGrad) + #define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs) #define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot) diff --git a/oneflow/core/ep/common/primitive/unary_functor.h b/oneflow/core/ep/common/primitive/unary_functor.h index c973f2acd60..7ca89872673 100644 --- a/oneflow/core/ep/common/primitive/unary_functor.h +++ b/oneflow/core/ep/common/primitive/unary_functor.h @@ -572,6 +572,40 @@ struct UnaryFunctor { OF_DEVICE_FUNC Dst operator()(bool src) const { return static_cast(!src); } }; +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.real(), -src.imag()}; } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.real()); } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.imag()); } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; } +}; + } // namespace primitive } // namespace ep } // namespace oneflow diff --git a/oneflow/core/ep/cpu/primitive/elementwise_unary.cpp b/oneflow/core/ep/cpu/primitive/elementwise_unary.cpp index b3455d673ec..4a35389d485 100644 --- a/oneflow/core/ep/cpu/primitive/elementwise_unary.cpp +++ b/oneflow/core/ep/cpu/primitive/elementwise_unary.cpp @@ -92,6 +92,19 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ) + // For Complex Type OP + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, + UNARY_COMPLEX_C2C_OP_SEQ, + CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) + + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ, + CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ) + + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ, + CPU_PRIMITIVE_FLOATING_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ) + // For Int Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_INT_MATH_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ) diff --git a/oneflow/core/ep/cpu/primitive/unary_functor.h b/oneflow/core/ep/cpu/primitive/unary_functor.h index 70adcae0a36..5c5a236df07 100644 --- a/oneflow/core/ep/cpu/primitive/unary_functor.h +++ b/oneflow/core/ep/cpu/primitive/unary_functor.h @@ -388,6 +388,23 @@ struct UnaryFunctor { OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isnan(src); } }; +// avoid warning: narrowing conversion +template<> +struct UnaryFunctor, double> { + UnaryFunctor(Scalar attr0, Scalar attr1) {} + std::complex operator()(double src) const { + return std::complex{static_cast(src), 0.0f}; + } +}; + +template<> +struct UnaryFunctor, double> { + UnaryFunctor(Scalar attr0, Scalar attr1) {} + std::complex operator()(double src) const { + return std::complex{0.0f, static_cast(src)}; + } +}; + } // namespace primitive } // namespace ep } // namespace oneflow diff --git a/oneflow/core/ep/cuda/primitive/elementwise_unary.cu b/oneflow/core/ep/cuda/primitive/elementwise_unary.cu index a79c7155f05..d5ea092224c 100644 --- a/oneflow/core/ep/cuda/primitive/elementwise_unary.cu +++ b/oneflow/core/ep/cuda/primitive/elementwise_unary.cu @@ -86,6 +86,19 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory { UNARY_FLOATING_MATH_OP_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) + // For Complex Type OP + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, + UNARY_COMPLEX_C2C_OP_SEQ, + CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) + + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ, + CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ) + + OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( + MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ, + CUDA_PRIMITIVE_FLOATING_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ) + // For Int Type OP OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_INT_MATH_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ) diff --git a/oneflow/core/ep/cuda/primitive/unary_functor.cuh b/oneflow/core/ep/cuda/primitive/unary_functor.cuh index 5f1a9fd17b7..06e6b730598 100644 --- a/oneflow/core/ep/cuda/primitive/unary_functor.cuh +++ b/oneflow/core/ep/cuda/primitive/unary_functor.cuh @@ -596,6 +596,60 @@ struct UnaryFunctor= 11000 /*********float complex dtype support*********/ +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.x, -src.y}; } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.x); } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast(src.y); } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; } +}; + +template +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; } +}; + +// avoid warning: narrowing conversion +template<> +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC cuComplex operator()(double src) const { + return cuComplex{static_cast(src), 0.0f}; + } +}; + +template<> +struct UnaryFunctor { + OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} + + OF_DEVICE_FUNC cuComplex operator()(double src) const { + return cuComplex{0.0f, static_cast(src)}; + } +}; + template struct UnaryFunctor { OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {} diff --git a/oneflow/core/ep/include/primitive/unary_op.h b/oneflow/core/ep/include/primitive/unary_op.h index cf37d5fd49a..cab540adb4a 100644 --- a/oneflow/core/ep/include/primitive/unary_op.h +++ b/oneflow/core/ep/include/primitive/unary_op.h @@ -97,6 +97,13 @@ enum class UnaryOp { // bitwise op kBitwiseNot, + + // complex op + kConj, + kReal, + kImag, + kRealGrad, + kImagGrad }; } diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 693ba02db13..291a68613e9 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -5478,6 +5478,7 @@ class RealFunctor { RealFunctor() { op_ = CHECK_JUST(one::OpBuilder("real").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { + if (!x->dtype()->is_complex()) { return x; } return OpInterpUtil::Dispatch(*op_, {x}); } @@ -5504,6 +5505,9 @@ class ImagFunctor { ImagFunctor() { op_ = CHECK_JUST(one::OpBuilder("imag").Input("x").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x) const { + CHECK_OR_RETURN(x->dtype()->is_complex()) + << "RuntimeError: imag is implemented for tensors with complex dtypes, but gets" + << x->dtype()->name(); return OpInterpUtil::Dispatch(*op_, {x}); } @@ -5532,6 +5536,7 @@ class ConjFunctor { } Maybe operator()(const std::shared_ptr& x) const { + if (!x->dtype()->is_complex()) { return x; } return OpInterpUtil::Dispatch(*op_, {x}); } diff --git a/oneflow/user/kernels/complex_kernels.cpp b/oneflow/user/kernels/complex_kernels.cpp index 3bf78629a71..2c141c5f503 100644 --- a/oneflow/user/kernels/complex_kernels.cpp +++ b/oneflow/user/kernels/complex_kernels.cpp @@ -15,7 +15,10 @@ limitations under the License. */ #include "oneflow/core/common/shape_view.h" #include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/complex_kernels_util.h" +#include "oneflow/core/ep/include/primitive/elementwise_unary.h" +#include "oneflow/core/ep/include/primitive/primitive.h" +#include "oneflow/core/ep/include/primitive/unary_op.h" +#include "oneflow/user/kernels/elementwise_primitive_kernel.h" #include #ifdef WITH_CUDA #include @@ -24,170 +27,43 @@ limitations under the License. namespace oneflow { namespace user_op { -template -class RealKernel final : public user_op::OpKernel { - public: - RealKernel() = default; - ~RealKernel() = default; - - private: - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); - user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); - if (out_tensor->shape_view().elem_cnt() == 0) { return; } - const dtype_x* x = x_tensor->dptr(); - dtype_out* out = out_tensor->mut_dptr(); - RealFunctor()(ctx->stream(), x, out, - out_tensor->shape_view().elem_cnt()); - } -}; - -#define REGISTER_REAL_KERNEL(device, dtype_x, dtype_out) \ - REGISTER_USER_KERNEL("real") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("x", 0) == GetDataType::value)); - -REGISTER_REAL_KERNEL(DeviceType::kCPU, std::complex, float) -REGISTER_REAL_KERNEL(DeviceType::kCPU, std::complex, double) -#ifdef WITH_CUDA -REGISTER_REAL_KERNEL(DeviceType::kCUDA, cuComplex, float) -REGISTER_REAL_KERNEL(DeviceType::kCUDA, cuDoubleComplex, double) -#endif // WITH_CUDA - -template -class RealGradKernel final : public user_op::OpKernel { - public: - RealGradKernel() = default; - ~RealGradKernel() = default; - - private: - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* dout_tensor = ctx->Tensor4ArgNameAndIndex("dout", 0); - user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); - if (dx_tensor->shape_view().elem_cnt() == 0) { return; } - const dtype_dout* dout = dout_tensor->dptr(); - dtype_dx* dx = dx_tensor->mut_dptr(); - RealGradFunctor()(ctx->stream(), dout, dx, - dx_tensor->shape_view().elem_cnt()); - } -}; - -#define REGISTER_REAL_GRAD_KERNEL(device, dtype_dout, dtype_dx) \ - REGISTER_USER_KERNEL("real_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("dx", 0) == GetDataType::value)); - -REGISTER_REAL_GRAD_KERNEL(DeviceType::kCPU, float, std::complex) -REGISTER_REAL_GRAD_KERNEL(DeviceType::kCPU, double, std::complex) -#ifdef WITH_CUDA -REGISTER_REAL_GRAD_KERNEL(DeviceType::kCUDA, float, cuComplex) -REGISTER_REAL_GRAD_KERNEL(DeviceType::kCUDA, double, cuDoubleComplex) -#endif // WITH_CUDA - -template -class ImagKernel final : public user_op::OpKernel { - public: - ImagKernel() = default; - ~ImagKernel() = default; - - private: - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); - user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); - if (out_tensor->shape_view().elem_cnt() == 0) { return; } - const dtype_x* x = x_tensor->dptr(); - dtype_out* out = out_tensor->mut_dptr(); - ImagFunctor()(ctx->stream(), x, out, - out_tensor->shape_view().elem_cnt()); - } -}; - -#define REGISTER_IMAG_KERNEL(device, dtype_x, dtype_out) \ - REGISTER_USER_KERNEL("imag") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("x", 0) == GetDataType::value)); - -REGISTER_IMAG_KERNEL(DeviceType::kCPU, std::complex, float) -REGISTER_IMAG_KERNEL(DeviceType::kCPU, std::complex, double) -#ifdef WITH_CUDA -REGISTER_IMAG_KERNEL(DeviceType::kCUDA, cuComplex, float) -REGISTER_IMAG_KERNEL(DeviceType::kCUDA, cuDoubleComplex, double) -#endif // WITH_CUDA - -template -class ImagGradKernel final : public user_op::OpKernel { - public: - ImagGradKernel() = default; - ~ImagGradKernel() = default; - - private: - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* dout_tensor = ctx->Tensor4ArgNameAndIndex("dout", 0); - user_op::Tensor* dx_tensor = ctx->Tensor4ArgNameAndIndex("dx", 0); - if (dx_tensor->shape_view().elem_cnt() == 0) { return; } - const dtype_dout* dout = dout_tensor->dptr(); - dtype_dx* dx = dx_tensor->mut_dptr(); - ImagGradFunctor()(ctx->stream(), dout, dx, - dx_tensor->shape_view().elem_cnt()); - } -}; - -#define REGISTER_IMAG_GRAD_KERNEL(device, dtype_dout, dtype_dx) \ - REGISTER_USER_KERNEL("imag_grad") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("dx", 0) == GetDataType::value)); - -REGISTER_IMAG_GRAD_KERNEL(DeviceType::kCPU, float, std::complex) -REGISTER_IMAG_GRAD_KERNEL(DeviceType::kCPU, double, std::complex) -#ifdef WITH_CUDA -REGISTER_IMAG_GRAD_KERNEL(DeviceType::kCUDA, float, cuComplex) -REGISTER_IMAG_GRAD_KERNEL(DeviceType::kCUDA, double, cuDoubleComplex) -#endif // WITH_CUDA - -template -class ConjPhysicalKernel final : public user_op::OpKernel { - public: - ConjPhysicalKernel() = default; - ~ConjPhysicalKernel() = default; - - private: - bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } - - void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* x_tensor = ctx->Tensor4ArgNameAndIndex("x", 0); - user_op::Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0); - if (out_tensor->shape_view().elem_cnt() == 0) { return; } - const dtype* x = x_tensor->dptr(); - dtype* out = out_tensor->mut_dptr(); - ConjPhysicalFunctor()(ctx->stream(), x, out, - out_tensor->shape_view().elem_cnt()); - } -}; - -#define REGISTER_CONJ_PHYSICAL_KERNEL(device, dtype) \ - REGISTER_USER_KERNEL("conj_physical") \ - .SetCreateFn>() \ - .SetIsMatchedHob((user_op::HobDeviceType() == device) \ - && (user_op::HobDataType("x", 0) == GetDataType::value)); - -REGISTER_CONJ_PHYSICAL_KERNEL(DeviceType::kCPU, std::complex) -REGISTER_CONJ_PHYSICAL_KERNEL(DeviceType::kCPU, std::complex) -#ifdef WITH_CUDA -REGISTER_CONJ_PHYSICAL_KERNEL(DeviceType::kCUDA, cuComplex) -REGISTER_CONJ_PHYSICAL_KERNEL(DeviceType::kCUDA, cuDoubleComplex) -#endif // WITH_CUDA +#define COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ("conj_physical", ep::primitive::UnaryOp::kConj) \ + OF_PP_MAKE_TUPLE_SEQ("real", ep::primitive::UnaryOp::kReal) \ + OF_PP_MAKE_TUPLE_SEQ("imag", ep::primitive::UnaryOp::kImag) + +#define COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ \ + OF_PP_MAKE_TUPLE_SEQ("real_grad", ep::primitive::UnaryOp::kRealGrad) \ + OF_PP_MAKE_TUPLE_SEQ("imag_grad", ep::primitive::UnaryOp::kImagGrad) + +#define REGISTER_COMPLEX_KERNEL(name, UnaryOp) \ + REGISTER_USER_KERNEL(name) \ + .SetCreateFn([]() { \ + return user_op::NewOpKernel( \ + "out", "x", [](user_op::KernelComputeContext* ctx) { \ + const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); \ + const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); \ + return ep::primitive::NewPrimitive( \ + ctx->device_type(), UnaryOp, src->data_type(), dst->data_type()); \ + }); \ + }) \ + .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, "out", "x")); +OF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_KERNEL, COMPLEX_UNARY_ELEMENTWISE_PRIMITIVE_SEQ) + +#define REGISTER_COMPLEX_GRAD_KERNEL(name, UnaryOp) \ + REGISTER_USER_KERNEL(name) \ + .SetCreateFn([]() { \ + return user_op::NewOpKernel( \ + "dx", "dout", [](user_op::KernelComputeContext* ctx) { \ + const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0); \ + const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dout", 0); \ + return ep::primitive::NewPrimitive( \ + ctx->device_type(), UnaryOp, src->data_type(), dst->data_type()); \ + }); \ + }) \ + .SetIsMatchedHob(UnaryPrimitiveExists(UnaryOp, "dx", "dout")); + +OF_PP_FOR_EACH_TUPLE(REGISTER_COMPLEX_GRAD_KERNEL, COMPLEX_UNARY_GRAD_ELEMENTWISE_PRIMITIVE_SEQ) } // namespace user_op } // namespace oneflow diff --git a/oneflow/user/kernels/complex_kernels_util.cpp b/oneflow/user/kernels/complex_kernels_util.cpp deleted file mode 100644 index 2deeeca5470..00000000000 --- a/oneflow/user/kernels/complex_kernels_util.cpp +++ /dev/null @@ -1,75 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/complex_kernels_util.h" -#include - -namespace oneflow { - -namespace user_op { - -template -struct RealFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt) { - FOR_RANGE(int64_t, i, 0, cnt) { out[i] = x[i].real(); } - } -}; - -INSTANTIATE_REAL_FUNCTOR(DeviceType::kCPU, std::complex, float) -INSTANTIATE_REAL_FUNCTOR(DeviceType::kCPU, std::complex, double) - -template -struct RealGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - FOR_RANGE(int64_t, i, 0, cnt) { dx[i] = dtype_dx{dout[i], 0.0}; } - } -}; - -INSTANTIATE_REAL_GRAD_FUNCTOR(DeviceType::kCPU, float, std::complex) -INSTANTIATE_REAL_GRAD_FUNCTOR(DeviceType::kCPU, double, std::complex) - -template -struct ImagFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt) { - FOR_RANGE(int64_t, i, 0, cnt) { out[i] = x[i].imag(); } - } -}; - -INSTANTIATE_IMAG_FUNCTOR(DeviceType::kCPU, std::complex, float) -INSTANTIATE_IMAG_FUNCTOR(DeviceType::kCPU, std::complex, double) - -template -struct ImagGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - FOR_RANGE(int64_t, i, 0, cnt) { dx[i] = dtype_dx{0.0, dout[i]}; } - } -}; - -INSTANTIATE_IMAG_GRAD_FUNCTOR(DeviceType::kCPU, float, std::complex) -INSTANTIATE_IMAG_GRAD_FUNCTOR(DeviceType::kCPU, double, std::complex) - -template -struct ConjPhysicalFunctor final { - void operator()(ep::Stream* stream, const dtype* x, dtype* out, int64_t cnt) { - FOR_RANGE(int64_t, i, 0, cnt) { out[i] = dtype{x[i].real(), -x[i].imag()}; } - } -}; - -INSTANTIATE_CONJ_PHYSICAL_FUNCTOR(DeviceType::kCPU, std::complex) -INSTANTIATE_CONJ_PHYSICAL_FUNCTOR(DeviceType::kCPU, std::complex) - -} // namespace user_op -} // namespace oneflow diff --git a/oneflow/user/kernels/complex_kernels_util.cu b/oneflow/user/kernels/complex_kernels_util.cu deleted file mode 100644 index fb3182fee9a..00000000000 --- a/oneflow/user/kernels/complex_kernels_util.cu +++ /dev/null @@ -1,104 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifdef WITH_CUDA -#include "oneflow/core/device/cuda_util.h" -#include "oneflow/core/framework/framework.h" -#include "oneflow/user/kernels/complex_kernels_util.h" -#include - -namespace oneflow { - -namespace user_op { - -template -__global__ void RealCUDA(const dtype_x* x, dtype_out* out, int64_t cnt) { - CUDA_1D_KERNEL_LOOP(i, cnt) { out[i] = x[i].x; } -} - -template -__global__ void RealGradCUDA(const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - CUDA_1D_KERNEL_LOOP(i, cnt) { dx[i] = dtype_dx{dout[i], 0.0}; } -} - -template -__global__ void ImagCUDA(const dtype_x* x, dtype_out* out, int64_t cnt) { - CUDA_1D_KERNEL_LOOP(i, cnt) { out[i] = x[i].y; } -} - -template -__global__ void ImagGradCUDA(const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - CUDA_1D_KERNEL_LOOP(i, cnt) { dx[i] = dtype_dx{0.0, dout[i]}; } -} - -template -__global__ void ConjPhysicalCUDA(const dtype* x, dtype* out, int64_t cnt) { - CUDA_1D_KERNEL_LOOP(i, cnt) { out[i] = dtype{x[i].x, -x[i].y}; } -} - -template -struct RealFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt) { - RUN_CUDA_KERNEL((RealCUDA), stream, cnt, x, out, cnt); - } -}; - -INSTANTIATE_REAL_FUNCTOR(DeviceType::kCUDA, cuComplex, float) -INSTANTIATE_REAL_FUNCTOR(DeviceType::kCUDA, cuDoubleComplex, double) - -template -struct RealGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - RUN_CUDA_KERNEL((RealGradCUDA), stream, cnt, dout, dx, cnt); - } -}; - -INSTANTIATE_REAL_GRAD_FUNCTOR(DeviceType::kCUDA, float, cuComplex) -INSTANTIATE_REAL_GRAD_FUNCTOR(DeviceType::kCUDA, double, cuDoubleComplex) - -template -struct ImagFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt) { - RUN_CUDA_KERNEL((ImagCUDA), stream, cnt, x, out, cnt); - } -}; - -INSTANTIATE_IMAG_FUNCTOR(DeviceType::kCUDA, cuComplex, float) -INSTANTIATE_IMAG_FUNCTOR(DeviceType::kCUDA, cuDoubleComplex, double) - -template -struct ImagGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt) { - RUN_CUDA_KERNEL((ImagGradCUDA), stream, cnt, dout, dx, cnt); - } -}; - -INSTANTIATE_IMAG_GRAD_FUNCTOR(DeviceType::kCUDA, float, cuComplex) -INSTANTIATE_IMAG_GRAD_FUNCTOR(DeviceType::kCUDA, double, cuDoubleComplex) - -template -struct ConjPhysicalFunctor final { - void operator()(ep::Stream* stream, const dtype* x, dtype* out, int64_t cnt) { - RUN_CUDA_KERNEL((ConjPhysicalCUDA), stream, cnt, x, out, cnt); - } -}; - -INSTANTIATE_CONJ_PHYSICAL_FUNCTOR(DeviceType::kCUDA, cuComplex) -INSTANTIATE_CONJ_PHYSICAL_FUNCTOR(DeviceType::kCUDA, cuDoubleComplex) - -} // namespace user_op -} // namespace oneflow - -#endif // WITH_CUDA diff --git a/oneflow/user/kernels/complex_kernels_util.h b/oneflow/user/kernels/complex_kernels_util.h deleted file mode 100644 index d01d037900f..00000000000 --- a/oneflow/user/kernels/complex_kernels_util.h +++ /dev/null @@ -1,65 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_USER_KERNELS_COMPLEX_KERNELS_UTIL_H_ -#define ONEFLOW_USER_KERNELS_COMPLEX_KERNELS_UTIL_H_ - -namespace oneflow { -namespace user_op { - -template -struct RealFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt); -}; - -#define INSTANTIATE_REAL_FUNCTOR(device, dtype_x, dtype_out) \ - template struct RealFunctor; - -template -struct RealGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt); -}; - -#define INSTANTIATE_REAL_GRAD_FUNCTOR(device, dtype_dout, dtype_dx) \ - template struct RealGradFunctor; - -template -struct ImagFunctor final { - void operator()(ep::Stream* stream, const dtype_x* x, dtype_out* out, int64_t cnt); -}; - -#define INSTANTIATE_IMAG_FUNCTOR(device, dtype_x, dtype_out) \ - template struct ImagFunctor; - -template -struct ImagGradFunctor final { - void operator()(ep::Stream* stream, const dtype_dout* dout, dtype_dx* dx, int64_t cnt); -}; - -#define INSTANTIATE_IMAG_GRAD_FUNCTOR(device, dtype_dout, dtype_dx) \ - template struct ImagGradFunctor; - -template -struct ConjPhysicalFunctor final { - void operator()(ep::Stream* stream, const dtype* x, dtype* out, int64_t cnt); -}; - -#define INSTANTIATE_CONJ_PHYSICAL_FUNCTOR(device, dtype) \ - template struct ConjPhysicalFunctor; - -} // namespace user_op -} // namespace oneflow - -#endif // ONEFLOW_USER_KERNELS_COMPLEX_KERNELS_UTIL_H_