Skip to content

Commit

Permalink
use primitive to implement "conj_physical", "real" and "imag" op (#10281
Browse files Browse the repository at this point in the history
)

### Original requirements
现有的 conj kernel, real kernel, imag kernel 使用 KernelUtil 来进行组织,在 kernel
层面复数数据类型和实数数据类型的调用 conj 无法实现统一。故移除原有 KernelUtil 的实现,使用 ElementwiseUnary
的 primitive 进行实现

- [x] 新增 `kConj` 这一 UnaryOp,使用 ElementwiseUnary 的 Primitive 实现
conj,并移除原有 conj 的 KernelUtil
- [x] 新增 `kReal` 和 `kRealGrad` UnaryOp,使用 ElementwiseUnary 的 Primitive
实现 real 和 real_grad,并移除原有 real 和 real_grad 的 KernelUtil
- [x] 新增 `kImag` 和 `kImagGrad` UnaryOp,使用 ElementwiseUnary 的 Primitive
实现 imag 和 imag_grad,并移除原有 imag 和 imag_grad 的 KernelUtil

## 注意:
复数基础设施建设系列 pr:
1. 使用 primitive 来实现 conj, real 等常见复数算子:
#10281
2. 将现有支持复数数据类型的算子测例迁移到 autotest
模块中,以使复数算子复用实数算子的测试用例:#10284
3. 继续拓展支持复数数据类型的算子,比如 matmul, sqrt, div
等:#10269
依赖关系:
本 pr 基于:[pr2](#10284) 和
[pr3](#10269) 的基础,请优先 merge 此
pr
  • Loading branch information
MarioLulab authored Jun 5, 2023
1 parent a9a339b commit e4d79e8
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 409 deletions.
10 changes: 10 additions & 0 deletions oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ namespace primitive {
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)

#define UNARY_COMPLEX_C2C_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj)

#define UNARY_COMPLEX_C2R_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kReal) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImag)

#define UNARY_COMPLEX_R2C_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kRealGrad) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kImagGrad)

#define UNARY_INT_MATH_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kAbs)

#define UNARY_LOGICAL_OP_SEQ OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kLogicalNot)
Expand Down
34 changes: 34 additions & 0 deletions oneflow/core/ep/common/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,40 @@ struct UnaryFunctor<device, UnaryOp::kBitwiseNot, Dst, bool> {
OF_DEVICE_FUNC Dst operator()(bool src) const { return static_cast<Dst>(!src); }
};

template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kConj, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.real(), -src.imag()}; }
};

template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kReal, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.real()); }
};

template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kImag, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.imag()); }
};

template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kRealGrad, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; }
};

template<DeviceType device, typename Dst, typename Src>
struct UnaryFunctor<device, UnaryOp::kImagGrad, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; }
};

} // namespace primitive
} // namespace ep
} // namespace oneflow
Expand Down
13 changes: 13 additions & 0 deletions oneflow/core/ep/cpu/primitive/elementwise_unary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_FLOATING_MATH_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ CPU_PRIMITIVE_BFLOAT16_TYPE_SEQ)

// For Complex Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_COMPLEX_C2C_OP_SEQ,
CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ,
CPU_PRIMITIVE_COMPLEX_TYPE_SEQ, CPU_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ,
CPU_PRIMITIVE_FLOATING_TYPE_SEQ, CPU_PRIMITIVE_COMPLEX_TYPE_SEQ)

// For Int Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_INT_MATH_OP_SEQ, CPU_PRIMITIVE_INT_TYPE_SEQ)
Expand Down
17 changes: 17 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,23 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kIsNan, bool, bfloat16> {
OF_DEVICE_FUNC bool operator()(bfloat16 src) const { return std::isnan(src); }
};

// avoid warning: narrowing conversion
template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kRealGrad, std::complex<float>, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
std::complex<float> operator()(double src) const {
return std::complex<float>{static_cast<float>(src), 0.0f};
}
};

template<>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kImagGrad, std::complex<float>, double> {
UnaryFunctor(Scalar attr0, Scalar attr1) {}
std::complex<float> operator()(double src) const {
return std::complex<float>{0.0f, static_cast<float>(src)};
}
};

} // namespace primitive
} // namespace ep
} // namespace oneflow
13 changes: 13 additions & 0 deletions oneflow/core/ep/cuda/primitive/elementwise_unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,19 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
UNARY_FLOATING_MATH_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

// For Complex Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_COMPLEX_C2C_OP_SEQ,
CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_C2R_OP_SEQ,
CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ, CUDA_PRIMITIVE_FLOATING_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY, UNARY_COMPLEX_R2C_OP_SEQ,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ, CUDA_PRIMITIVE_COMPLEX_TYPE_SEQ)

// For Int Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY,
UNARY_INT_MATH_OP_SEQ, CUDA_PRIMITIVE_INT_TYPE_SEQ)
Expand Down
54 changes: 54 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,60 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kTrunc, nv_bfloat16, nv_bfloat16
#endif // CUDA_VERSION >= 11000

/*********float complex dtype support*********/
template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kConj, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src.x, -src.y}; }
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kReal, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.x); }
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImag, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return static_cast<Dst>(src.y); }
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kRealGrad, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{src, 0.0}; }
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImagGrad, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const { return Dst{0.0, src}; }
};

// avoid warning: narrowing conversion
template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kRealGrad, cuComplex, double> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC cuComplex operator()(double src) const {
return cuComplex{static_cast<float>(src), 0.0f};
}
};

template<>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kImagGrad, cuComplex, double> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC cuComplex operator()(double src) const {
return cuComplex{0.0f, static_cast<float>(src)};
}
};

template<typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kCast, cuComplex, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ enum class UnaryOp {

// bitwise op
kBitwiseNot,

// complex op
kConj,
kReal,
kImag,
kRealGrad,
kImagGrad
};

}
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5478,6 +5478,7 @@ class RealFunctor {
RealFunctor() { op_ = CHECK_JUST(one::OpBuilder("real").Input("x").Output("out").Build()); }

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
if (!x->dtype()->is_complex()) { return x; }
return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
}

Expand All @@ -5504,6 +5505,9 @@ class ImagFunctor {
ImagFunctor() { op_ = CHECK_JUST(one::OpBuilder("imag").Input("x").Output("out").Build()); }

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
CHECK_OR_RETURN(x->dtype()->is_complex())
<< "RuntimeError: imag is implemented for tensors with complex dtypes, but gets"
<< x->dtype()->name();
return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
}

Expand Down Expand Up @@ -5532,6 +5536,7 @@ class ConjFunctor {
}

Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
if (!x->dtype()->is_complex()) { return x; }
return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
}

Expand Down
Loading

0 comments on commit e4d79e8

Please sign in to comment.