From fb423fa80073e054cb8f5c33606457f5da62c0a2 Mon Sep 17 00:00:00 2001 From: Lu Qi <61354321+MarioLulab@users.noreply.github.com> Date: Mon, 26 Jun 2023 19:44:23 +0800 Subject: [PATCH] Applying autotest module on existing complex operators (#10284) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Original requirements **Autotest**: We found that the previous testing of operators supporting complex tensor was not complete. We decided to reuse the real tensor operator tests to ensure completeness. Since complex tensor tests are supported in the `autotest` module from pr (https://github.com/Oneflow-Inc/oneflow/pull/10027) , in this pr we applied the autotest module to the tests of complex tensor operators already available in Oneflow **Fix**: In addition, the autograd rules for some previous operators of complex numbers might not conform to the convention of ["Conjugate Wirtinger Derivative"](https://en.wikipedia.org/wiki/Wirtinger_derivatives). We have fixed these bugs in this pr at the same time. #### Main Works **Applying `autotest` module on existing operators that have already support complex tensor:** `Complex and Real Behave the Same Way`: means we don't need to add conjugate operation in op grad. Because regardless of whether the input data type involved in the operation is real or complex, the gradient result using the winterger derivative is the same as the real derivative rule, `Grad Not Supported in OF`: means the grad of this op is not supported in oneflow - broadcast_elementwise_binary | Op | complex type | Backend | Using autotest | conjugate Wirtinger derivative | |:-----:|:-------------:|:------:|:------:|:------:| | Add | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the Same Way | | Mul | cp64, cp128 | CPU, CUDA | DONE | DONE | | Sub | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the Same Way | | Equal | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the Same Way | | NotEqual | cp64, cp128 | CPU, CUDA | DONE | Grad Not Supported in OF | - broadcast_elementwise_unary | Op | complex type | Backend | Using autotest | conjugate Wirtinger derivative | |:----------:|:-------------:|:------:|:------:|:------:| | Cast | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the Same Way | - other exisiting operations | Op | complex type | Backend | Using autotest | conjugate Wirtinger derivative | |:----------:|:-------------:|:------:|:------:|:------:| | constant_pad | cp64, cp128 | CPU, CUDA | Done | Complex and Real Behave the Same Way | | reduce_sum | cp64, cp128 | CPU, CUDA | TO-DO | Complex and Real Behave the Same Way | ## 注意: 复数基础设施建设系列 pr: 1. 使用 primitive 来实现 conj, real 等常见复数算子: https://github.com/Oneflow-Inc/oneflow/pull/10281 2. 将现有支持复数数据类型的算子测例迁移到 autotest 模块中,以使复数算子复用实数算子的测试用例:https://github.com/Oneflow-Inc/oneflow/pull/10284 3. 继续拓展支持复数数据类型的算子,比如 matmul, sqrt, div 等:https://github.com/Oneflow-Inc/oneflow/pull/10269 依赖关系: 本 pr 基于:[pr1](https://github.com/Oneflow-Inc/oneflow/pull/10281),需要在 merge [pr1](https://github.com/Oneflow-Inc/oneflow/pull/10281) 后,再 Merge 本 pr --- .../gradient_funcs/broadcast_binary_ops.cpp | 4 +-- oneflow/core/framework/dtype.cpp | 32 ++++++++--------- oneflow/core/functional/impl/math_functor.cpp | 7 ++-- oneflow/core/functional/tensor_processor.cpp | 9 +++-- oneflow/user/kernels/reduce_like_kernels.cpp | 11 ++++++ python/oneflow/test/modules/test_add.py | 16 ++++----- python/oneflow/test/modules/test_cast.py | 8 +++++ .../oneflow/test/modules/test_constant_pad.py | 2 +- python/oneflow/test/modules/test_equal.py | 36 ++++++++++++++++++- python/oneflow/test/modules/test_mul.py | 5 +-- python/oneflow/test/modules/test_sub.py | 5 +-- python/oneflow/test/modules/test_sum.py | 23 +++++++++++- python/oneflow/test/tensor/test_complex.py | 18 +++++++--- .../automated_test_util/generators.py | 6 ++++ python/oneflow/test_utils/test_util.py | 4 +++ 15 files changed, 144 insertions(+), 42 deletions(-) diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index 0b3c367b5a5..d278728cc37 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -142,7 +142,7 @@ class BroadcastMul : public BroadcastBinaryGrad { in_grads->resize(2); if (ctx->x_requires_grad) { const auto& y = ctx->SavedTensors().at(ctx->y_index); - const auto& x_grad = JUST(functional::Mul(out_grads.at(0), y)); + const auto& x_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(y)))); if (ctx->broadcast_x) { const auto& x = ctx->SavedTensors().at(ctx->x_index); in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x)); @@ -152,7 +152,7 @@ class BroadcastMul : public BroadcastBinaryGrad { } if (ctx->y_requires_grad) { const auto& x = ctx->SavedTensors().at(ctx->x_index); - const auto& y_grad = JUST(functional::Mul(out_grads.at(0), x)); + const auto& y_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(x)))); if (ctx->broadcast_y) { const auto& y = ctx->SavedTensors().at(ctx->y_index); in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y)); diff --git a/oneflow/core/framework/dtype.cpp b/oneflow/core/framework/dtype.cpp index 33169d013c5..cd0aca6b751 100644 --- a/oneflow/core/framework/dtype.cpp +++ b/oneflow/core/framework/dtype.cpp @@ -226,24 +226,24 @@ Symbol promoteTypes(const Symbol a, const Symbol b) { static const Symbol _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = { /* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u2 u4 u8 u16 i2 i16 cp4 cp8 cp16 */ /* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf, b1, u2, u4, u8, u16, i2, i16, cp4, cp8, cp16}, - /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, - /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp4, cp16}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp4, cp16}, - /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, - /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp4, cp16}, - /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp4, cp16}, - /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, + /* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, + /* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp8, cp16}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp8, cp16}, + /* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, + /* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp8, cp16}, + /* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp8, cp16}, + /* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, /* re */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, - /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp4, cp16}, + /* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp8, cp16}, /* bu */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, bu, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv}, - /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp4, cp16}, - /* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16}, - /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp4, cp16}, - /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp4, cp16}, - /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp4, cp16}, - /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp4, cp16}, - /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp4, cp16}, - /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp4, cp16}, + /* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp8, cp16}, + /* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16}, + /* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp8, cp16}, + /* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp8, cp16}, + /* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp8, cp16}, + /* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp8, cp16}, + /* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp8, cp16}, + /* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp8, cp16}, /* cp4 */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, cp4, cp8, cp16}, /* cp8 */ {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv, cp8, iv, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp16}, /* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv, cp16,iv, cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}}; diff --git a/oneflow/core/functional/impl/math_functor.cpp b/oneflow/core/functional/impl/math_functor.cpp index 291a68613e9..275757fbb27 100644 --- a/oneflow/core/functional/impl/math_functor.cpp +++ b/oneflow/core/functional/impl/math_functor.cpp @@ -85,7 +85,7 @@ class ScalarMathBaseFunctor { "int_operand", "has_int_operand"); TensorProcessor tensor_processor; Symbol lowest_dtype; - if (scalar.IsFloatingPoint()) { + if (scalar.IsFloatingPoint() || scalar.IsComplex()) { attrs.SetAllAttrs(scalar.As(), true, NullOpt, false); // Only promote type to Float32 when tensor is Int type but scalar is float type. if (DType::priority_order[x->dtype()->data_type()] @@ -797,8 +797,9 @@ class ReduceMeanWholeFunctor { ReduceMeanWholeFunctor() {} Maybe operator()(const std::shared_ptr& x) const { // ReduceMean only calculate floating values. - CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type())) - << "RuntimeError: Can only calculate the mean of floating types."; + CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type()) + || IsComplexDataType(x->dtype()->data_type())) + << "RuntimeError: Can only calculate the mean of floating types or complex types."; size_t reduce_count = 1; reduce_count = x->shape()->Count(0); const auto& sum = JUST(functional::ReduceSumWhole(x, NullOpt)); diff --git a/oneflow/core/functional/tensor_processor.cpp b/oneflow/core/functional/tensor_processor.cpp index e242003a40f..b13ebc57507 100644 --- a/oneflow/core/functional/tensor_processor.cpp +++ b/oneflow/core/functional/tensor_processor.cpp @@ -35,7 +35,10 @@ Symbol ComputeCommonDType(const TensorTuple& tensor_tuple) { [](const std::shared_ptr& tensor) { return tensor->shape()->NumAxes() == 0; }); for (auto& tensor_ptr : tensor_tuple) { // skip scalar tensor - if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0) { continue; } + if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0 + && !(tensor_ptr->dtype()->is_complex())) { + continue; + } common_dtype = promoteTypes(tensor_ptr->dtype(), common_dtype); } return common_dtype; @@ -114,7 +117,9 @@ Maybe TensorProcessor::Apply() { } JUST(CastToSameType(tensor_tuple_, common_dtype_)); } else { - if (tensor_tuple_.size() == 1 && !tensor_tuple_[0]->dtype()->is_floating_point()) { + if (tensor_tuple_.size() == 1 + && !((tensor_tuple_[0]->dtype()->is_floating_point()) + || tensor_tuple_[0]->dtype()->is_complex())) { Symbol cast_dtype = (inputs_lowest_dtype_vec_[0] == DType::InvalidDataType()) ? DType::Float() : inputs_lowest_dtype_vec_[0]; diff --git a/oneflow/user/kernels/reduce_like_kernels.cpp b/oneflow/user/kernels/reduce_like_kernels.cpp index bb202a8fa96..b16cb2c3b73 100644 --- a/oneflow/user/kernels/reduce_like_kernels.cpp +++ b/oneflow/user/kernels/reduce_like_kernels.cpp @@ -124,6 +124,17 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::Cu OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ); +#if defined(WITH_CUDA) +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), + OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64)); +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, + OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA), + OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128)); +#endif // WITH_CUDA + #if defined(WITH_CUDA) namespace { diff --git a/python/oneflow/test/modules/test_add.py b/python/oneflow/test/modules/test_add.py index 931149556dd..3788f893489 100644 --- a/python/oneflow/test/modules/test_add.py +++ b/python/oneflow/test/modules/test_add.py @@ -16,7 +16,7 @@ import unittest from collections import OrderedDict - +import torch as torch_original import numpy as np from oneflow.test_utils.test_util import GenArgList @@ -190,7 +190,7 @@ def test_add(test_case): for arg in GenArgList(arg_dict): arg[0](test_case, *arg[1:]) - @autotest(n=5) + @autotest(n=10) def test_0_size_add(test_case): device = random_device() x = random_tensor(2, 0, 3).to(device) @@ -198,7 +198,7 @@ def test_0_size_add(test_case): out = x + y return out - @autotest(n=3, auto_backward=False) + @autotest(n=6, auto_backward=False) def test_0dim_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3, requires_grad=False).to(device) @@ -206,7 +206,7 @@ def test_0dim_inplace_add(test_case): x += y.mean() return x - @autotest(n=5) + @autotest(n=10) def test_0dim_two_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 3).to(device).mean() @@ -214,7 +214,7 @@ def test_0dim_two_inplace_add(test_case): x += y.mean() return x - @autotest(n=3) + @autotest(n=6) def test_add_with_alpha(test_case): device = random_device() x1 = random_tensor(2, 2, 3).to(device).mean() @@ -260,7 +260,7 @@ def test_0dim_two_inplace_add(test_case): return x x += y.mean().to(torch.bool) - @autotest(n=3) + @autotest(n=6) def test_add_with_alpha_0dim(test_case): device = random_device() x1 = random_tensor(ndim=0).to(device).mean() @@ -279,7 +279,7 @@ def profile_add(test_case): torch.add(torch.ones(100), 20) torch.add(torch.ones(100), torch.ones(100, 1), alpha=10) - @autotest(n=3) + @autotest(n=6) def test_non_contiguous_inplace_add(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -288,7 +288,7 @@ def test_non_contiguous_inplace_add(test_case): y += random_tensor(2, 2, 2).to(device) return y - @autotest(n=5) + @autotest(n=10) def test_scalar_add_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_cast.py b/python/oneflow/test/modules/test_cast.py index 48c1134ef1f..0023d482c2f 100644 --- a/python/oneflow/test/modules/test_cast.py +++ b/python/oneflow/test/modules/test_cast.py @@ -19,6 +19,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original import oneflow as flow import oneflow.unittest @@ -174,6 +175,13 @@ def test_cast_with_scalar_input(test_case): z = y.to(dtype=torch.int8, device=device) return z + @autotest(n=5, auto_backward=True, include_complex=False, atol=1e-5, rtol=1e-5) + def test_cast_with_complex_float2complex(test_case): + device = random_device() + x = random_tensor().to(dtype=torch.float32, device=device) + y = x.to(torch.complex64) + return y + if __name__ == "__main__": unittest.main() diff --git a/python/oneflow/test/modules/test_constant_pad.py b/python/oneflow/test/modules/test_constant_pad.py index d713196c7d7..593ea1dd090 100644 --- a/python/oneflow/test/modules/test_constant_pad.py +++ b/python/oneflow/test/modules/test_constant_pad.py @@ -112,7 +112,7 @@ def test_constantpad3d_with_random_data(test_case): return y @autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False) - def test_constantpad3d_with_random_data(test_case): + def test_constantpad3d_with_random_int_data(test_case): dtype = choice([bool, int]) value = random(0, 2).to(bool) if dtype is bool else random().to(int) m = torch.nn.ConstantPad3d(padding=random(1, 6).to(_size_6_t), value=value,) diff --git a/python/oneflow/test/modules/test_equal.py b/python/oneflow/test/modules/test_equal.py index c5e6ee0c42d..b6daf7dc5b1 100644 --- a/python/oneflow/test/modules/test_equal.py +++ b/python/oneflow/test/modules/test_equal.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.test_util import GenArgList import oneflow as flow @@ -28,7 +29,7 @@ @flow.unittest.skip_unless_1n1d() class TestEqual(flow.unittest.TestCase): - @autotest(n=5, auto_backward=False, check_graph=False) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) def test_eq_with_0_size_data(test_case): device = random_device() x = random_tensor(3, 2, 0, 3).to(device) @@ -75,6 +76,15 @@ def test_flow_equal_with_same_random_data(test_case): x = random_tensor(len(shape), *shape, requires_grad=False).to(device) return torch.equal(x, x) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_equal_complex_with_same_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device + ) + return torch.equal(x, x) + @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_bool_with_random_data(test_case): device = random_device() @@ -87,6 +97,30 @@ def test_flow_equal_bool_with_random_data(test_case): ) return torch.equal(x, y) + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_equal_complex_with_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + return torch.equal(x, y) + + @autotest(n=5, auto_backward=False, check_graph=False, include_complex=True) + def test_flow_not_equal_complex_with_random_data(test_case): + device = random_device() + shape = random_tensor().oneflow.shape + x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to( + device=device + ) + return torch.not_equal(x, y) + @autotest(n=5, auto_backward=False, check_graph=False) def test_flow_equal_with_same_random_0d_data(test_case): device = random_device() diff --git a/python/oneflow/test/modules/test_mul.py b/python/oneflow/test/modules/test_mul.py index 3aa65fd5698..3dcaee390e3 100644 --- a/python/oneflow/test/modules/test_mul.py +++ b/python/oneflow/test/modules/test_mul.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList @@ -208,7 +209,7 @@ def test_broadcast_mul(test_case): x.mul_(y) return x - @autotest(n=3) + @autotest(n=6) def test_non_contiguous_inplace_mul(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -217,7 +218,7 @@ def test_non_contiguous_inplace_mul(test_case): y *= random_tensor(2, 2, 2).to(device) return y - @autotest(n=5) + @autotest(n=10) def test_scalar_mul_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_sub.py b/python/oneflow/test/modules/test_sub.py index 3c70ce725c7..2650a162778 100644 --- a/python/oneflow/test/modules/test_sub.py +++ b/python/oneflow/test/modules/test_sub.py @@ -18,6 +18,7 @@ from collections import OrderedDict import numpy as np +import torch as torch_original from oneflow.test_utils.automated_test_util import * from oneflow.test_utils.test_util import GenArgList @@ -156,7 +157,7 @@ def test_sub_with_alpha(test_case): z3 = torch.sub(s, x3, alpha=alpha) return z1, z2, z3 - @autotest(n=3) + @autotest(n=5) def test_non_contiguous_inplace_sub(test_case): device = random_device() x = random_tensor(2, 2, 4).to(device) @@ -166,7 +167,7 @@ def test_non_contiguous_inplace_sub(test_case): return y @unittest.skip("skip for now, becase it failed 2 times in past week") - @autotest(n=5) + @autotest(n=5, include_complex=True) def test_scalar_sub_with_random_devices(test_case): x1_device = random_device() x2_device = random_device() diff --git a/python/oneflow/test/modules/test_sum.py b/python/oneflow/test/modules/test_sum.py index 3d4cf278c33..975dfda75f7 100644 --- a/python/oneflow/test/modules/test_sum.py +++ b/python/oneflow/test/modules/test_sum.py @@ -97,8 +97,29 @@ def test_sum_dtype(test_case): ) return y + @autotest( + n=10, + check_graph=False, + auto_backward=True, + include_complex=True, + atol=1e-2, + rtol=1e-5, + ) + def test_sum_complex_dtype(test_case): + device = random_device() + x = random_tensor(4, dtype=complex, requires_grad=True).to( + device=device, dtype=random_dtype(["complex"]) + ) + y = torch.sum( + x, + dim=np.random.randint(0, 3), + keepdim=random_bool(), + dtype=random_dtype(["complex"]), + ) + return y + @autotest(check_graph=True, auto_backward=False) - def test_sum_whole_dtype(test_case): + def test_sum_arithmetic_dtype(test_case): device = random_device() x = random_tensor(4, requires_grad=False).to(device) y = torch.sum(x, dtype=random_dtype(["arithmetic"])) diff --git a/python/oneflow/test/tensor/test_complex.py b/python/oneflow/test/tensor/test_complex.py index b5f3b29eaa4..6ea6401af62 100644 --- a/python/oneflow/test/tensor/test_complex.py +++ b/python/oneflow/test/tensor/test_complex.py @@ -523,8 +523,12 @@ def test_mul_cpu(self): # backward flow_ret.sum().backward() - compare_result(flow_x.grad.numpy(), flow_y.numpy(), self.rtol, self.atol) - compare_result(flow_y.grad.numpy(), flow_x.numpy(), self.rtol, self.atol) + compare_result( + flow_x.grad.numpy(), flow_y.numpy().conjugate(), self.rtol, self.atol + ) + compare_result( + flow_y.grad.numpy(), flow_x.numpy().conjugate(), self.rtol, self.atol + ) @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") def test_mul_cuda(self): @@ -549,10 +553,16 @@ def test_mul_cuda(self): # backward flow_ret.sum().backward() compare_result( - flow_x.grad.cpu().detach().numpy(), flow_y.numpy(), self.rtol, self.atol + flow_x.grad.cpu().detach().numpy(), + flow_y.numpy().conjugate(), + self.rtol, + self.atol, ) compare_result( - flow_y.grad.cpu().detach().numpy(), flow_x.numpy(), self.rtol, self.atol + flow_y.grad.cpu().detach().numpy(), + flow_x.numpy().conjugate(), + self.rtol, + self.atol, ) def test_sum_cpu(self): diff --git a/python/oneflow/test_utils/automated_test_util/generators.py b/python/oneflow/test_utils/automated_test_util/generators.py index 24c22159a6e..db90e09622e 100644 --- a/python/oneflow/test_utils/automated_test_util/generators.py +++ b/python/oneflow/test_utils/automated_test_util/generators.py @@ -261,6 +261,10 @@ def _generate(self, annotation): val = float(rng.random() * (high - low) + low) elif annotation == bool: val = random_util.choice([True, False]) + elif annotation == complex: + val_real = float(rng.random() * (high - low) + low) + val_imag = float(rng.random() * (high - low) + low) + val = val_real + 1.0j * val_imag elif annotation is None: val = None elif annotation is NoneType: @@ -425,6 +429,7 @@ class random_pytorch_dtype(generator): floating_dtype_seq = [torch.float, torch.double] half_dtype_seq = [torch.half] bfloat16_dtype_seq = [torch.bfloat16] + complex_dtype_seq = [torch.complex64, torch.complex128] signed_int_dtype_seq = [torch.int8, torch.int32, torch.int64] unsigned_int_dtype_seq = [torch.uint8] int_dtype_seq = [torch.int8, torch.int32, torch.int64] @@ -440,6 +445,7 @@ class random_pytorch_dtype(generator): "float": floating_dtype_seq, "half": half_dtype_seq, "bfloat16": bfloat16_dtype_seq, + "complex": complex_dtype_seq, "signed": signed_int_dtype_seq, "unsigned": unsigned_int_dtype_seq, "int": int_dtype_seq, diff --git a/python/oneflow/test_utils/test_util.py b/python/oneflow/test_utils/test_util.py index 8474517b737..a817f42c910 100644 --- a/python/oneflow/test_utils/test_util.py +++ b/python/oneflow/test_utils/test_util.py @@ -73,6 +73,8 @@ def __repr__(self): "uint8": flow.uint8, "half": flow.half, "bfloat16": flow.bfloat16, + "complex64": flow.complex64, + "complex128": flow.complex128, } type_name_to_np_type = { "float16": np.float16, @@ -82,6 +84,8 @@ def __repr__(self): "int32": np.int32, "int64": np.int64, "uint8": np.uint8, + "complex64": np.complex64, + "complex128": np.complex128, }