Skip to content

Commit

Permalink
Applying autotest module on existing complex operators (#10284)
Browse files Browse the repository at this point in the history
### Original requirements
**Autotest**: We found that the previous testing of operators supporting
complex tensor was not complete. We decided to reuse the real tensor
operator tests to ensure completeness. Since complex tensor tests are
supported in the `autotest` module from pr
(#10027) , in this pr we
applied the autotest module to the tests of complex tensor operators
already available in Oneflow

**Fix**: In addition, the autograd rules for some previous operators of
complex numbers might not conform to the convention of ["Conjugate
Wirtinger
Derivative"](https://en.wikipedia.org/wiki/Wirtinger_derivatives). We
have fixed these bugs in this pr at the same time.

#### Main Works
**Applying `autotest` module on existing operators that have already
support complex tensor:**

`Complex and Real Behave the Same Way`: means we don't need to add
conjugate operation in op grad. Because regardless of whether the input
data type involved in the operation is real or complex, the gradient
result using the winterger derivative is the same as the real derivative
rule,

`Grad Not Supported in OF`: means the grad of this op is not supported
in oneflow

- broadcast_elementwise_binary
| Op | complex type | Backend | Using autotest | conjugate Wirtinger
derivative |
  |:-----:|:-------------:|:------:|:------:|:------:|
| Add | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the
Same Way |
  | Mul |  cp64, cp128   |  CPU, CUDA | DONE | DONE |
| Sub | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the
Same Way |
| Equal | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the
Same Way |
| NotEqual | cp64, cp128 | CPU, CUDA | DONE | Grad Not Supported in OF |


- broadcast_elementwise_unary
| Op | complex type | Backend | Using autotest | conjugate Wirtinger
derivative |
  |:----------:|:-------------:|:------:|:------:|:------:|
| Cast | cp64, cp128 | CPU, CUDA | DONE | Complex and Real Behave the
Same Way |


- other exisiting operations
| Op | complex type | Backend | Using autotest | conjugate Wirtinger
derivative |
  |:----------:|:-------------:|:------:|:------:|:------:|
| constant_pad | cp64, cp128 | CPU, CUDA | Done | Complex and Real
Behave the Same Way |
| reduce_sum | cp64, cp128 | CPU, CUDA | TO-DO | Complex and Real Behave
the Same Way |


## 注意:
复数基础设施建设系列 pr:
1. 使用 primitive 来实现 conj, real 等常见复数算子:
#10281
2. 将现有支持复数数据类型的算子测例迁移到 autotest
模块中,以使复数算子复用实数算子的测试用例:#10284
3. 继续拓展支持复数数据类型的算子,比如 matmul, sqrt, div
等:#10269
依赖关系:
本 pr 基于:[pr1](https://github.com/Oneflow-Inc/oneflow/pull/10281),需要在
merge [pr1](#10281) 后,再 Merge
本 pr
  • Loading branch information
MarioLulab authored Jun 26, 2023
1 parent 6f5e0f6 commit fb423fa
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 42 deletions.
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class BroadcastMul : public BroadcastBinaryGrad {
in_grads->resize(2);
if (ctx->x_requires_grad) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
const auto& x_grad = JUST(functional::Mul(out_grads.at(0), y));
const auto& x_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(y))));
if (ctx->broadcast_x) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
in_grads->at(0) = JUST(functional::BroadcastReduceSumLike(x_grad, x));
Expand All @@ -152,7 +152,7 @@ class BroadcastMul : public BroadcastBinaryGrad {
}
if (ctx->y_requires_grad) {
const auto& x = ctx->SavedTensors().at(ctx->x_index);
const auto& y_grad = JUST(functional::Mul(out_grads.at(0), x));
const auto& y_grad = JUST(functional::Mul(out_grads.at(0), JUST(functional::Conj(x))));
if (ctx->broadcast_y) {
const auto& y = ctx->SavedTensors().at(ctx->y_index);
in_grads->at(1) = JUST(functional::BroadcastReduceSumLike(y_grad, y));
Expand Down
32 changes: 16 additions & 16 deletions oneflow/core/framework/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,24 +226,24 @@ Symbol<DType> promoteTypes(const Symbol<DType> a, const Symbol<DType> b) {
static const Symbol<DType> _promoteTypesLookup[DataType_ARRAYSIZE][DataType_ARRAYSIZE] = {
/* iv c1 f4 f8 i1 i4 i8 u1 re f2 bu bf b1 u2 u4 u8 u16 i2 i16 cp4 cp8 cp16 */
/* iv */ {iv, c1, f4, f8, i1, i4, i8, u1, re, f2, bu, bf, b1, u2, u4, u8, u16, i2, i16, cp4, cp8, cp16},
/* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16},
/* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp4, cp16},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp4, cp16},
/* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp4, cp16},
/* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp4, cp16},
/* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp4, cp16},
/* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16},
/* c1 */ {c1, c1, f4, f8, i1, i4, i8, c1, iv, f2, iv, bf, c1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16},
/* f4 */ {f4, f4, f4, f8, f4, f4, f4, f4, iv, f4, iv, bf, f4, f4, f4, f4, f4, f4, f4, iv, cp8, cp16},
/* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, iv, f8, iv, bf, f8, f8, f8, f8, f8, f8, f8, iv, cp8, cp16},
/* i1 */ {i1, i1, f4, f8, i1, i4, i8, i2, iv, f2, iv, bf, i1, i4, i8, i16, iv, i2, i16, iv, cp8, cp16},
/* i4 */ {i4, i4, f4, f8, i4, i4, i8, i4, iv, f2, iv, bf, i4, i4, i8, i16, iv, i4, i16, iv, cp8, cp16},
/* i8 */ {i8, i8, f4, f8, i8, i8, i8, i8, iv, f2, iv, bf, i8, i8, i8, i16, iv, i8, i16, iv, cp8, cp16},
/* u1 */ {u1, c1, f4, f8, i2, i4, i8, u1, iv, f2, iv, bf, u1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16},
/* re */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv},
/* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp4, cp16},
/* f2 */ {f2, f2, f4, f8, f2, f2, f2, f2, iv, f2, iv, bf, f2, f2, f2, f2, iv, f2, f2, iv, cp8, cp16},
/* bu */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, bu, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv},
/* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp4, cp16},
/* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp4, cp16},
/* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp4, cp16},
/* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp4, cp16},
/* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp4, cp16},
/* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp4, cp16},
/* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp4, cp16},
/* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp4, cp16},
/* bf */ {bf, bf, bf, bf, bf, bf, bf, bf, iv, bf, iv, bf, bf, bf, bf, bf, iv, bf, bf, iv, cp8, cp16},
/* b1 */ {b1, c1, f4, f8, i1, i4, i8, u1, iv, f2, iv, bf, b1, u2, u4, u8, u16, i2, i16, iv, cp8, cp16},
/* u2 */ {u2, u2, f4, f8, i4, i4, i8, u2, iv, f2, iv, bf, u2, u2, u4, u8, u16, i4, i16, iv, cp8, cp16},
/* u4 */ {u4, u4, f4, f8, i8, i8, i8, u4, iv, f2, iv, bf, u4, u4, u4, u8, u16, i8, i16, iv, cp8, cp16},
/* u8 */ {u8, u8, f4, f8, i16, i16, i16, u8, iv, f2, iv, bf, u8, u8, u8, u8, u16, i16, i16, iv, cp8, cp16},
/* u16 */ {u16, u16, f4, f8, iv, iv, iv, u16, iv, f2, iv, bf, u16, u16, u16, u16, u16, iv, iv, iv, cp8, cp16},
/* i2 */ {i2, i2, f4, f8, i2, i4, i8, i2, iv, f2, iv, bf, i2, i4, i8, i16, iv, i2, i16, iv, cp8, cp16},
/* i16 */ {i16, i16, f4, f8, i16, i16, i16, i16, iv, f2, iv, bf, i16, i16, i16, i16, iv, i16, i16, iv, cp8, cp16},
/* cp4 */ {iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, iv, cp4, cp8, cp16},
/* cp8 */ {cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, iv, cp8, iv, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp8, cp16},
/* cp16 */ {cp16,cp16,cp16,cp16,cp16,cp16,cp16,cp16,iv, cp16,iv, cp16,cp16,cp16,cp16,cp16,cp16, cp16,cp16, cp16, cp16, cp16}};
Expand Down
7 changes: 4 additions & 3 deletions oneflow/core/functional/impl/math_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ScalarMathBaseFunctor {
"int_operand", "has_int_operand");
TensorProcessor tensor_processor;
Symbol<DType> lowest_dtype;
if (scalar.IsFloatingPoint()) {
if (scalar.IsFloatingPoint() || scalar.IsComplex()) {
attrs.SetAllAttrs(scalar.As<double>(), true, NullOpt, false);
// Only promote type to Float32 when tensor is Int type but scalar is float type.
if (DType::priority_order[x->dtype()->data_type()]
Expand Down Expand Up @@ -797,8 +797,9 @@ class ReduceMeanWholeFunctor {
ReduceMeanWholeFunctor() {}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x) const {
// ReduceMean only calculate floating values.
CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type()))
<< "RuntimeError: Can only calculate the mean of floating types.";
CHECK_OR_RETURN(IsFloatingDataType(x->dtype()->data_type())
|| IsComplexDataType(x->dtype()->data_type()))
<< "RuntimeError: Can only calculate the mean of floating types or complex types.";
size_t reduce_count = 1;
reduce_count = x->shape()->Count(0);
const auto& sum = JUST(functional::ReduceSumWhole(x, NullOpt));
Expand Down
9 changes: 7 additions & 2 deletions oneflow/core/functional/tensor_processor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ Symbol<DType> ComputeCommonDType(const TensorTuple& tensor_tuple) {
[](const std::shared_ptr<Tensor>& tensor) { return tensor->shape()->NumAxes() == 0; });
for (auto& tensor_ptr : tensor_tuple) {
// skip scalar tensor
if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0) { continue; }
if (!all_scalar_tensors && tensor_ptr->shape()->NumAxes() == 0
&& !(tensor_ptr->dtype()->is_complex())) {
continue;
}
common_dtype = promoteTypes(tensor_ptr->dtype(), common_dtype);
}
return common_dtype;
Expand Down Expand Up @@ -114,7 +117,9 @@ Maybe<void> TensorProcessor::Apply() {
}
JUST(CastToSameType(tensor_tuple_, common_dtype_));
} else {
if (tensor_tuple_.size() == 1 && !tensor_tuple_[0]->dtype()->is_floating_point()) {
if (tensor_tuple_.size() == 1
&& !((tensor_tuple_[0]->dtype()->is_floating_point())
|| tensor_tuple_[0]->dtype()->is_complex())) {
Symbol<DType> cast_dtype = (inputs_lowest_dtype_vec_[0] == DType::InvalidDataType())
? DType::Float()
: inputs_lowest_dtype_vec_[0];
Expand Down
11 changes: 11 additions & 0 deletions oneflow/user/kernels/reduce_like_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ class ReduceSumLikeOpKernel final : public user_op::OpKernel, public user_op::Cu
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL, DEVICE_TYPE_SEQ,
ARITHMETIC_DATA_TYPE_SEQ)

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCPU), COMPLEX_DATA_TYPE_SEQ);
#if defined(WITH_CUDA)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),
OF_PP_MAKE_TUPLE_SEQ(cuComplex, DataType::kComplex64));
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_REDUCE_SUM_LIKE_KERNEL,
OF_PP_MAKE_TUPLE_SEQ(DeviceType::kCUDA),
OF_PP_MAKE_TUPLE_SEQ(cuDoubleComplex, DataType::kComplex128));
#endif // WITH_CUDA

#if defined(WITH_CUDA)

namespace {
Expand Down
16 changes: 8 additions & 8 deletions python/oneflow/test/modules/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
from collections import OrderedDict

import torch as torch_original
import numpy as np
from oneflow.test_utils.test_util import GenArgList

Expand Down Expand Up @@ -190,31 +190,31 @@ def test_add(test_case):
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])

@autotest(n=5)
@autotest(n=10)
def test_0_size_add(test_case):
device = random_device()
x = random_tensor(2, 0, 3).to(device)
y = random_tensor(2, 1, 3).to(device)
out = x + y
return out

@autotest(n=3, auto_backward=False)
@autotest(n=6, auto_backward=False)
def test_0dim_inplace_add(test_case):
device = random_device()
x = random_tensor(2, 2, 3, requires_grad=False).to(device)
y = random_tensor(1, 10).to(device)
x += y.mean()
return x

@autotest(n=5)
@autotest(n=10)
def test_0dim_two_inplace_add(test_case):
device = random_device()
x = random_tensor(2, 2, 3).to(device).mean()
y = random_tensor(2, 2, 3).to(device)
x += y.mean()
return x

@autotest(n=3)
@autotest(n=6)
def test_add_with_alpha(test_case):
device = random_device()
x1 = random_tensor(2, 2, 3).to(device).mean()
Expand Down Expand Up @@ -260,7 +260,7 @@ def test_0dim_two_inplace_add(test_case):
return x
x += y.mean().to(torch.bool)

@autotest(n=3)
@autotest(n=6)
def test_add_with_alpha_0dim(test_case):
device = random_device()
x1 = random_tensor(ndim=0).to(device).mean()
Expand All @@ -279,7 +279,7 @@ def profile_add(test_case):
torch.add(torch.ones(100), 20)
torch.add(torch.ones(100), torch.ones(100, 1), alpha=10)

@autotest(n=3)
@autotest(n=6)
def test_non_contiguous_inplace_add(test_case):
device = random_device()
x = random_tensor(2, 2, 4).to(device)
Expand All @@ -288,7 +288,7 @@ def test_non_contiguous_inplace_add(test_case):
y += random_tensor(2, 2, 2).to(device)
return y

@autotest(n=5)
@autotest(n=10)
def test_scalar_add_with_random_devices(test_case):
x1_device = random_device()
x2_device = random_device()
Expand Down
8 changes: 8 additions & 0 deletions python/oneflow/test/modules/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import OrderedDict

import numpy as np
import torch as torch_original

import oneflow as flow
import oneflow.unittest
Expand Down Expand Up @@ -174,6 +175,13 @@ def test_cast_with_scalar_input(test_case):
z = y.to(dtype=torch.int8, device=device)
return z

@autotest(n=5, auto_backward=True, include_complex=False, atol=1e-5, rtol=1e-5)
def test_cast_with_complex_float2complex(test_case):
device = random_device()
x = random_tensor().to(dtype=torch.float32, device=device)
y = x.to(torch.complex64)
return y


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion python/oneflow/test/modules/test_constant_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_constantpad3d_with_random_data(test_case):
return y

@autotest(n=10, rtol=0.001, atol=0.001, auto_backward=False)
def test_constantpad3d_with_random_data(test_case):
def test_constantpad3d_with_random_int_data(test_case):
dtype = choice([bool, int])
value = random(0, 2).to(bool) if dtype is bool else random().to(int)
m = torch.nn.ConstantPad3d(padding=random(1, 6).to(_size_6_t), value=value,)
Expand Down
36 changes: 35 additions & 1 deletion python/oneflow/test/modules/test_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import OrderedDict

import numpy as np
import torch as torch_original
from oneflow.test_utils.test_util import GenArgList

import oneflow as flow
Expand All @@ -28,7 +29,7 @@

@flow.unittest.skip_unless_1n1d()
class TestEqual(flow.unittest.TestCase):
@autotest(n=5, auto_backward=False, check_graph=False)
@autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)
def test_eq_with_0_size_data(test_case):
device = random_device()
x = random_tensor(3, 2, 0, 3).to(device)
Expand Down Expand Up @@ -75,6 +76,15 @@ def test_flow_equal_with_same_random_data(test_case):
x = random_tensor(len(shape), *shape, requires_grad=False).to(device)
return torch.equal(x, x)

@autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)
def test_flow_equal_complex_with_same_random_data(test_case):
device = random_device()
shape = random_tensor().oneflow.shape
x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(
device
)
return torch.equal(x, x)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_flow_equal_bool_with_random_data(test_case):
device = random_device()
Expand All @@ -87,6 +97,30 @@ def test_flow_equal_bool_with_random_data(test_case):
)
return torch.equal(x, y)

@autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)
def test_flow_equal_complex_with_random_data(test_case):
device = random_device()
shape = random_tensor().oneflow.shape
x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(
device=device
)
y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(
device=device
)
return torch.equal(x, y)

@autotest(n=5, auto_backward=False, check_graph=False, include_complex=True)
def test_flow_not_equal_complex_with_random_data(test_case):
device = random_device()
shape = random_tensor().oneflow.shape
x = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(
device=device
)
y = random_tensor(len(shape), *shape, requires_grad=False, dtype=complex).to(
device=device
)
return torch.not_equal(x, y)

@autotest(n=5, auto_backward=False, check_graph=False)
def test_flow_equal_with_same_random_0d_data(test_case):
device = random_device()
Expand Down
5 changes: 3 additions & 2 deletions python/oneflow/test/modules/test_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import OrderedDict

import numpy as np
import torch as torch_original

from oneflow.test_utils.automated_test_util import *
from oneflow.test_utils.test_util import GenArgList
Expand Down Expand Up @@ -208,7 +209,7 @@ def test_broadcast_mul(test_case):
x.mul_(y)
return x

@autotest(n=3)
@autotest(n=6)
def test_non_contiguous_inplace_mul(test_case):
device = random_device()
x = random_tensor(2, 2, 4).to(device)
Expand All @@ -217,7 +218,7 @@ def test_non_contiguous_inplace_mul(test_case):
y *= random_tensor(2, 2, 2).to(device)
return y

@autotest(n=5)
@autotest(n=10)
def test_scalar_mul_with_random_devices(test_case):
x1_device = random_device()
x2_device = random_device()
Expand Down
5 changes: 3 additions & 2 deletions python/oneflow/test/modules/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import OrderedDict

import numpy as np
import torch as torch_original

from oneflow.test_utils.automated_test_util import *
from oneflow.test_utils.test_util import GenArgList
Expand Down Expand Up @@ -156,7 +157,7 @@ def test_sub_with_alpha(test_case):
z3 = torch.sub(s, x3, alpha=alpha)
return z1, z2, z3

@autotest(n=3)
@autotest(n=5)
def test_non_contiguous_inplace_sub(test_case):
device = random_device()
x = random_tensor(2, 2, 4).to(device)
Expand All @@ -166,7 +167,7 @@ def test_non_contiguous_inplace_sub(test_case):
return y

@unittest.skip("skip for now, becase it failed 2 times in past week")
@autotest(n=5)
@autotest(n=5, include_complex=True)
def test_scalar_sub_with_random_devices(test_case):
x1_device = random_device()
x2_device = random_device()
Expand Down
23 changes: 22 additions & 1 deletion python/oneflow/test/modules/test_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,29 @@ def test_sum_dtype(test_case):
)
return y

@autotest(
n=10,
check_graph=False,
auto_backward=True,
include_complex=True,
atol=1e-2,
rtol=1e-5,
)
def test_sum_complex_dtype(test_case):
device = random_device()
x = random_tensor(4, dtype=complex, requires_grad=True).to(
device=device, dtype=random_dtype(["complex"])
)
y = torch.sum(
x,
dim=np.random.randint(0, 3),
keepdim=random_bool(),
dtype=random_dtype(["complex"]),
)
return y

@autotest(check_graph=True, auto_backward=False)
def test_sum_whole_dtype(test_case):
def test_sum_arithmetic_dtype(test_case):
device = random_device()
x = random_tensor(4, requires_grad=False).to(device)
y = torch.sum(x, dtype=random_dtype(["arithmetic"]))
Expand Down
Loading

0 comments on commit fb423fa

Please sign in to comment.