Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
port & update pr16744 numpy gcd (#19547)
Browse files Browse the repository at this point in the history
* numpy-compatible gcd operator

* use BinaryScalarRTCCompute

* Update _op.py

* Update np_elemwise_broadcast_op_extended.cc

* fix

* Update operator_tune.cc

* fix kernel

* add large tensor test

* add gcd interoperability workload

* Update test_numpy_interoperability.py

* Update np_elemwise_broadcast_op_extended.cc

* Update np_elemwise_broadcast_op_extended.cc

* avoid ci linspce issue

Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>
  • Loading branch information
Zha0q1 and haojin2 authored Feb 3, 2021
1 parent c723ae2 commit 25c25da
Show file tree
Hide file tree
Showing 15 changed files with 278 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ci/docker/install/requirements
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# the whole docker cache for the image

# Required dependencies
numpy<1.20.0
numpy>=1.17,<1.20.0
requests>=2.20.0,<3
graphviz<0.9.0,>=0.8.1
contextvars;python_version<"3.7"
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@
'_npi_logistic',
'_npi_lcm',
'_npi_lcm_scalar',
'_npi_gcd',
'_npi_gcd_scalar',
'_npi_linspace',
'_npi_logical_not',
'_npi_logical_and_scalar',
Expand Down
42 changes: 41 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'max', 'min', 'amax', 'amin', 'logical_and', 'logical_or', 'logical_xor',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd',
'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
Expand Down Expand Up @@ -2081,6 +2081,46 @@ def expand_dims(a, axis):
return _api_internal.expand_dims(a, axis)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def gcd(x1, x2, out=None, **kwargs):
"""
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape of
one or the other).
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
y : ndarray or scalar
The greatest common divisor of the absolute value of the inputs
This is a scalar if both `x1` and `x2` are scalars.
See Also
--------
lcm : The lowest common multiple
Examples
--------
>>> np.gcd(12, 20)
4
>>> np.gcd(np.arange(6, dtype=int), 20)
array([20, 1, 2, 1, 4, 5], dtype=int64)
"""
if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
return _np.gcd(x1, x2, out=out)
return _api_internal.gcd(x1, x2, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def lcm(x1, x2, out=None, **kwargs):
Expand Down
40 changes: 39 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'triu_indices_from', 'triu_indices', 'tri',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'cross', 'kron', 'equal', 'not_equal', 'interp',
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
Expand Down Expand Up @@ -3620,6 +3620,44 @@ def power(x1, x2, out=None, **kwargs):
return _mx_nd_np.power(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def gcd(x1, x2, out=None, **kwargs):
"""
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape of
one or the other).
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
y : ndarray or scalar
The greatest common divisor of the absolute value of the inputs
This is a scalar if both `x1` and `x2` are scalars.
See Also
--------
gcd : The lowest common multiple
Examples
--------
>>> np.gcd(12, 20)
4
>>> np.gcd(np.arange(6, dtype=int), 20)
array([20, 1, 2, 1, 4, 5], dtype=int64)
"""
return _mx_nd_np.gcd(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def lcm(x1, x2, out=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def _register_array_function():
'degrees',
'hypot',
'lcm',
'gcd',
# 'ldexp',
'subtract',
'multiply',
Expand Down
33 changes: 32 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'flatnonzero', 'tril_indices', 'amax', 'amin', 'max', 'min', 'logical_and', 'logical_or', 'logical_xor',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd', 'interp',
'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
Expand Down Expand Up @@ -1678,6 +1678,37 @@ def power(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def gcd(x1, x2, out=None, **kwargs):
"""
Returns the greatest common divisor of ``|x1|`` and ``|x2|``
Parameters
----------
x1, x2 : ndarrays or scalar values
The arrays for computing greatest common divisor. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which may be the shape of
one or the other).
out : ndarray or None, optional
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.
Returns
-------
y : ndarray or scalar
The greatest common divisor of the absolute value of the inputs
This is a scalar if both `x1` and `x2` are scalars.
See Also
--------
lcm : The lowest common multiple
"""
return _ufunc_helper(x1, x2, _npi.gcd, _np.gcd, _npi.gcd_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def matmul(a, b, out=None, **kwargs):
Expand Down
8 changes: 8 additions & 0 deletions src/api/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,14 @@ MXNET_REGISTER_API("_npi.lcm")
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.gcd")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_gcd");
const nnvm::Op* op_scalar = Op::Get("_npi_gcd_scalar");
UFuncHelper(args, ret, op, op_scalar, nullptr);
});

MXNET_REGISTER_API("_npi.logical_and")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
Expand Down
43 changes: 43 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,49 @@ lcm(const DType a, const DType2 b) {
}
}
template <typename DType, typename DType2>
__device__ inline typename type_util::mixed_type<DType, DType2>::type
gcd(const DType a, const DType2 b) {
if (type_util::is_integral<DType>::value &&
type_util::is_integral<DType2>::value) {
DType A = a;
DType2 B = b;
// minus cases.
if (a < 0) {
A = -a;
}
if (b < 0) {
B = -b;
}
// handle zero-valued cases.
DType c;
if (a == 0 && b != 0) {
c = B;
} else if (b == 0 && a != 0) {
c = A;
} else if (a == 0 && b == 0) {
c = 0;
} else {
DType tmp;
if (A < B) {
tmp = A;
A = B;
B = tmp;
}
while (A % B != 0) {
A = A % B;
tmp = A;
A = B;
B = tmp;
}
c = B;
}
return c;
} else {
return 0;
}
}
template <typename DType, typename DType2>
__device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_xor(const DType a,
const DType2 b) {
Expand Down
46 changes: 46 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,52 @@ struct nanprod_grad : public mxnet_op::tunable {
#pragma GCC diagnostic ignored "-Wint-in-bool-context"
#pragma GCC diagnostic ignored "-Wbool-compare"
#endif

/*! \brief used for computing binary greatest common divisor */
struct gcd : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type
Map(DType a, DType b) {
// minus cases.
if (a < 0) {
a = -a;
}
if (b < 0) {
b = -b;
}
// handle zero-valued cases.
DType c;
if (a == 0 && b != 0) {
c = b;
} else if (b == 0 && a != 0) {
c = a;
} else if (a == 0 && b == 0) {
c = 0;
} else {
DType tmp;
if (a < b) {
tmp = a;
a = b;
b = tmp;
}
while (a % b != 0) {
a = a % b;
tmp = a;
a = b;
b = tmp;
}
c = b;
}
return c;
}

template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type
Map(DType a, DType b) {
return DType(0.0f);
}
};

/*! \brief used for computing binary lowest common multiple */
struct lcm : public mxnet_op::tunable {
template<typename DType>
Expand Down
35 changes: 34 additions & 1 deletion src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,39 @@ NNVM_REGISTER_OP(_backward_npi_copysign)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);

NNVM_REGISTER_OP(_npi_gcd)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, mshadow_op::gcd>)
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");

NNVM_REGISTER_OP(_npi_gcd_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::gcd>);

NNVM_REGISTER_OP(_npi_lcm)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down Expand Up @@ -94,7 +127,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
.set_num_inputs(2)
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ namespace op {
NNVM_REGISTER_OP(_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"copysign"});

NNVM_REGISTER_OP(_npi_gcd)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"gcd"});

NNVM_REGISTER_OP(_npi_lcm)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"lcm"});

Expand Down Expand Up @@ -82,6 +85,9 @@ NNVM_REGISTER_OP(_npi_rarctan2_scalar)
NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCBackward{"rarctan2_grad"});

NNVM_REGISTER_OP(_npi_gcd_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"gcd"});

NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"lcm"});

Expand Down
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); //
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::gcd); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT()
Expand Down
Loading

0 comments on commit 25c25da

Please sign in to comment.