From 4fdd7a3ca404d70180d860feec985e35791b9f6b Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 8 Nov 2022 12:04:56 +0800 Subject: [PATCH 01/26] add op, cpu kernel, functor --- oneflow/core/functional/functional_api.yaml | 4 + oneflow/core/functional/impl/nn_functor.cpp | 81 +++++++++ oneflow/ir/include/OneFlow/OneFlowUserOps.td | 17 ++ oneflow/user/kernels/cdist_kernel.cpp | 174 +++++++++++++++++++ oneflow/user/ops/cdist_op.cpp | 63 +++++++ 5 files changed, 339 insertions(+) create mode 100644 oneflow/user/kernels/cdist_kernel.cpp create mode 100644 oneflow/user/ops/cdist_op.cpp diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index cee5a94e085..0c026eb504d 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2315,6 +2315,10 @@ signature: "Tensor (Tensor x, Tensor y, Int32 dim=1, Double eps=1e-8) => CosineSimilarity" bind_python: True +- name: "cdist" + signature: 'Tensor (Tensor x1, Tensor x2, Double p=2.0, String compute_mode="use_mm_for_euclid_dist_if_necessary") => CDist' + bind_python: True + - name: "normalize" signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize" bind_python: True diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 9980c03ed3b..6ef0b02b1e6 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2968,6 +2968,86 @@ class CosineSimilarityFunctor { } }; +class CdistFunctor { + public: + CdistFunctor() { + op_ = CHECK_JUST(OpBuilder("cdist").Input("x1").Input("x2").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x1, const std::shared_ptr& x2, + const double& p, const std::string& compute_mode) const { + const int64_t x1_ndim = x1->ndim(); + const int64_t x2_ndim = x2->ndim(); + CHECK_OR_RETURN(x1_ndim >= 2) << "cdist only supports at least 2D tensors, X1 got: " + << x1->ndim() << "D"; + CHECK_OR_RETURN(x2_ndim >= 2) << "cdist only supports at least 2D tensors, X2 got: " + << x2->ndim() << "D"; + CHECK_OR_RETURN(x1->dim(x1_ndim - 1) == x2->dim(x2_ndim - 1)) + << "X1 and X2 must have the same number of columns. X1: " << x1->dim(x1_ndim - 1) + << " X2: " << x2->dim(x2_ndim - 1); + CHECK_OR_RETURN(p >= 0) << "cdist only supports non-negative p values, got " << p; + + int32_t mode = 0; + if (compute_mode == "use_mm_for_euclid_dist_if_necessary") { + mode = 0; + } else if (compute_mode == "use_mm_for_euclid_dist") { + mode = 1; + } else if (compute_mode == "donot_use_mm_for_euclid_dist") { + mode = 2; + } else { + THROW(RuntimeError) << compute_mode << " is not a valid value for compute_mode"; + } + + int64_t r1 = x1->dim(x1_ndim - 2); + int64_t r2 = x2->dim(x2_ndim - 2); + int64_t d = x1->dim(x1_ndim - 1); + + Shape x1_batch_shape = Shape(DimVector({x1->shape()->begin(), x1->shape()->end() - 2})); + Shape x2_batch_shape = Shape(DimVector({x2->shape()->begin(), x2->shape()->end() - 2})); + Shape max_batch_shape = Shape::Ones(std::max(x1_batch_shape.NumAxes(), x2_batch_shape.NumAxes())); + { + for (int64_t i = max_batch_shape.NumAxes() - 1; i >= 0; i--) { + int64_t offset = max_batch_shape.NumAxes() - 1 - i; + int64_t dim_x = x1_batch_shape.NumAxes() - 1 - offset; + int64_t dim_y = x2_batch_shape.NumAxes() - 1 - offset; + int64_t size_x = (dim_x >= 0) ? x1_batch_shape.At(dim_x) : 1; + int64_t size_y = (dim_y >= 0) ? x2_batch_shape.At(dim_y) : 1; + if (!(size_x == size_y || size_x == 1 || size_y == 1)) { + return Error::RuntimeError() + << "The size of tensor a (" << size_x << ") must match the size of tensor b (" + << size_y << ") at non-singleton dimension " << i; + } + max_batch_shape.Set(i, std::max(size_x, size_y)); + } + } + // auto max_batch_shape = JUST(InferShapeForDistance(x1_batch_shape, x2_batch_shape)); + Shape x1_expand_shape(max_batch_shape); + Shape x2_expand_shape(max_batch_shape); + std::cout << x1_expand_shape.DebugStr() << std::endl; + std::cout << "r1: " << r1 << std::endl; + x1_expand_shape.emplace_back(r1); + x1_expand_shape.emplace_back(d); + x2_expand_shape.emplace_back(r2); + x2_expand_shape.emplace_back(d); + + std::cout << "run1: " << x1_expand_shape.DebugStr() << std::endl; + const auto x1_expand = JUST(Expand(x1, x1_expand_shape)); + std::cout << "run2: " << x2_expand_shape.DebugStr() << std::endl; + const auto x2_expand = JUST(Expand(x2, x2_expand_shape)); + + TensorProcessor tensor_processor; + JUST(tensor_processor.PromoteInputsToCommonDtype(true) + .AddInputs({x1_expand, x2_expand}) + .Apply()); + + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p", "mode"); + attrs.SetAllAttrs(p, mode); + return OpInterpUtil::Dispatch(*op_, {x1, x2}, attrs); + } + + private: + std::shared_ptr op_; +}; + class L2NormalizeFunctor { public: L2NormalizeFunctor() { @@ -4538,6 +4618,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("PairwiseDistance"); m.add_functor("CosineSimilarity"); m.add_functor("Normalize"); + m.add_functor("CDist"); m.add_functor("L2Normalize"); m.add_functor("L2NormalizeGrad"); m.add_functor("FusedBiasAddGelu"); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index ad43b6cfee9..2a46574d2c0 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -3924,6 +3924,23 @@ def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoSideEffect, DeclareOpInter let has_data_type_infer_fn = 1; } +def OneFlow_CdistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x1, + OneFlow_Tensor:$x2 + ); + let output = (outs + OneFlow_Tensor:$out + ); + let attrs = (ins + DefaultValuedAttr:$p + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x diff --git a/oneflow/user/kernels/cdist_kernel.cpp b/oneflow/user/kernels/cdist_kernel.cpp new file mode 100644 index 00000000000..ec2fd6a357e --- /dev/null +++ b/oneflow/user/kernels/cdist_kernel.cpp @@ -0,0 +1,174 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include +#include +#include "oneflow/core/ep/cpu/cpu_stream.h" +#include "oneflow/core/ep/include/primitive/memset.h" +#include "oneflow/core/ep/include/stream.h" +#include "oneflow/core/framework/user_op_hob.h" +#include "oneflow/core/ndarray/xpu_util.h" +#include "oneflow/core/thread/thread_manager.h" + +namespace oneflow { + +template +struct ZeroDist { + static inline T map(const T& diff, const T& p) { return diff == T(0) ? diff : T(1); } + static inline T reduce(const T& agg, const T& up) { return agg + up; } + static inline T finish(const T agg, const T p) { return agg; } +}; + +template +struct OneDist { + static inline T map(const T& diff, const T& p) { return diff; } + static inline T reduce(const T& agg, const T& up) { return agg + up; } + static inline T finish(const T agg, const T p) { return agg; } + static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + return T(grad) * diff > 0 ? T(1) : T(-1); + } +}; + +template +struct TwoDist { + static inline T map(const T& diff, const T& p) { return diff * diff; } + static inline T reduce(const T& agg, const T& up) { return agg + up; } + static inline T finish(const T agg, const T p) { return std::sqrt(agg); } + static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + return dist == 0.0 ? T(0) : T(grad) * diff / T(dist); + } +}; + +template +struct InfiDist { + static inline T map(const T& diff, const T& p) { return diff; } + static inline T reduce(const T& agg, const T& up) { return std::max(agg, up); } + static inline T finish(const T agg, const T p) { return agg; } + static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + return dist == 0.0 ? T(0) : T(grad) * diff / T(dist); + } +}; + +template +struct PDist { + static inline T map(const T& diff, const T& p) { return std::pow(diff, p); } + static inline T reduce(const T& agg, const T& up) { return agg + up; } + static inline T finish(const T agg, const T p) { return std::pow(agg, 1.0 / p); } +}; + +template +void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, int64_t size_out, + int64_t d, int64_t r1, int64_t r2, int64_t c, double p) { + // x1 shape: (d1, d2, ..., dn, r1, c), treated as (d1 * ... * dn, r1 * c) + // x2 shape: (d1, d2, ..., dn, r2, c), treated as (d1 * ... * dn, r2 * c) + // out shape: (d1, d2, ..., dn, r1, r2), treated as (d1 * ... * dn, r1 * r2) + // d = d1 * ... * dn + stream->ParallelFor( + 0, size_out, + [x1, x2, out, d, r1, r2, c, p](int64_t begin, int64_t end) { + // begin is a multiple of c + T* out_begin = out + begin; + const T* out_end = out + end; + + int64_t d = r1 * r2; + int64_t batch_idx = begin / d; + int64_t vec_out_idx = begin - d * batch_idx; + int64_t vec1_idx = (vec_out_idx / r2); + int64_t vec2_idx = vec_out_idx - vec1_idx * r2; + int64_t vec1_begin = vec1_idx * c; + int64_t vec2_begin = vec2_idx * c; + int64_t size1 = r1 * c; + int64_t size2 = r1 * c; + + while (out_begin != out_end) { + T agg = 0; + const T* x1_begin = x1 + batch_idx * size1 + vec1_begin; + const T* x2_begin = x2 + batch_idx * size2 + vec2_begin; + FOR_RANGE(int32_t, idx, 0, c) { + T a = *(x1_begin + idx); + T b = *(x2_begin + idx); + agg = Dist::reduce(agg, Dist::map(std::abs(a - b), p)); + } + *out_begin = Dist::finish(agg, p); + out_begin += 1; + vec2_begin += c; + if (vec2_begin == r2 * c) { + vec2_begin = 0; + vec1_begin += c; + if (vec1_begin == r1 * c) { + vec1_begin = 0; + batch_idx += 1; + } + } + } + }, + c); +} + +template +class CpuCdistKernel final : public user_op::OpKernel { + public: + CpuCdistKernel() = default; + ~CpuCdistKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + double p = ctx->Attr("p"); + int64_t ndim = x1->shape_view().NumAxes(); + int64_t r1 = x1->shape_view().At(ndim - 2); + int64_t r2 = x2->shape_view().At(ndim - 2); + int64_t c = x1->shape_view().At(ndim - 1); + int64_t d = x1->shape_view().Count(0, ndim - 2); + + const T* x1_ptr = x1->dptr(); + const T* x2_ptr = x2->dptr(); + T* out_ptr = out->mut_dptr(); + + if (p == 0) { + CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + out->shape_view().elem_cnt(), d, r1, r2, c, p); + } else if (p == 1) { + CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + out->shape_view().elem_cnt(), d, r1, r2, c, p); + } else if (p == 2) { + CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + out->shape_view().elem_cnt(), d, r1, r2, c, p); + } else if (std::isinf(p)) { + CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + out->shape_view().elem_cnt(), d, r1, r2, c, p); + } else { + CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + out->shape_view().elem_cnt(), d, r1, r2, c, p); + }; + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CPU_CDIST_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("x1", 0) == GetDataType::value) \ + && (user_op::HobDataType("x2", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CPU_CDIST_KERNEL(float) +REGISTER_CPU_CDIST_KERNEL(double) +#undef REGISTER_CPU_CDIST_KERNEL + +} // namespace oneflow diff --git a/oneflow/user/ops/cdist_op.cpp b/oneflow/user/ops/cdist_op.cpp new file mode 100644 index 00000000000..c60a053d847 --- /dev/null +++ b/oneflow/user/ops/cdist_op.cpp @@ -0,0 +1,63 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" + +namespace oneflow { + +namespace { + +Maybe InferTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); + const user_op::TensorDesc& x2_desc = ctx->InputTensorDesc("x2", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); + + int64_t ndim = x1_desc.shape().NumAxes(); + Shape output_shape(x1_desc.shape().begin(), x1_desc.shape().end() - 2); + output_shape.emplace_back(x1_desc.shape().At(ndim - 2)); + output_shape.emplace_back(x2_desc.shape().At(ndim - 2)); + output_desc->set_shape(Shape(output_shape)); + + return Maybe::Ok(); +} + +} // namespace + +/* static */ Maybe CdistOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe CdistOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CdistOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe CdistOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); + user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); + if (IsIntegralDataType(x1_desc.data_type())) { + output_desc->set_data_type(DataType::kFloat); + } else { + output_desc->set_data_type(x1_desc.data_type()); + } + return Maybe::Ok(); +} + +} // namespace oneflow From f1bc1a10175bcdd9b3551c6253d2935a2de24396 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 8 Nov 2022 12:15:11 +0800 Subject: [PATCH 02/26] refine code --- oneflow/core/functional/impl/nn_functor.cpp | 11 ++++------- oneflow/user/ops/cdist_op.cpp | 6 +----- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 6ef0b02b1e6..3e534961fef 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -3003,7 +3003,8 @@ class CdistFunctor { Shape x1_batch_shape = Shape(DimVector({x1->shape()->begin(), x1->shape()->end() - 2})); Shape x2_batch_shape = Shape(DimVector({x2->shape()->begin(), x2->shape()->end() - 2})); - Shape max_batch_shape = Shape::Ones(std::max(x1_batch_shape.NumAxes(), x2_batch_shape.NumAxes())); + Shape max_batch_shape = + Shape::Ones(std::max(x1_batch_shape.NumAxes(), x2_batch_shape.NumAxes())); { for (int64_t i = max_batch_shape.NumAxes() - 1; i >= 0; i--) { int64_t offset = max_batch_shape.NumAxes() - 1 - i; @@ -3013,8 +3014,8 @@ class CdistFunctor { int64_t size_y = (dim_y >= 0) ? x2_batch_shape.At(dim_y) : 1; if (!(size_x == size_y || size_x == 1 || size_y == 1)) { return Error::RuntimeError() - << "The size of tensor a (" << size_x << ") must match the size of tensor b (" - << size_y << ") at non-singleton dimension " << i; + << "The size of tensor a (" << size_x << ") must match the size of tensor b (" + << size_y << ") at non-singleton dimension " << i; } max_batch_shape.Set(i, std::max(size_x, size_y)); } @@ -3022,16 +3023,12 @@ class CdistFunctor { // auto max_batch_shape = JUST(InferShapeForDistance(x1_batch_shape, x2_batch_shape)); Shape x1_expand_shape(max_batch_shape); Shape x2_expand_shape(max_batch_shape); - std::cout << x1_expand_shape.DebugStr() << std::endl; - std::cout << "r1: " << r1 << std::endl; x1_expand_shape.emplace_back(r1); x1_expand_shape.emplace_back(d); x2_expand_shape.emplace_back(r2); x2_expand_shape.emplace_back(d); - std::cout << "run1: " << x1_expand_shape.DebugStr() << std::endl; const auto x1_expand = JUST(Expand(x1, x1_expand_shape)); - std::cout << "run2: " << x2_expand_shape.DebugStr() << std::endl; const auto x2_expand = JUST(Expand(x2, x2_expand_shape)); TensorProcessor tensor_processor; diff --git a/oneflow/user/ops/cdist_op.cpp b/oneflow/user/ops/cdist_op.cpp index c60a053d847..87709b63fe4 100644 --- a/oneflow/user/ops/cdist_op.cpp +++ b/oneflow/user/ops/cdist_op.cpp @@ -52,11 +52,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { /* static */ Maybe CdistOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); - if (IsIntegralDataType(x1_desc.data_type())) { - output_desc->set_data_type(DataType::kFloat); - } else { - output_desc->set_data_type(x1_desc.data_type()); - } + output_desc->set_data_type(x1_desc.data_type()); return Maybe::Ok(); } From 374d1fe9e708df44982312458db1abc391113a15 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 9 Nov 2022 17:20:01 +0800 Subject: [PATCH 03/26] add backward func, register autograd --- .../core/autograd/gradient_funcs/cdist.cpp | 95 +++++++++++ oneflow/core/functional/functional_api.yaml | 4 + .../core/functional/impl/nn_grad_functor.cpp | 26 +++ oneflow/ir/include/OneFlow/OneFlowUserOps.td | 20 +++ oneflow/user/kernels/cdist_kernel.cpp | 149 ++++++++++++++++-- oneflow/user/ops/cdist_op.cpp | 42 ++++- python/oneflow/__init__.py | 1 + 7 files changed, 323 insertions(+), 14 deletions(-) create mode 100644 oneflow/core/autograd/gradient_funcs/cdist.cpp diff --git a/oneflow/core/autograd/gradient_funcs/cdist.cpp b/oneflow/core/autograd/gradient_funcs/cdist.cpp new file mode 100644 index 00000000000..8dbd4b96c2d --- /dev/null +++ b/oneflow/core/autograd/gradient_funcs/cdist.cpp @@ -0,0 +1,95 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_expr_grad_function.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/functional/functional.h" + +namespace oneflow { +namespace one { + +namespace { + +struct CDistCaptureState : public AutoGradCaptureState { + bool requires_grad = false; + size_t x1_index = 0; + size_t x2_index = 0; + size_t out_index = 0; + double p = 0.0; +}; + +class CDistGrad : public OpExprGradFunction { + public: + virtual ~CDistGrad() = default; + + using OpExprGradFunction::Init; + + Maybe Init(const OpExpr& op) override; + Maybe Capture(CDistCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, + const AttrMap& attrs) const override; + Maybe Apply(const CDistCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const override; + + private: + AttrMap base_attrs_; +}; + +Maybe CDistGrad::Init(const OpExpr& op) { + const auto* fw_op_expr = dynamic_cast(&op); + CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg) + base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto()); + return Maybe::Ok(); +} + +Maybe CDistGrad::Capture(CDistCaptureState* ctx, const TensorTuple& inputs, + const TensorTuple& outputs, const AttrMap& attrs) const { + ctx->requires_grad = inputs.at(0)->requires_grad(); + if (!ctx->requires_grad) { return Maybe::Ok(); } + + ctx->x1_index = ctx->SaveTensorForBackward(inputs.at(0)); + ctx->x2_index = ctx->SaveTensorForBackward(inputs.at(1)); + ctx->out_index = ctx->SaveTensorForBackward(outputs.at(0)); + ComposedAttrMap composed_attrs(attrs, base_attrs_); + ctx->p = JUST(composed_attrs.GetAttr("p")); + + return Maybe::Ok(); +} + +Maybe CDistGrad::Apply(const CDistCaptureState* ctx, const TensorTuple& out_grads, + TensorTuple* in_grads) const { + if (!ctx->requires_grad) { return Maybe::Ok(); } + CHECK_LE_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg) + + const auto& x1 = ctx->SavedTensors().at(ctx->x1_index); + const auto& x2 = ctx->SavedTensors().at(ctx->x2_index); + const auto& out = ctx->SavedTensors().at(ctx->out_index); + const double p = ctx->p; + + in_grads->resize(2); + (*in_grads)[0] = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p))->at(0); + (*in_grads)[1] = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p))->at(1); + return Maybe::Ok(); +} + +} // namespace + +REGISTER_OP_EXPR_GRAD_FUNCTION("cdist", CDistGrad); + +} // namespace one +} // namespace oneflow diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 0c026eb504d..52ec38d9167 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2319,6 +2319,10 @@ signature: 'Tensor (Tensor x1, Tensor x2, Double p=2.0, String compute_mode="use_mm_for_euclid_dist_if_necessary") => CDist' bind_python: True +- name: "cdist_grad" + signature: "TensorTuple (Tensor x1, Tensor x2, Tensor out, Tensor dy, Double p=2.0) => CDistGrad" + bind_python: True + - name: "normalize" signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize" bind_python: True diff --git a/oneflow/core/functional/impl/nn_grad_functor.cpp b/oneflow/core/functional/impl/nn_grad_functor.cpp index 9e0e2ad644f..fa8b87712b9 100644 --- a/oneflow/core/functional/impl/nn_grad_functor.cpp +++ b/oneflow/core/functional/impl/nn_grad_functor.cpp @@ -641,6 +641,31 @@ class BinaryCrossEntropyWithLogitsReduceMeanLossTargetGradFunctor { } }; +class CDistGradFunctor { + public: + CDistGradFunctor() { + op_ = CHECK_JUST(one::OpBuilder("cdist_grad") + .Input("x1") + .Input("x2") + .Input("out") + .Input("dy") + .Output("dx1") + .Output("dx2") + .Build()); + } + Maybe operator()(const std::shared_ptr& x1, + const std::shared_ptr& x2, + const std::shared_ptr& out, + const std::shared_ptr& dy, const double& p) const { + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p"); + attrs.SetAllAttrs(p); + return OpInterpUtil::Dispatch(*op_, {x1, x2, out, dy}, attrs); + } + + private: + std::shared_ptr op_; +}; + class CombinedMarginLossGradFunctor { public: CombinedMarginLossGradFunctor() { @@ -1529,6 +1554,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("SparseSoftmaxCrossEntropyGrad"); m.add_functor("SparseSoftmaxCrossEntropyMsGrad"); m.add_functor("SmoothL1LossGrad"); + m.add_functor("CDistGrad"); m.add_functor("CombinedMarginLossGrad"); m.add_functor("AffineGridGrad"); m.add_functor("GridSampleGrad"); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index 2a46574d2c0..68871647022 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -3941,6 +3941,26 @@ def OneFlow_CdistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceM let has_data_type_infer_fn = 1; } +def OneFlow_CdistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + OneFlow_Tensor:$x1, + OneFlow_Tensor:$x2, + OneFlow_Tensor:$out, + OneFlow_Tensor:$dy + ); + let output = (outs + OneFlow_Tensor:$dx1, + OneFlow_Tensor:$dx2 + ); + let attrs = (ins + DefaultValuedAttr:$p + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x diff --git a/oneflow/user/kernels/cdist_kernel.cpp b/oneflow/user/kernels/cdist_kernel.cpp index ec2fd6a357e..a5ba563330a 100644 --- a/oneflow/user/kernels/cdist_kernel.cpp +++ b/oneflow/user/kernels/cdist_kernel.cpp @@ -30,6 +30,7 @@ struct ZeroDist { static inline T map(const T& diff, const T& p) { return diff == T(0) ? diff : T(1); } static inline T reduce(const T& agg, const T& up) { return agg + up; } static inline T finish(const T agg, const T p) { return agg; } + // backward always return 0 }; template @@ -38,7 +39,7 @@ struct OneDist { static inline T reduce(const T& agg, const T& up) { return agg + up; } static inline T finish(const T agg, const T p) { return agg; } static inline T backward(const T& diff, const T grad, const T dist, const T& p) { - return T(grad) * diff > 0 ? T(1) : T(-1); + return grad * (diff > T(0) ? T(1) : T(-1)); } }; @@ -48,7 +49,7 @@ struct TwoDist { static inline T reduce(const T& agg, const T& up) { return agg + up; } static inline T finish(const T agg, const T p) { return std::sqrt(agg); } static inline T backward(const T& diff, const T grad, const T dist, const T& p) { - return dist == 0.0 ? T(0) : T(grad) * diff / T(dist); + return dist == 0.0 ? T(0) : grad * diff / dist; } }; @@ -58,7 +59,8 @@ struct InfiDist { static inline T reduce(const T& agg, const T& up) { return std::max(agg, up); } static inline T finish(const T agg, const T p) { return agg; } static inline T backward(const T& diff, const T grad, const T dist, const T& p) { - return dist == 0.0 ? T(0) : T(grad) * diff / T(dist); + return (T(1) - std::min(std::ceil(std::abs(std::abs(diff) - dist)), T(1))) * grad + * (diff > T(0) ? T(1) : T(-1)); } }; @@ -67,18 +69,25 @@ struct PDist { static inline T map(const T& diff, const T& p) { return std::pow(diff, p); } static inline T reduce(const T& agg, const T& up) { return agg + up; } static inline T finish(const T agg, const T p) { return std::pow(agg, 1.0 / p); } + static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + if (dist == 0.0) { + return T(0); + } else { + return diff * std::pow(std::abs(diff), p - T(2)) * grad / std::pow(dist, p - T(1)); + } + } }; template void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, int64_t size_out, - int64_t d, int64_t r1, int64_t r2, int64_t c, double p) { + int64_t r1, int64_t r2, int64_t c, double p) { // x1 shape: (d1, d2, ..., dn, r1, c), treated as (d1 * ... * dn, r1 * c) // x2 shape: (d1, d2, ..., dn, r2, c), treated as (d1 * ... * dn, r2 * c) // out shape: (d1, d2, ..., dn, r1, r2), treated as (d1 * ... * dn, r1 * r2) // d = d1 * ... * dn stream->ParallelFor( 0, size_out, - [x1, x2, out, d, r1, r2, c, p](int64_t begin, int64_t end) { + [x1, x2, out, r1, r2, c, p](int64_t begin, int64_t end) { // begin is a multiple of c T* out_begin = out + begin; const T* out_end = out + end; @@ -91,7 +100,7 @@ void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, in int64_t vec1_begin = vec1_idx * c; int64_t vec2_begin = vec2_idx * c; int64_t size1 = r1 * c; - int64_t size2 = r1 * c; + int64_t size2 = r2 * c; while (out_begin != out_end) { T agg = 0; @@ -118,6 +127,56 @@ void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, in c); } +template +void CpuCdistBackward(ep::CpuStream* stream, const T* x1, const T* x2, const T* dist, const T* grad, + T* grad1, T* grad2, int64_t size_out, int64_t r1, int64_t r2, int64_t c, + double p) { + stream->ParallelFor( + 0, size_out, + [=](int64_t begin, int64_t end) { + const T* dist_begin = dist + begin; + const T* dist_end = dist + end; + const T* dist_grad = grad + begin; + + int64_t d = r1 * r2; + int64_t batch_idx = begin / d; + int64_t vec_out_idx = begin - d * batch_idx; + int64_t vec1_idx = (vec_out_idx / r2); + int64_t vec2_idx = vec_out_idx - vec1_idx * r2; + int64_t vec1_begin = vec1_idx * c; + int64_t vec2_begin = vec2_idx * c; + int64_t size1 = r1 * c; + int64_t size2 = r2 * c; + + while (dist_begin != dist_end) { + const T* x1_begin = x1 + batch_idx * size1 + vec1_begin; + const T* x2_begin = x2 + batch_idx * size2 + vec2_begin; + T* x1_grad_begin = grad1 + batch_idx * size1 + vec1_begin; + T* x2_grad_begin = grad2 + batch_idx * size2 + vec2_begin; + FOR_RANGE(int32_t, idx, 0, c) { + T a = *(x1_begin + idx); + T b = *(x2_begin + idx); + T diff = a - b; + *(x1_grad_begin + idx) += Dist::backward(diff, *dist_grad, *dist_begin, p); + *(x2_grad_begin + idx) += Dist::backward(-diff, *dist_grad, *dist_begin, p); + } + + dist_begin += 1; + dist_grad += 1; + vec2_begin += c; + if (vec2_begin == r2 * c) { + vec2_begin = 0; + vec1_begin += c; + if (vec1_begin == r1 * c) { + vec1_begin = 0; + batch_idx += 1; + } + } + } + }, + c); +} + template class CpuCdistKernel final : public user_op::OpKernel { public: @@ -134,7 +193,6 @@ class CpuCdistKernel final : public user_op::OpKernel { int64_t r1 = x1->shape_view().At(ndim - 2); int64_t r2 = x2->shape_view().At(ndim - 2); int64_t c = x1->shape_view().At(ndim - 1); - int64_t d = x1->shape_view().Count(0, ndim - 2); const T* x1_ptr = x1->dptr(); const T* x2_ptr = x2->dptr(); @@ -142,19 +200,74 @@ class CpuCdistKernel final : public user_op::OpKernel { if (p == 0) { CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, - out->shape_view().elem_cnt(), d, r1, r2, c, p); + out->shape_view().elem_cnt(), r1, r2, c, p); } else if (p == 1) { CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, - out->shape_view().elem_cnt(), d, r1, r2, c, p); + out->shape_view().elem_cnt(), r1, r2, c, p); } else if (p == 2) { CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, - out->shape_view().elem_cnt(), d, r1, r2, c, p); + out->shape_view().elem_cnt(), r1, r2, c, p); } else if (std::isinf(p)) { CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, - out->shape_view().elem_cnt(), d, r1, r2, c, p); + out->shape_view().elem_cnt(), r1, r2, c, p); } else { CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, - out->shape_view().elem_cnt(), d, r1, r2, c, p); + out->shape_view().elem_cnt(), r1, r2, c, p); + }; + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +class CpuCdistGradKernel final : public user_op::OpKernel { + public: + CpuCdistGradKernel() = default; + ~CpuCdistGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); + const user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* dx1 = ctx->Tensor4ArgNameAndIndex("dx1", 0); + user_op::Tensor* dx2 = ctx->Tensor4ArgNameAndIndex("dx2", 0); + double p = ctx->Attr("p"); + int64_t ndim = x1->shape_view().NumAxes(); + int64_t r1 = x1->shape_view().At(ndim - 2); + int64_t r2 = x2->shape_view().At(ndim - 2); + int64_t c = x1->shape_view().At(ndim - 1); + + const T* x1_ptr = x1->dptr(); + const T* x2_ptr = x2->dptr(); + const T* dist_ptr = out->dptr(); + const T* grad_ptr = dy->dptr(); + + T* dx1_ptr = dx1->mut_dptr(); + T* dx2_ptr = dx2->mut_dptr(); + + if (p == 0) { + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->device_type()); + CHECK(memset_primitive); + memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); + memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); + } else if (p == 1) { + CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, + r2, c, p); + } else if (p == 2) { + CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, + r2, c, p); + } else if (std::isinf(p)) { + CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, + r2, c, p); + } else { + CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, + r2, c, p); }; } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } @@ -171,4 +284,16 @@ REGISTER_CPU_CDIST_KERNEL(float) REGISTER_CPU_CDIST_KERNEL(double) #undef REGISTER_CPU_CDIST_KERNEL +#define REGISTER_CPU_CDIST_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cdist_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ + && (user_op::HobDataType("x1", 0) == GetDataType::value) \ + && (user_op::HobDataType("x2", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CPU_CDIST_GRAD_KERNEL(float) +REGISTER_CPU_CDIST_GRAD_KERNEL(double) +#undef REGISTER_CPU_CDIST_KERNEL + } // namespace oneflow diff --git a/oneflow/user/ops/cdist_op.cpp b/oneflow/user/ops/cdist_op.cpp index 87709b63fe4..2663c9eaa7f 100644 --- a/oneflow/user/ops/cdist_op.cpp +++ b/oneflow/user/ops/cdist_op.cpp @@ -21,7 +21,7 @@ namespace oneflow { namespace { -Maybe InferTensorDesc(user_op::InferContext* ctx) { +Maybe FwdInferTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); const user_op::TensorDesc& x2_desc = ctx->InputTensorDesc("x2", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); @@ -38,7 +38,7 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } // namespace /* static */ Maybe CdistOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { - return InferTensorDesc(ctx); + return FwdInferTensorDesc(ctx); } /*static*/ Maybe CdistOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { @@ -56,4 +56,42 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } +namespace { + +Maybe BwdInferTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); + const user_op::TensorDesc& x2_desc = ctx->InputTensorDesc("x2", 0); + user_op::TensorDesc* dx1_desc = ctx->MutOutputTensorDesc("dx1", 0); + user_op::TensorDesc* dx2_desc = ctx->MutOutputTensorDesc("dx2", 0); + + dx1_desc->set_shape(x1_desc.shape()); + dx2_desc->set_shape(x2_desc.shape()); + + return Maybe::Ok(); +} + +} // namespace + +/* static */ Maybe CdistGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return BwdInferTensorDesc(ctx); +} + +/*static*/ Maybe CdistGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CdistGradOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe CdistGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); + user_op::TensorDesc* dx1_desc = ctx->MutOutputTensorDesc("dx1", 0); + user_op::TensorDesc* dx2_desc = ctx->MutOutputTensorDesc("dx2", 0); + dx1_desc->set_data_type(x1_desc.data_type()); + dx2_desc->set_data_type(x1_desc.data_type()); + + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 25f5eec4bc7..92755c07d5f 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -223,6 +223,7 @@ def is_deprecated(func_or_class): from oneflow._C import multinomial from oneflow._C import linalg_cross as cross from oneflow._C import bincount +from oneflow._C import cdist from oneflow._oneflow_internal import _set_num_threads as set_num_threads from . import sbp From e1047e14b6429ccdd94811d60fb7f952a5075f08 Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 10 Nov 2022 10:41:35 +0800 Subject: [PATCH 04/26] fix bug of memory init, and grad func call, add eager unittest --- .../core/autograd/gradient_funcs/cdist.cpp | 5 +- oneflow/user/kernels/cdist_kernel.cpp | 12 ++-- python/oneflow/test/modules/test_cdist.py | 62 +++++++++++++++++++ 3 files changed, 72 insertions(+), 7 deletions(-) create mode 100644 python/oneflow/test/modules/test_cdist.py diff --git a/oneflow/core/autograd/gradient_funcs/cdist.cpp b/oneflow/core/autograd/gradient_funcs/cdist.cpp index 8dbd4b96c2d..20a8063bec0 100644 --- a/oneflow/core/autograd/gradient_funcs/cdist.cpp +++ b/oneflow/core/autograd/gradient_funcs/cdist.cpp @@ -82,8 +82,9 @@ Maybe CDistGrad::Apply(const CDistCaptureState* ctx, const TensorTuple& ou const double p = ctx->p; in_grads->resize(2); - (*in_grads)[0] = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p))->at(0); - (*in_grads)[1] = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p))->at(1); + auto results = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p)); + (*in_grads)[0] = results->at(0); + (*in_grads)[1] = results->at(1); return Maybe::Ok(); } diff --git a/oneflow/user/kernels/cdist_kernel.cpp b/oneflow/user/kernels/cdist_kernel.cpp index a5ba563330a..cb42ddfa428 100644 --- a/oneflow/user/kernels/cdist_kernel.cpp +++ b/oneflow/user/kernels/cdist_kernel.cpp @@ -246,12 +246,14 @@ class CpuCdistGradKernel final : public user_op::OpKernel { T* dx1_ptr = dx1->mut_dptr(); T* dx2_ptr = dx2->mut_dptr(); + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->device_type()); + CHECK(memset_primitive); + memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); + memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); + if (p == 0) { - std::unique_ptr memset_primitive = - ep::primitive::NewPrimitive(ctx->device_type()); - CHECK(memset_primitive); - memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); - memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); + // grad is always zero } else if (p == 1) { CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py new file mode 100644 index 00000000000..178467a9afa --- /dev/null +++ b/python/oneflow/test/modules/test_cdist.py @@ -0,0 +1,62 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow as flow +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + +@flow.unittest.skip_unless_1n1d() +class TestCDist(flow.unittest.TestCase): + @autotest(n=2, check_graph=True) + def test_zero_cdist(test_case): + dim0 = random() + dim2 = random() + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + return torch.cdist(x1, x2, p=0) + + @autotest(n=2, check_graph=True) + def test_one_cdist(test_case): + dim0 = random() + dim2 = random() + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + return torch.cdist(x1, x2, p=1) + + @autotest(n=2, check_graph=True) + def test_two_cdist(test_case): + dim0 = random() + dim2 = random() + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + return torch.cdist(x1, x2, p=2) + + @autotest(n=2, check_graph=True) + def test_infi_cdist(test_case): + dim0 = random() + dim2 = random() + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + return torch.cdist(x1, x2, p=float("inf")) + +if __name__ == "__main__": + unittest.main() From b846c6920994130322033ddfbf0f6c0b2ec2d592 Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 10 Nov 2022 14:28:38 +0800 Subject: [PATCH 05/26] add docs --- docs/source/oneflow.rst | 1 + python/oneflow/framework/docstr/distance.py | 47 +++++++++++++++++++++ python/oneflow/test/modules/test_cdist.py | 13 ++++-- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index 479bb686bc8..7b6b70bda3e 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -370,6 +370,7 @@ Other Ops adaptive_avg_pool3d broadcast_like cast + cdist cumprod cumsum decode_onerec diff --git a/python/oneflow/framework/docstr/distance.py b/python/oneflow/framework/docstr/distance.py index cdc259ad2d1..65336a888f7 100644 --- a/python/oneflow/framework/docstr/distance.py +++ b/python/oneflow/framework/docstr/distance.py @@ -84,3 +84,50 @@ """, ) + +add_docstr( + oneflow._C.cdist, + r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. + + The interface is consistent with PyTorch. + The documentation is referenced from: https://pytorch.org/docs/1.10/generated/torch.cdist.html. + + Args: + x1 (Tensor): input tensor of shape :math:`B \times P \times M`. + x2 (Tensor): input tensor of shape :math:`B \times R \times M`. + p: p value for the p-norm distance to calculate between each vector pair + :math:`\in [0, \infty]`. + compute_mode: + 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate + euclidean distance (p = 2) if P > 25 or R > 25 + 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate + euclidean distance (p = 2) + 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate + euclidean distance (p = 2) + Default: use_mm_for_euclid_dist_if_necessary. + + If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the + output will have shape :math:`B \times P \times R`. + + This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)` + if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to + `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest + scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`. + + For example: + + .. code-block:: python + + >>> import oneflow as flow + >>> x = flow.Tensor([[1., 2], [3, 4]]) + >>> y = flow.Tensor([[5., 6], [7, 8]]) + >>> flow.cdist(x, y) + tensor([[5.6569, 8.4853], + [2.8284, 5.6569]], dtype=oneflow.float32) + >>> flow.cdist(x, y, p=1) + tensor([[ 8., 12.], + [ 4., 8.]], dtype=oneflow.float32) + + """ + +) \ No newline at end of file diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 178467a9afa..3949ff33fd6 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -15,13 +15,9 @@ """ import unittest -from collections import OrderedDict - -import numpy as np import oneflow as flow import oneflow.unittest - from oneflow.test_utils.automated_test_util import * @flow.unittest.skip_unless_1n1d() @@ -58,5 +54,14 @@ def test_infi_cdist(test_case): x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=float("inf")) + @autotest(n=2, check_graph=True) + def test_random_p_cdist(test_case): + dim0 = random() + dim2 = random() + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + p = random(0, 4).to(float) + return torch.cdist(x1, x2, p=p) + if __name__ == "__main__": unittest.main() From d7b8b625196faa0924f39e0b641987cd3dbb942e Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 10 Nov 2022 14:53:18 +0800 Subject: [PATCH 06/26] refine code --- python/oneflow/framework/docstr/distance.py | 5 ++--- python/oneflow/test/modules/test_cdist.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/oneflow/framework/docstr/distance.py b/python/oneflow/framework/docstr/distance.py index 65336a888f7..02e0969dbff 100644 --- a/python/oneflow/framework/docstr/distance.py +++ b/python/oneflow/framework/docstr/distance.py @@ -128,6 +128,5 @@ tensor([[ 8., 12.], [ 4., 8.]], dtype=oneflow.float32) - """ - -) \ No newline at end of file + """, +) diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 3949ff33fd6..9f9162c3b77 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -20,6 +20,7 @@ import oneflow.unittest from oneflow.test_utils.automated_test_util import * + @flow.unittest.skip_unless_1n1d() class TestCDist(flow.unittest.TestCase): @autotest(n=2, check_graph=True) @@ -54,14 +55,15 @@ def test_infi_cdist(test_case): x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=float("inf")) - @autotest(n=2, check_graph=True) + @autotest(n=5, check_graph=True) def test_random_p_cdist(test_case): dim0 = random() dim2 = random() x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - p = random(0, 4).to(float) + p = random(0, 4).to(float) return torch.cdist(x1, x2, p=p) + if __name__ == "__main__": unittest.main() From 4a792bb19918263dc23923cfc10a490b7b5a2b32 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Nov 2022 14:40:07 +0800 Subject: [PATCH 07/26] add compute_mode unittest --- oneflow/core/functional/impl/nn_functor.cpp | 15 +++++++++ python/oneflow/test/modules/test_cdist.py | 36 ++++++++++++++++++--- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 840d3f68519..4928adf95d3 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2878,6 +2878,17 @@ class CdistFunctor { CdistFunctor() { op_ = CHECK_JUST(OpBuilder("cdist").Input("x1").Input("x2").Output("out").Build()); } + Maybe euclidean_dist(const std::shared_ptr& x1, const std::shared_ptr& x2) const { + const auto& x1_norm = JUST(ReduceSum(JUST(ScalarPow(x1, 2, false)), {-1}, true)); + const auto& x2_norm = JUST(ReduceSum(JUST(ScalarPow(x2, 2, false)), {-1}, true)); + const auto& x1_ones = JUST(OnesLike(x1_norm)); + const auto& x2_ones = JUST(OnesLike(x2_norm)); + const auto& x1_cat = JUST(Concat({JUST(ScalarMul(x1, -2, false)), x1_norm, x1_ones}, -1)); + const auto& x2_cat = JUST(Concat({x2, x2_ones, x2_norm}, -1)); + const auto& result = JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0)); + return Sqrt(JUST(ClampMin(result, 0))); + }; + Maybe operator()(const std::shared_ptr& x1, const std::shared_ptr& x2, const double& p, const std::string& compute_mode) const { const int64_t x1_ndim = x1->ndim(); @@ -2936,6 +2947,10 @@ class CdistFunctor { const auto x1_expand = JUST(Expand(x1, x1_expand_shape)); const auto x2_expand = JUST(Expand(x2, x2_expand_shape)); + if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) { + return euclidean_dist(x1_expand, x2_expand); + } + TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true) .AddInputs({x1_expand, x2_expand}) diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 9f9162c3b77..0c92c7891cb 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -15,6 +15,7 @@ """ import unittest +import random as random_utils import oneflow as flow import oneflow.unittest @@ -27,42 +28,67 @@ class TestCDist(flow.unittest.TestCase): def test_zero_cdist(test_case): dim0 = random() dim2 = random() + mode = random_utils.choice([ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - return torch.cdist(x1, x2, p=0) + return torch.cdist(x1, x2, p=0, compute_mode=mode) @autotest(n=2, check_graph=True) def test_one_cdist(test_case): dim0 = random() dim2 = random() + mode = random_utils.choice([ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - return torch.cdist(x1, x2, p=1) + return torch.cdist(x1, x2, p=1, compute_mode=mode) @autotest(n=2, check_graph=True) def test_two_cdist(test_case): dim0 = random() dim2 = random() + mode = random_utils.choice([ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - return torch.cdist(x1, x2, p=2) + return torch.cdist(x1, x2, p=2, compute_mode=mode) @autotest(n=2, check_graph=True) def test_infi_cdist(test_case): dim0 = random() dim2 = random() + mode = random_utils.choice([ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - return torch.cdist(x1, x2, p=float("inf")) + return torch.cdist(x1, x2, p=float("inf"), compute_mode=mode) @autotest(n=5, check_graph=True) def test_random_p_cdist(test_case): dim0 = random() dim2 = random() + mode = random_utils.choice([ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) p = random(0, 4).to(float) - return torch.cdist(x1, x2, p=p) + return torch.cdist(x1, x2, p=p, compute_mode=mode) if __name__ == "__main__": From 015c0f2954ec01f4ec01b2c2ae9cff1ae9d927da Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Nov 2022 14:42:45 +0800 Subject: [PATCH 08/26] format code --- oneflow/core/functional/impl/nn_functor.cpp | 7 +-- python/oneflow/test/modules/test_cdist.py | 60 ++++++++++++--------- 2 files changed, 39 insertions(+), 28 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 4928adf95d3..981a1a6a342 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2878,14 +2878,16 @@ class CdistFunctor { CdistFunctor() { op_ = CHECK_JUST(OpBuilder("cdist").Input("x1").Input("x2").Output("out").Build()); } - Maybe euclidean_dist(const std::shared_ptr& x1, const std::shared_ptr& x2) const { + Maybe euclidean_dist(const std::shared_ptr& x1, + const std::shared_ptr& x2) const { const auto& x1_norm = JUST(ReduceSum(JUST(ScalarPow(x1, 2, false)), {-1}, true)); const auto& x2_norm = JUST(ReduceSum(JUST(ScalarPow(x2, 2, false)), {-1}, true)); const auto& x1_ones = JUST(OnesLike(x1_norm)); const auto& x2_ones = JUST(OnesLike(x2_norm)); const auto& x1_cat = JUST(Concat({JUST(ScalarMul(x1, -2, false)), x1_norm, x1_ones}, -1)); const auto& x2_cat = JUST(Concat({x2, x2_ones, x2_norm}, -1)); - const auto& result = JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0)); + const auto& result = + JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0)); return Sqrt(JUST(ClampMin(result, 0))); }; @@ -2936,7 +2938,6 @@ class CdistFunctor { max_batch_shape.Set(i, std::max(size_x, size_y)); } } - // auto max_batch_shape = JUST(InferShapeForDistance(x1_batch_shape, x2_batch_shape)); Shape x1_expand_shape(max_batch_shape); Shape x2_expand_shape(max_batch_shape); x1_expand_shape.emplace_back(r1); diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 0c92c7891cb..53b5ea3a16e 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -28,11 +28,13 @@ class TestCDist(flow.unittest.TestCase): def test_zero_cdist(test_case): dim0 = random() dim2 = random() - mode = random_utils.choice([ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ]) + mode = random_utils.choice( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=0, compute_mode=mode) @@ -41,11 +43,13 @@ def test_zero_cdist(test_case): def test_one_cdist(test_case): dim0 = random() dim2 = random() - mode = random_utils.choice([ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ]) + mode = random_utils.choice( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=1, compute_mode=mode) @@ -54,11 +58,13 @@ def test_one_cdist(test_case): def test_two_cdist(test_case): dim0 = random() dim2 = random() - mode = random_utils.choice([ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ]) + mode = random_utils.choice( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=2, compute_mode=mode) @@ -67,11 +73,13 @@ def test_two_cdist(test_case): def test_infi_cdist(test_case): dim0 = random() dim2 = random() - mode = random_utils.choice([ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ]) + mode = random_utils.choice( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) return torch.cdist(x1, x2, p=float("inf"), compute_mode=mode) @@ -80,11 +88,13 @@ def test_infi_cdist(test_case): def test_random_p_cdist(test_case): dim0 = random() dim2 = random() - mode = random_utils.choice([ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ]) + mode = random_utils.choice( + [ + "use_mm_for_euclid_dist_if_necessary", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", + ] + ) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) p = random(0, 4).to(float) From 8f30598d8f2ca03297b47f41cdffee0b4f29e311 Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 17 Nov 2022 11:24:48 +0800 Subject: [PATCH 09/26] add cuda kernel --- oneflow/user/kernels/cdist_kernel.cu | 310 ++++++++++++++++++++++ python/oneflow/test/modules/test_cdist.py | 25 +- 2 files changed, 325 insertions(+), 10 deletions(-) create mode 100644 oneflow/user/kernels/cdist_kernel.cu diff --git a/oneflow/user/kernels/cdist_kernel.cu b/oneflow/user/kernels/cdist_kernel.cu new file mode 100644 index 00000000000..eab87856487 --- /dev/null +++ b/oneflow/user/kernels/cdist_kernel.cu @@ -0,0 +1,310 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include +#include "oneflow/core/ep/include/primitive/memset.h" +#include "oneflow/core/ep/include/stream.h" +#include "oneflow/core/framework/user_op_hob.h" +#include "oneflow/core/ndarray/xpu_util.h" + +namespace oneflow { + +template +static __forceinline__ __device__ T sign(T val) { + return (0 < val) - (val < 0); +} + +template +static __forceinline__ __device__ T device_sqrt(T val); + +template<> +__forceinline__ __device__ float device_sqrt(float val) { + return ::sqrtf(val); +} + +template<> +__forceinline__ __device__ double device_sqrt(double val) { + return ::sqrt(val); +} + +// Zero norm +template +struct ZeroDist { + static __forceinline__ __device__ void inc(T& agg, const T diff, const double /*p*/) { + agg += diff != 0.0; + } + static __forceinline__ __device__ T finish(const T agg, const double /*p*/) { return agg; } + static __forceinline__ __device__ void agg(T& update, const T other) { update += other; } +}; + +// One norm +template +struct OneDist { + static __forceinline__ __device__ void inc(T& agg, const T diff, const double /*p*/) { + agg += diff; + } + static __forceinline__ __device__ T finish(const T agg, const double /*p*/) { return agg; } + static __forceinline__ __device__ void agg(T& update, const T other) { update += other; } + static __forceinline__ __device__ T backward(const T diff, const T grad, const T /*dist*/, + const double /*p*/) { + return grad * sign(diff); + } +}; + +// Two norm +template +struct TwoDist { + static __forceinline__ __device__ void inc(T& agg, const T diff, const double /*p*/) { + agg += diff * diff; + } + static __forceinline__ __device__ T finish(const T agg, const double /*p*/) { + return device_sqrt(agg); + } + static __forceinline__ __device__ void agg(T& update, const T other) { update += other; } + static __forceinline__ __device__ T backward(const T diff, const T grad, const T dist, + const double /*p*/) { + return dist == 0.0 ? 0 : grad * diff / dist; + } +}; + +// General p norm +template +struct PDist { + static __forceinline__ __device__ void inc(T& agg, const T diff, const double p) { + agg += std::pow(diff, p); + } + static __forceinline__ __device__ T finish(const T agg, const double p) { + return std::pow(agg, static_cast(1) / p); + } + static __forceinline__ __device__ void agg(T& update, const T other) { update += other; } + static __forceinline__ __device__ T backward(const T diff, const T grad, const T dist, + const double p) { + return dist == 0.0 ? 0 : diff * std::pow(std::abs(diff), p - 2) * grad / std::pow(dist, p - 1); + } +}; + +// Inf norm +template +struct InfiDist { + static __forceinline__ __device__ void inc(T& agg, const T diff, const double /*p*/) { + if (diff > agg) { agg = diff; } + } + static __forceinline__ __device__ T finish(const T agg, const double /*p*/) { return agg; } + static __forceinline__ __device__ void agg(T& update, const T other) { + if (other > update) { update = other; } + } + static __forceinline__ __device__ T backward(const T diff, const T grad, const T dist, + const double /*p*/) { + return grad * sign(diff) * (std::abs(diff) == dist); + } +}; + +template +struct DistReduce { + __forceinline__ __device__ T operator()(T a, T b) const { + Dist::agg(a, b); + return a; + } +}; + +template +__global__ static void CUDACDistForward(const T* x1, const T* x2, T* out, int64_t r1, int64_t r2, + int64_t c, int64_t r_size, int64_t r1_size, int64_t r2_size, + double p) { + const int64_t batch_idx = blockIdx.x / r_size; + const int64_t vec_out_idx = blockIdx.x - batch_idx * r_size; + const int64_t vec1_idx = vec_out_idx / r2; + const int64_t vec2_idx = vec_out_idx - vec1_idx * r2; + const int64_t stride = blockDim.x; + + const T* vec1_begin = x1 + batch_idx * r1_size + vec1_idx * c + threadIdx.x; + const T* vec1_end = x1 + batch_idx * r1_size + vec1_idx * c + c; + const T* vec2_begin = x2 + batch_idx * r2_size + vec2_idx * c + threadIdx.x; + + T agg = 0; + for (; vec1_begin < vec1_end; vec1_begin += stride, vec2_begin += stride) { + Dist::inc(agg, std::abs(*vec1_begin - *vec2_begin), p); + } + + __syncthreads(); + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + T result = BlockReduce(temp_storage).Reduce(agg, DistReduce()); + if (threadIdx.x == 0) { out[blockIdx.x] = Dist::finish(result, p); } +} + +template +__global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist, const T* dist_grad, + T* grad1, T* grad2, int64_t r1, int64_t r2, int64_t c, + int64_t r_size, int64_t r1_size, int64_t r2_size, + double p) { + const int64_t batch_idx = blockIdx.x / r_size; + const int64_t vec_out_idx = blockIdx.x - batch_idx * r_size; + const int64_t vec1_idx = vec_out_idx / r2; + const int64_t vec2_idx = vec_out_idx - vec1_idx * r2; + const int64_t stride = blockDim.x; + + const T* vec1_begin = x1 + batch_idx * r1_size + vec1_idx * c + threadIdx.x; + const T* vec1_end = x1 + batch_idx * r1_size + vec1_idx * c + c; + const T* vec2_begin = x2 + batch_idx * r2_size + vec2_idx * c + threadIdx.x; + + T* grad1_begin = vec1_begin - x1 + grad1; + T* grad2_begin = vec2_begin - x2 + grad2; + T diff = *vec1_begin - *vec2_begin; + + atomicAdd(grad1_begin, Dist::backward(diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p)); + atomicAdd(grad2_begin, Dist::backward(-diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p)); +} + +template +class CUDACdistKernel final : public user_op::OpKernel { + public: + CUDACdistKernel() = default; + ~CUDACdistKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + double p = ctx->Attr("p"); + int64_t ndim = x1->shape_view().NumAxes(); + int64_t r1 = x1->shape_view().At(ndim - 2); + int64_t r2 = x2->shape_view().At(ndim - 2); + int64_t c = x1->shape_view().At(ndim - 1); + + const int64_t r1_size = r1 * c; + const int64_t r2_size = r2 * c; + const int64_t r_size = r1 * r2; + + const T* x1_ptr = x1->dptr(); + const T* x2_ptr = x2->dptr(); + T* out_ptr = out->mut_dptr(); + + if (p == 0) { + CUDACDistForward><<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); + } else if (p == 1) { + CUDACDistForward><<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); + } else if (p == 2) { + CUDACDistForward><<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); + } else if (std::isinf(p)) { + CUDACDistForward><<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); + } else { + CUDACDistForward><<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); + } + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +template +class CUDACdistGradKernel final : public user_op::OpKernel { + public: + CUDACdistGradKernel() = default; + ~CUDACdistGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); + const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); + const user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* dy = ctx->Tensor4ArgNameAndIndex("dy", 0); + user_op::Tensor* dx1 = ctx->Tensor4ArgNameAndIndex("dx1", 0); + user_op::Tensor* dx2 = ctx->Tensor4ArgNameAndIndex("dx2", 0); + double p = ctx->Attr("p"); + int64_t ndim = x1->shape_view().NumAxes(); + int64_t r1 = x1->shape_view().At(ndim - 2); + int64_t r2 = x2->shape_view().At(ndim - 2); + int64_t c = x1->shape_view().At(ndim - 1); + + const T* x1_ptr = x1->dptr(); + const T* x2_ptr = x2->dptr(); + const T* dist_ptr = out->dptr(); + const T* grad_ptr = dy->dptr(); + + const int64_t r1_size = r1 * c; + const int64_t r2_size = r2 * c; + const int64_t r_size = r1 * r2; + + T* dx1_ptr = dx1->mut_dptr(); + T* dx2_ptr = dx2->mut_dptr(); + + std::unique_ptr memset_primitive = + ep::primitive::NewPrimitive(ctx->device_type()); + CHECK(memset_primitive); + memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); + memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); + + if (p == 0) { + // grad is always zero + } else if (p == 1) { + CUDACDistBackward><<shape_view().elem_cnt(), c, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, + p); + } else if (p == 2) { + CUDACDistBackward><<shape_view().elem_cnt(), c, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, + p); + } else if (std::isinf(p)) { + CUDACDistBackward><<shape_view().elem_cnt(), c, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, + p); + } else { + CUDACDistBackward><<shape_view().elem_cnt(), c, 0, + ctx->stream()->As()->cuda_stream()>>>( + x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, + p); + } + cudaDeviceSynchronize(); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CUDA_CDIST_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ + (user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("x1", 0) == GetDataType::value) \ + && (user_op::HobDataType("x2", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CUDA_CDIST_KERNEL(float) +REGISTER_CUDA_CDIST_KERNEL(double) +#undef REGISTER_CUDA_CDIST_KERNEL + +#define REGISTER_CUDA_CDIST_GRAD_KERNEL(dtype) \ + REGISTER_USER_KERNEL("cdist_grad") \ + .SetCreateFn>() \ + .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ + && (user_op::HobDataType("x1", 0) == GetDataType::value) \ + && (user_op::HobDataType("x2", 0) == GetDataType::value) \ + && (user_op::HobDataType("out", 0) == GetDataType::value)); + +REGISTER_CUDA_CDIST_GRAD_KERNEL(float) +REGISTER_CUDA_CDIST_GRAD_KERNEL(double) +#undef REGISTER_CUDA_CDIST_KERNEL + +} // namespace oneflow diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 53b5ea3a16e..dc7e9644105 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -26,6 +26,7 @@ class TestCDist(flow.unittest.TestCase): @autotest(n=2, check_graph=True) def test_zero_cdist(test_case): + device = random_device() dim0 = random() dim2 = random() mode = random_utils.choice( @@ -35,12 +36,13 @@ def test_zero_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) return torch.cdist(x1, x2, p=0, compute_mode=mode) @autotest(n=2, check_graph=True) def test_one_cdist(test_case): + device = random_device() dim0 = random() dim2 = random() mode = random_utils.choice( @@ -50,12 +52,13 @@ def test_one_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) return torch.cdist(x1, x2, p=1, compute_mode=mode) @autotest(n=2, check_graph=True) def test_two_cdist(test_case): + device = random_device() dim0 = random() dim2 = random() mode = random_utils.choice( @@ -65,12 +68,13 @@ def test_two_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) return torch.cdist(x1, x2, p=2, compute_mode=mode) @autotest(n=2, check_graph=True) def test_infi_cdist(test_case): + device = random_device() dim0 = random() dim2 = random() mode = random_utils.choice( @@ -80,12 +84,13 @@ def test_infi_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) return torch.cdist(x1, x2, p=float("inf"), compute_mode=mode) @autotest(n=5, check_graph=True) def test_random_p_cdist(test_case): + device = random_device() dim0 = random() dim2 = random() mode = random_utils.choice( @@ -95,8 +100,8 @@ def test_random_p_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2) + x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) + x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) p = random(0, 4).to(float) return torch.cdist(x1, x2, p=p, compute_mode=mode) From 130e28e4a33a8c175347381693485fb1d7f25147 Mon Sep 17 00:00:00 2001 From: WangYi Date: Fri, 18 Nov 2022 17:39:09 +0800 Subject: [PATCH 10/26] add cuda backward kernel, refine code (Cdist=>CDist), refine unittest, remove use_mm option --- oneflow/core/functional/impl/nn_functor.cpp | 19 +++-- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 8 +- oneflow/user/kernels/cdist_kernel.cpp | 38 +++++----- oneflow/user/kernels/cdist_kernel.cu | 61 ++++++++++----- oneflow/user/ops/cdist_op.cpp | 16 ++-- python/oneflow/test/modules/test_cdist.py | 74 ++----------------- .../test/modules/test_large_size_tensor.py | 39 ++++++++++ 7 files changed, 131 insertions(+), 124 deletions(-) create mode 100644 python/oneflow/test/modules/test_large_size_tensor.py diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 2c22ca89914..552aa45fd23 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2924,9 +2924,10 @@ class CosineSimilarityFunctor { } }; -class CdistFunctor { + +class CDistFunctor { public: - CdistFunctor() { + CDistFunctor() { op_ = CHECK_JUST(OpBuilder("cdist").Input("x1").Input("x2").Output("out").Build()); } Maybe euclidean_dist(const std::shared_ptr& x1, @@ -2939,7 +2940,7 @@ class CdistFunctor { const auto& x2_cat = JUST(Concat({x2, x2_ones, x2_norm}, -1)); const auto& result = JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0)); - return Sqrt(JUST(ClampMin(result, 0))); + return Sqrt(JUST(ClampMin(result, 0.0))); }; Maybe operator()(const std::shared_ptr& x1, const std::shared_ptr& x2, @@ -2999,9 +3000,13 @@ class CdistFunctor { const auto x1_expand = JUST(Expand(x1, x1_expand_shape)); const auto x2_expand = JUST(Expand(x2, x2_expand_shape)); - if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) { - return euclidean_dist(x1_expand, x2_expand); - } + // mm_for_euclid_dist has accuracy issue + // if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) { + // shape output_shape(max_batch_shape); + // output_shape.emplace_back(r1); + // output_shape.emplace_back(r2); + // return JUST(Reshape(JUST(euclidean_dist(x1_expand, x2_expand)), output_shape)); + // } TensorProcessor tensor_processor; JUST(tensor_processor.PromoteInputsToCommonDtype(true) @@ -4771,7 +4776,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("PairwiseDistance"); m.add_functor("CosineSimilarity"); m.add_functor("Normalize"); - m.add_functor("CDist"); + m.add_functor("CDist"); m.add_functor("L2Normalize"); m.add_functor("L2NormalizeGrad"); m.add_functor("FusedBiasAddGelu"); diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index fb2be9a213c..f0ebe0d83dc 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4008,8 +4008,8 @@ def OneFlow_SmoothL1LossGradOp : OneFlow_BaseOp<"smooth_l1_loss_grad", [NoSideEf #endif // GET_ONEFLOW_LOSS_OP_DEFINITIONS // Group: MATH -// abs_grad, ceil_grad, erf, erf_grad, exp, exp_grad, expm1, expand_grad, expm1, expm1_grad, floor_grad, floordiv_x_grad, floordiv_y_grad, truncdiv_x_grad, truncdiv_y_grad, lgamma, lgamma_grad, log, log1p, log1p_grad, log2_grad, log10_grad, log_grad, log_sigmoid, log_sigmoid_grad, negative_grad, reciprocal_grad, reciprocal_no_nan, reciprocal_no_nan_grad, rint_grad, round_grad, rsqrt, rsqrt_grad, sigmoid_v2, sigmoid_v2_grad, sign_grad, softplus, softplus_grad, softsign_grad, var, sqrt, sqrt_grad, square, square_grad, xlogy_x_grad, xlogy_y_grad, cumsum, erfinv -// Total: 48 +// abs_grad, ceil_grad, erf, erf_grad, exp, exp_grad, expm1, expand_grad, expm1, expm1_grad, floor_grad, floordiv_x_grad, floordiv_y_grad, truncdiv_x_grad, truncdiv_y_grad, lgamma, lgamma_grad, log, log1p, log1p_grad, log2_grad, log10_grad, log_grad, log_sigmoid, log_sigmoid_grad, negative_grad, reciprocal_grad, reciprocal_no_nan, reciprocal_no_nan_grad, rint_grad, round_grad, rsqrt, rsqrt_grad, sigmoid_v2, sigmoid_v2_grad, sign_grad, softplus, softplus_grad, softsign_grad, var, sqrt, sqrt_grad, square, square_grad, xlogy_x_grad, xlogy_y_grad, cumsum, erfinv, cdist, cdist_grad +// Total: 50 #ifdef GET_ONEFLOW_MATH_OP_DEFINITIONS @@ -4027,7 +4027,7 @@ def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoSideEffect, DeclareOpInter let has_data_type_infer_fn = 1; } -def OneFlow_CdistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceMethods]> { +def OneFlow_CDistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x1, OneFlow_Tensor:$x2 @@ -4044,7 +4044,7 @@ def OneFlow_CdistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceM let has_data_type_infer_fn = 1; } -def OneFlow_CdistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { +def OneFlow_CDistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { let input = (ins OneFlow_Tensor:$x1, OneFlow_Tensor:$x2, diff --git a/oneflow/user/kernels/cdist_kernel.cpp b/oneflow/user/kernels/cdist_kernel.cpp index cb42ddfa428..534af24118d 100644 --- a/oneflow/user/kernels/cdist_kernel.cpp +++ b/oneflow/user/kernels/cdist_kernel.cpp @@ -79,7 +79,7 @@ struct PDist { }; template -void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, int64_t size_out, +void CpuCDistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, int64_t size_out, int64_t r1, int64_t r2, int64_t c, double p) { // x1 shape: (d1, d2, ..., dn, r1, c), treated as (d1 * ... * dn, r1 * c) // x2 shape: (d1, d2, ..., dn, r2, c), treated as (d1 * ... * dn, r2 * c) @@ -128,7 +128,7 @@ void CpuCdistForward(ep::CpuStream* stream, const T* x1, const T* x2, T* out, in } template -void CpuCdistBackward(ep::CpuStream* stream, const T* x1, const T* x2, const T* dist, const T* grad, +void CpuCDistBackward(ep::CpuStream* stream, const T* x1, const T* x2, const T* dist, const T* grad, T* grad1, T* grad2, int64_t size_out, int64_t r1, int64_t r2, int64_t c, double p) { stream->ParallelFor( @@ -178,10 +178,10 @@ void CpuCdistBackward(ep::CpuStream* stream, const T* x1, const T* x2, const T* } template -class CpuCdistKernel final : public user_op::OpKernel { +class CpuCDistKernel final : public user_op::OpKernel { public: - CpuCdistKernel() = default; - ~CpuCdistKernel() = default; + CpuCDistKernel() = default; + ~CpuCDistKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { @@ -199,19 +199,19 @@ class CpuCdistKernel final : public user_op::OpKernel { T* out_ptr = out->mut_dptr(); if (p == 0) { - CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + CpuCDistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else if (p == 1) { - CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + CpuCDistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else if (p == 2) { - CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + CpuCDistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else if (std::isinf(p)) { - CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + CpuCDistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else { - CpuCdistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, + CpuCDistForward>(ctx->stream()->As(), x1_ptr, x2_ptr, out_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); }; } @@ -219,10 +219,10 @@ class CpuCdistKernel final : public user_op::OpKernel { }; template -class CpuCdistGradKernel final : public user_op::OpKernel { +class CpuCDistGradKernel final : public user_op::OpKernel { public: - CpuCdistGradKernel() = default; - ~CpuCdistGradKernel() = default; + CpuCDistGradKernel() = default; + ~CpuCDistGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { @@ -255,19 +255,19 @@ class CpuCdistGradKernel final : public user_op::OpKernel { if (p == 0) { // grad is always zero } else if (p == 1) { - CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + CpuCDistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else if (p == 2) { - CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + CpuCDistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else if (std::isinf(p)) { - CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + CpuCDistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); } else { - CpuCdistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, + CpuCDistBackward>(ctx->stream()->As(), x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, out->shape_view().elem_cnt(), r1, r2, c, p); }; @@ -276,7 +276,7 @@ class CpuCdistGradKernel final : public user_op::OpKernel { }; #define REGISTER_CPU_CDIST_KERNEL(dtype) \ - REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ + REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x1", 0) == GetDataType::value) \ && (user_op::HobDataType("x2", 0) == GetDataType::value) \ @@ -288,7 +288,7 @@ REGISTER_CPU_CDIST_KERNEL(double) #define REGISTER_CPU_CDIST_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cdist_grad") \ - .SetCreateFn>() \ + .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \ && (user_op::HobDataType("x1", 0) == GetDataType::value) \ && (user_op::HobDataType("x2", 0) == GetDataType::value) \ diff --git a/oneflow/user/kernels/cdist_kernel.cu b/oneflow/user/kernels/cdist_kernel.cu index eab87856487..00cb9c47df4 100644 --- a/oneflow/user/kernels/cdist_kernel.cu +++ b/oneflow/user/kernels/cdist_kernel.cu @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "oneflow/core/ep/cuda/cuda_stream.h" #include "oneflow/core/ep/include/primitive/memset.h" #include "oneflow/core/ep/include/stream.h" +#include "oneflow/core/device/cuda_util.h" #include "oneflow/core/framework/user_op_hob.h" #include "oneflow/core/ndarray/xpu_util.h" @@ -119,6 +121,22 @@ struct DistReduce { } }; +template +__global__ static void reduce_backward_buffer(T* buffer, T* grad, int64_t reduce_size) { + typedef cub::BlockReduce BlockReduce; + int32_t row_idx = blockIdx.x; + int32_t col_idx = threadIdx.x; + __shared__ typename BlockReduce::TempStorage temp_storage; + T agg = 0; + for(int32_t col = col_idx; col < reduce_size; col += blockDim.x) { + int idx = row_idx * reduce_size + col_idx; + agg += buffer[idx]; + } + T result = BlockReduce(temp_storage).Sum(agg); + if (threadIdx.x == 0) { grad[blockIdx.x] = result; } +} + + template __global__ static void CUDACDistForward(const T* x1, const T* x2, T* out, int64_t r1, int64_t r2, int64_t c, int64_t r_size, int64_t r1_size, int64_t r2_size, @@ -138,7 +156,6 @@ __global__ static void CUDACDistForward(const T* x1, const T* x2, T* out, int64_ Dist::inc(agg, std::abs(*vec1_begin - *vec2_begin), p); } - __syncthreads(); typedef cub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; T result = BlockReduce(temp_storage).Reduce(agg, DistReduce()); @@ -149,7 +166,7 @@ template __global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist, const T* dist_grad, T* grad1, T* grad2, int64_t r1, int64_t r2, int64_t c, int64_t r_size, int64_t r1_size, int64_t r2_size, - double p) { + double p, T* buffer1, T* buffer2) { const int64_t batch_idx = blockIdx.x / r_size; const int64_t vec_out_idx = blockIdx.x - batch_idx * r_size; const int64_t vec1_idx = vec_out_idx / r2; @@ -164,15 +181,17 @@ __global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist T* grad2_begin = vec2_begin - x2 + grad2; T diff = *vec1_begin - *vec2_begin; - atomicAdd(grad1_begin, Dist::backward(diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p)); - atomicAdd(grad2_begin, Dist::backward(-diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p)); + T* buffer1_idx = buffer1 + batch_idx * r_size * c + vec1_idx * r2 * c + threadIdx.x * r2 + vec2_idx; + T* buffer2_idx = buffer2 + batch_idx * r_size * c + vec2_idx * r1 * c + threadIdx.x * r1 + vec1_idx; + *buffer1_idx = Dist::backward(diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p); + *buffer2_idx = Dist::backward(-diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p); } template -class CUDACdistKernel final : public user_op::OpKernel { +class CUDACDistKernel final : public user_op::OpKernel { public: - CUDACdistKernel() = default; - ~CUDACdistKernel() = default; + CUDACDistKernel() = default; + ~CUDACDistKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { @@ -214,15 +233,16 @@ class CUDACdistKernel final : public user_op::OpKernel { ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); } + } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; template -class CUDACdistGradKernel final : public user_op::OpKernel { +class CUDACDistGradKernel final : public user_op::OpKernel { public: - CUDACdistGradKernel() = default; - ~CUDACdistGradKernel() = default; + CUDACDistGradKernel() = default; + ~CUDACDistGradKernel() = default; private: void Compute(user_op::KernelComputeContext* ctx) const override { @@ -256,36 +276,43 @@ class CUDACdistGradKernel final : public user_op::OpKernel { memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); + + T* buffer1 = nullptr; + T* buffer2 = nullptr; + OF_CUDA_CHECK(cudaMalloc(&buffer1, out->shape_view().elem_cnt() * c * sizeof(T))); + OF_CUDA_CHECK(cudaMalloc(&buffer2, out->shape_view().elem_cnt() * c * sizeof(T))); + if (p == 0) { // grad is always zero } else if (p == 1) { CUDACDistBackward><<shape_view().elem_cnt(), c, 0, ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, - p); + p, buffer1, buffer2); } else if (p == 2) { CUDACDistBackward><<shape_view().elem_cnt(), c, 0, ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, - p); + p, buffer1, buffer2); } else if (std::isinf(p)) { CUDACDistBackward><<shape_view().elem_cnt(), c, 0, ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, - p); + p, buffer1, buffer2); } else { CUDACDistBackward><<shape_view().elem_cnt(), c, 0, ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, - p); + p, buffer1, buffer2); } - cudaDeviceSynchronize(); + reduce_backward_buffer<<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>(buffer1, dx1_ptr, r2); + reduce_backward_buffer<<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>(buffer2, dx2_ptr, r1); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; #define REGISTER_CUDA_CDIST_KERNEL(dtype) \ - REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ + REGISTER_USER_KERNEL("cdist").SetCreateFn>().SetIsMatchedHob( \ (user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x1", 0) == GetDataType::value) \ && (user_op::HobDataType("x2", 0) == GetDataType::value) \ @@ -297,7 +324,7 @@ REGISTER_CUDA_CDIST_KERNEL(double) #define REGISTER_CUDA_CDIST_GRAD_KERNEL(dtype) \ REGISTER_USER_KERNEL("cdist_grad") \ - .SetCreateFn>() \ + .SetCreateFn>() \ .SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \ && (user_op::HobDataType("x1", 0) == GetDataType::value) \ && (user_op::HobDataType("x2", 0) == GetDataType::value) \ diff --git a/oneflow/user/ops/cdist_op.cpp b/oneflow/user/ops/cdist_op.cpp index 2663c9eaa7f..6d947eedf07 100644 --- a/oneflow/user/ops/cdist_op.cpp +++ b/oneflow/user/ops/cdist_op.cpp @@ -37,19 +37,19 @@ Maybe FwdInferTensorDesc(user_op::InferContext* ctx) { } // namespace -/* static */ Maybe CdistOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { +/* static */ Maybe CDistOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return FwdInferTensorDesc(ctx); } -/*static*/ Maybe CdistOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe CDistOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } -/* static */ Maybe CdistOp::GetSbp(user_op::SbpContext* ctx) { +/* static */ Maybe CDistOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } -/* static */ Maybe CdistOp::InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe CDistOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); user_op::TensorDesc* output_desc = ctx->MutOutputTensorDesc("out", 0); output_desc->set_data_type(x1_desc.data_type()); @@ -72,19 +72,19 @@ Maybe BwdInferTensorDesc(user_op::InferContext* ctx) { } // namespace -/* static */ Maybe CdistGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { +/* static */ Maybe CDistGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { return BwdInferTensorDesc(ctx); } -/*static*/ Maybe CdistGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe CDistGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { return InferLogicalTensorDesc(ctx); } -/* static */ Maybe CdistGradOp::GetSbp(user_op::SbpContext* ctx) { +/* static */ Maybe CDistGradOp::GetSbp(user_op::SbpContext* ctx) { return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } -/* static */ Maybe CdistGradOp::InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe CDistGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& x1_desc = ctx->InputTensorDesc("x1", 0); user_op::TensorDesc* dx1_desc = ctx->MutOutputTensorDesc("dx1", 0); user_op::TensorDesc* dx2_desc = ctx->MutOutputTensorDesc("dx2", 0); diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index dc7e9644105..d31d0a3e882 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -24,85 +24,21 @@ @flow.unittest.skip_unless_1n1d() class TestCDist(flow.unittest.TestCase): - @autotest(n=2, check_graph=True) - def test_zero_cdist(test_case): + @autotest(n=10, check_graph=True) + def test_cdist(test_case): device = random_device() dim0 = random() - dim2 = random() + dim2 = random(2, 32) mode = random_utils.choice( [ "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", + "use_mm_for_euclid_dist" "donot_use_mm_for_euclid_dist", ] ) + p = random_utils.choice([0, 1, 2, float("inf"), random(0.5, 4).to(float)]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - return torch.cdist(x1, x2, p=0, compute_mode=mode) - - @autotest(n=2, check_graph=True) - def test_one_cdist(test_case): - device = random_device() - dim0 = random() - dim2 = random() - mode = random_utils.choice( - [ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ] - ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - return torch.cdist(x1, x2, p=1, compute_mode=mode) - - @autotest(n=2, check_graph=True) - def test_two_cdist(test_case): - device = random_device() - dim0 = random() - dim2 = random() - mode = random_utils.choice( - [ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ] - ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - return torch.cdist(x1, x2, p=2, compute_mode=mode) - - @autotest(n=2, check_graph=True) - def test_infi_cdist(test_case): - device = random_device() - dim0 = random() - dim2 = random() - mode = random_utils.choice( - [ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ] - ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - return torch.cdist(x1, x2, p=float("inf"), compute_mode=mode) - - @autotest(n=5, check_graph=True) - def test_random_p_cdist(test_case): - device = random_device() - dim0 = random() - dim2 = random() - mode = random_utils.choice( - [ - "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist", - "donot_use_mm_for_euclid_dist", - ] - ) - x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) - p = random(0, 4).to(float) return torch.cdist(x1, x2, p=p, compute_mode=mode) diff --git a/python/oneflow/test/modules/test_large_size_tensor.py b/python/oneflow/test/modules/test_large_size_tensor.py new file mode 100644 index 00000000000..c41373e5a10 --- /dev/null +++ b/python/oneflow/test/modules/test_large_size_tensor.py @@ -0,0 +1,39 @@ +import random as random_util +import unittest + +import oneflow as flow +import oneflow.unittest +from oneflow.test_utils.automated_test_util import * +import numpy as np + +@flow.unittest.skip_unless_1n1d() +class TestLargeSizeTensor(flow.unittest.TestCase): + @autotest(n=1000, check_graph=False) + def test(test_case): + # size = random(2000, 3000) + # size = 5000 + # x = random_tensor(ndim=1,dim0=size).cuda().half().requires_grad_() + # y = random_tensor(ndim=1,dim0=size).cuda().half().requires_grad_() + # z = x + y + # weight = torch.randn_like(z) + # p = z * weight + # p.sum().backward() + # import ipdb; ipdb.set_trace() + # of_x = x.oneflow.grad.numpy() + # torch_x = x.pytorch.grad.numpy() + # diff = of_x - torch_x + # return x + y + size = random(200, 300) + # x = random tensor(ndim=3, dim2=size).to("cuda").to(torch.half) + # y = random tensor(ndim=3, dim2=size).to("cuda").to(torch.half) + x = torch.Tensor(np.load("np_x.npy")).to("cuda").to(torch.half).requires_grad_() + y = torch.Tensor(np.load("np_y.npy")).to("cuda").to(torch.half).requires_grad_() + # np x = x.oneflow.numpy()# np_y = y.oneflow.numpy() + # np.save("np x.npy",np x) + # np.save("np_y.npy",np_y) + return x + y + + + +if __name__ == "__main__": + unittest.main() From 01936d6e6fcee3d217c09664dabe0e2a073d962f Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 18 Nov 2022 09:42:49 +0000 Subject: [PATCH 11/26] auto format by CI --- oneflow/core/functional/impl/nn_functor.cpp | 1 - oneflow/user/kernels/cdist_kernel.cu | 27 ++++++++++--------- .../test/modules/test_large_size_tensor.py | 15 +++++++++++ 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 552aa45fd23..f5bbf4e2b22 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -2924,7 +2924,6 @@ class CosineSimilarityFunctor { } }; - class CDistFunctor { public: CDistFunctor() { diff --git a/oneflow/user/kernels/cdist_kernel.cu b/oneflow/user/kernels/cdist_kernel.cu index 00cb9c47df4..3d5d857728c 100644 --- a/oneflow/user/kernels/cdist_kernel.cu +++ b/oneflow/user/kernels/cdist_kernel.cu @@ -128,7 +128,7 @@ __global__ static void reduce_backward_buffer(T* buffer, T* grad, int64_t reduce int32_t col_idx = threadIdx.x; __shared__ typename BlockReduce::TempStorage temp_storage; T agg = 0; - for(int32_t col = col_idx; col < reduce_size; col += blockDim.x) { + for (int32_t col = col_idx; col < reduce_size; col += blockDim.x) { int idx = row_idx * reduce_size + col_idx; agg += buffer[idx]; } @@ -136,7 +136,6 @@ __global__ static void reduce_backward_buffer(T* buffer, T* grad, int64_t reduce if (threadIdx.x == 0) { grad[blockIdx.x] = result; } } - template __global__ static void CUDACDistForward(const T* x1, const T* x2, T* out, int64_t r1, int64_t r2, int64_t c, int64_t r_size, int64_t r1_size, int64_t r2_size, @@ -163,10 +162,10 @@ __global__ static void CUDACDistForward(const T* x1, const T* x2, T* out, int64_ } template -__global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist, const T* dist_grad, - T* grad1, T* grad2, int64_t r1, int64_t r2, int64_t c, - int64_t r_size, int64_t r1_size, int64_t r2_size, - double p, T* buffer1, T* buffer2) { +__global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist, + const T* dist_grad, T* grad1, T* grad2, int64_t r1, + int64_t r2, int64_t c, int64_t r_size, int64_t r1_size, + int64_t r2_size, double p, T* buffer1, T* buffer2) { const int64_t batch_idx = blockIdx.x / r_size; const int64_t vec_out_idx = blockIdx.x - batch_idx * r_size; const int64_t vec1_idx = vec_out_idx / r2; @@ -181,8 +180,10 @@ __global__ static void CUDACDistBackward(const T* x1, const T* x2, const T* dist T* grad2_begin = vec2_begin - x2 + grad2; T diff = *vec1_begin - *vec2_begin; - T* buffer1_idx = buffer1 + batch_idx * r_size * c + vec1_idx * r2 * c + threadIdx.x * r2 + vec2_idx; - T* buffer2_idx = buffer2 + batch_idx * r_size * c + vec2_idx * r1 * c + threadIdx.x * r1 + vec1_idx; + T* buffer1_idx = + buffer1 + batch_idx * r_size * c + vec1_idx * r2 * c + threadIdx.x * r2 + vec2_idx; + T* buffer2_idx = + buffer2 + batch_idx * r_size * c + vec2_idx * r1 * c + threadIdx.x * r1 + vec1_idx; *buffer1_idx = Dist::backward(diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p); *buffer2_idx = Dist::backward(-diff, *(dist_grad + blockIdx.x), *(dist + blockIdx.x), p); } @@ -233,7 +234,6 @@ class CUDACDistKernel final : public user_op::OpKernel { ctx->stream()->As()->cuda_stream()>>>( x1_ptr, x2_ptr, out_ptr, r1, r2, c, r_size, r1_size, r2_size, p); } - } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; @@ -276,7 +276,6 @@ class CUDACDistGradKernel final : public user_op::OpKernel { memset_primitive->Launch(ctx->stream(), dx1_ptr, 0, dx1->shape_view().elem_cnt() * sizeof(T)); memset_primitive->Launch(ctx->stream(), dx2_ptr, 0, dx2->shape_view().elem_cnt() * sizeof(T)); - T* buffer1 = nullptr; T* buffer2 = nullptr; OF_CUDA_CHECK(cudaMalloc(&buffer1, out->shape_view().elem_cnt() * c * sizeof(T))); @@ -305,8 +304,12 @@ class CUDACDistGradKernel final : public user_op::OpKernel { x1_ptr, x2_ptr, dist_ptr, grad_ptr, dx1_ptr, dx2_ptr, r1, r2, c, r_size, r1_size, r2_size, p, buffer1, buffer2); } - reduce_backward_buffer<<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>(buffer1, dx1_ptr, r2); - reduce_backward_buffer<<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, ctx->stream()->As()->cuda_stream()>>>(buffer2, dx2_ptr, r1); + reduce_backward_buffer + <<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>(buffer1, dx1_ptr, r2); + reduce_backward_buffer + <<shape_view().elem_cnt(), kCudaThreadsNumPerBlock, 0, + ctx->stream()->As()->cuda_stream()>>>(buffer2, dx2_ptr, r1); } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } }; diff --git a/python/oneflow/test/modules/test_large_size_tensor.py b/python/oneflow/test/modules/test_large_size_tensor.py index c41373e5a10..4686fc58027 100644 --- a/python/oneflow/test/modules/test_large_size_tensor.py +++ b/python/oneflow/test/modules/test_large_size_tensor.py @@ -1,3 +1,18 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import random as random_util import unittest From d231937169add4d3f18e9f8f588e00577893afc2 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 18 Nov 2022 17:44:11 +0800 Subject: [PATCH 12/26] Delete test_large_size_tensor.py --- .../test/modules/test_large_size_tensor.py | 54 ------------------- 1 file changed, 54 deletions(-) delete mode 100644 python/oneflow/test/modules/test_large_size_tensor.py diff --git a/python/oneflow/test/modules/test_large_size_tensor.py b/python/oneflow/test/modules/test_large_size_tensor.py deleted file mode 100644 index 4686fc58027..00000000000 --- a/python/oneflow/test/modules/test_large_size_tensor.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" -import random as random_util -import unittest - -import oneflow as flow -import oneflow.unittest -from oneflow.test_utils.automated_test_util import * -import numpy as np - -@flow.unittest.skip_unless_1n1d() -class TestLargeSizeTensor(flow.unittest.TestCase): - @autotest(n=1000, check_graph=False) - def test(test_case): - # size = random(2000, 3000) - # size = 5000 - # x = random_tensor(ndim=1,dim0=size).cuda().half().requires_grad_() - # y = random_tensor(ndim=1,dim0=size).cuda().half().requires_grad_() - # z = x + y - # weight = torch.randn_like(z) - # p = z * weight - # p.sum().backward() - # import ipdb; ipdb.set_trace() - # of_x = x.oneflow.grad.numpy() - # torch_x = x.pytorch.grad.numpy() - # diff = of_x - torch_x - # return x + y - size = random(200, 300) - # x = random tensor(ndim=3, dim2=size).to("cuda").to(torch.half) - # y = random tensor(ndim=3, dim2=size).to("cuda").to(torch.half) - x = torch.Tensor(np.load("np_x.npy")).to("cuda").to(torch.half).requires_grad_() - y = torch.Tensor(np.load("np_y.npy")).to("cuda").to(torch.half).requires_grad_() - # np x = x.oneflow.numpy()# np_y = y.oneflow.numpy() - # np.save("np x.npy",np x) - # np.save("np_y.npy",np_y) - return x + y - - - -if __name__ == "__main__": - unittest.main() From e7767e4c35810a16b7c9a75b23b90c5029abd166 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 18 Nov 2022 09:45:56 +0000 Subject: [PATCH 13/26] auto format by CI --- python/oneflow/test/modules/test_cdist.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index d31d0a3e882..a4514cb1196 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -32,8 +32,7 @@ def test_cdist(test_case): mode = random_utils.choice( [ "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist" - "donot_use_mm_for_euclid_dist", + "use_mm_for_euclid_dist" "donot_use_mm_for_euclid_dist", ] ) p = random_utils.choice([0, 1, 2, float("inf"), random(0.5, 4).to(float)]) From b5149fd6812479627dd731a4cda5f75cebef4baa Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 18 Nov 2022 17:46:47 +0800 Subject: [PATCH 14/26] Update test_cdist.py --- python/oneflow/test/modules/test_cdist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index a4514cb1196..6f68cce6efa 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -32,7 +32,8 @@ def test_cdist(test_case): mode = random_utils.choice( [ "use_mm_for_euclid_dist_if_necessary", - "use_mm_for_euclid_dist" "donot_use_mm_for_euclid_dist", + "use_mm_for_euclid_dist", + "donot_use_mm_for_euclid_dist", ] ) p = random_utils.choice([0, 1, 2, float("inf"), random(0.5, 4).to(float)]) From 5b1f1b059e77f674a6cf6f6ebc8c73b694f9f95a Mon Sep 17 00:00:00 2001 From: WangYi Date: Sun, 20 Nov 2022 22:18:19 +0800 Subject: [PATCH 15/26] set every possible param p in unittest to float --- python/oneflow/test/modules/test_cdist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index d31d0a3e882..337c0cd18f6 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -36,7 +36,7 @@ def test_cdist(test_case): "donot_use_mm_for_euclid_dist", ] ) - p = random_utils.choice([0, 1, 2, float("inf"), random(0.5, 4).to(float)]) + p = random_utils.choice([0.0, 1.0, 2.0, float("inf"), random(0.5, 4).to(float)]) x1 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) x2 = random_tensor(ndim=3, dim0=dim0, dim1=random(), dim2=dim2).to(device) return torch.cdist(x1, x2, p=p, compute_mode=mode) From f9acf95d5fdb5d13872229266baede8322e40f81 Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 21 Nov 2022 09:50:52 +0800 Subject: [PATCH 16/26] set p attribute from f32 to f64 --- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index f0ebe0d83dc..d6b67dc9eac 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4036,7 +4036,7 @@ def OneFlow_CDistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceM OneFlow_Tensor:$out ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; @@ -4056,7 +4056,7 @@ def OneFlow_CDistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpI OneFlow_Tensor:$dx2 ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; From 011712e0bc5fe1d6ac8d525c0d64f05242e4f34f Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 23 Nov 2022 10:44:59 +0800 Subject: [PATCH 17/26] set p attribute from f64 to f32 --- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index d6b67dc9eac..f0ebe0d83dc 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4036,7 +4036,7 @@ def OneFlow_CDistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceM OneFlow_Tensor:$out ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; @@ -4056,7 +4056,7 @@ def OneFlow_CDistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpI OneFlow_Tensor:$dx2 ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; From 0bd0d25b49a88be792881260c0201521de9df701 Mon Sep 17 00:00:00 2001 From: WangYi Date: Wed, 23 Nov 2022 19:06:20 +0800 Subject: [PATCH 18/26] Revert "set p attribute from f64 to f32" This reverts commit 011712e0bc5fe1d6ac8d525c0d64f05242e4f34f. --- oneflow/ir/include/OneFlow/OneFlowUserOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/ir/include/OneFlow/OneFlowUserOps.td b/oneflow/ir/include/OneFlow/OneFlowUserOps.td index c1adb9dfce0..fa08373e5e2 100644 --- a/oneflow/ir/include/OneFlow/OneFlowUserOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowUserOps.td @@ -4137,7 +4137,7 @@ def OneFlow_CDistOp : OneFlow_BaseOp<"cdist", [NoSideEffect, DeclareOpInterfaceM OneFlow_Tensor:$out ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; @@ -4157,7 +4157,7 @@ def OneFlow_CDistGradOp : OneFlow_BaseOp<"cdist_grad", [NoSideEffect, DeclareOpI OneFlow_Tensor:$dx2 ); let attrs = (ins - DefaultValuedAttr:$p + DefaultValuedAttr:$p ); let has_logical_tensor_desc_infer_fn = 1; let has_physical_tensor_desc_infer_fn = 1; From 5ee010c54b393a32ca892f7a6c4dfe366800915c Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 28 Nov 2022 14:47:00 +0800 Subject: [PATCH 19/26] refine code --- oneflow/user/kernels/cdist_kernel.cpp | 32 +++++++++++------------ oneflow/user/kernels/cdist_kernel.cu | 2 +- python/oneflow/test/modules/test_cdist.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/oneflow/user/kernels/cdist_kernel.cpp b/oneflow/user/kernels/cdist_kernel.cpp index 534af24118d..d80540b6221 100644 --- a/oneflow/user/kernels/cdist_kernel.cpp +++ b/oneflow/user/kernels/cdist_kernel.cpp @@ -27,38 +27,38 @@ namespace oneflow { template struct ZeroDist { - static inline T map(const T& diff, const T& p) { return diff == T(0) ? diff : T(1); } - static inline T reduce(const T& agg, const T& up) { return agg + up; } - static inline T finish(const T agg, const T p) { return agg; } + static inline T map(const T& diff, const double& p) { return diff == T(0) ? diff : T(1); } + static inline T reduce(const T& agg, const double& up) { return agg + up; } + static inline T finish(const T agg, const double p) { return agg; } // backward always return 0 }; template struct OneDist { - static inline T map(const T& diff, const T& p) { return diff; } - static inline T reduce(const T& agg, const T& up) { return agg + up; } - static inline T finish(const T agg, const T p) { return agg; } - static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + static inline T map(const T& diff, const double& p) { return diff; } + static inline T reduce(const T& agg, const double& up) { return agg + up; } + static inline T finish(const T agg, const double p) { return agg; } + static inline T backward(const T& diff, const T grad, const T dist, const double& p) { return grad * (diff > T(0) ? T(1) : T(-1)); } }; template struct TwoDist { - static inline T map(const T& diff, const T& p) { return diff * diff; } + static inline T map(const T& diff, const double& p) { return diff * diff; } static inline T reduce(const T& agg, const T& up) { return agg + up; } - static inline T finish(const T agg, const T p) { return std::sqrt(agg); } - static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + static inline T finish(const T agg, const double p) { return std::sqrt(agg); } + static inline T backward(const T& diff, const T grad, const T dist, const double& p) { return dist == 0.0 ? T(0) : grad * diff / dist; } }; template struct InfiDist { - static inline T map(const T& diff, const T& p) { return diff; } + static inline T map(const T& diff, const double& p) { return diff; } static inline T reduce(const T& agg, const T& up) { return std::max(agg, up); } - static inline T finish(const T agg, const T p) { return agg; } - static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + static inline T finish(const T agg, const double p) { return agg; } + static inline T backward(const T& diff, const T grad, const T dist, const double& p) { return (T(1) - std::min(std::ceil(std::abs(std::abs(diff) - dist)), T(1))) * grad * (diff > T(0) ? T(1) : T(-1)); } @@ -66,10 +66,10 @@ struct InfiDist { template struct PDist { - static inline T map(const T& diff, const T& p) { return std::pow(diff, p); } + static inline T map(const T& diff, const double& p) { return std::pow(diff, p); } static inline T reduce(const T& agg, const T& up) { return agg + up; } - static inline T finish(const T agg, const T p) { return std::pow(agg, 1.0 / p); } - static inline T backward(const T& diff, const T grad, const T dist, const T& p) { + static inline T finish(const T agg, const double p) { return std::pow(agg, 1.0 / p); } + static inline T backward(const T& diff, const T grad, const T dist, const double& p) { if (dist == 0.0) { return T(0); } else { diff --git a/oneflow/user/kernels/cdist_kernel.cu b/oneflow/user/kernels/cdist_kernel.cu index 3d5d857728c..176e91cb484 100644 --- a/oneflow/user/kernels/cdist_kernel.cu +++ b/oneflow/user/kernels/cdist_kernel.cu @@ -199,7 +199,7 @@ class CUDACDistKernel final : public user_op::OpKernel { const user_op::Tensor* x1 = ctx->Tensor4ArgNameAndIndex("x1", 0); const user_op::Tensor* x2 = ctx->Tensor4ArgNameAndIndex("x2", 0); user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); - double p = ctx->Attr("p"); + const double p = ctx->Attr("p"); int64_t ndim = x1->shape_view().NumAxes(); int64_t r1 = x1->shape_view().At(ndim - 2); int64_t r2 = x2->shape_view().At(ndim - 2); diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 9f6cc0c21ee..80154b7b0b6 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -24,7 +24,7 @@ @flow.unittest.skip_unless_1n1d() class TestCDist(flow.unittest.TestCase): - @autotest(n=10, check_graph=True) + @autotest(n=1000, check_graph=True) def test_cdist(test_case): device = random_device() dim0 = random() From e120c593c0a6d5e43c2caa489f6a9dae779e11a5 Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 28 Nov 2022 18:35:35 +0800 Subject: [PATCH 20/26] remove attr 'mode' --- oneflow/core/functional/impl/nn_functor.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 52d7f43e1ee..4dfc82b600b 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -3119,8 +3119,8 @@ class CDistFunctor { .AddInputs({x1_expand, x2_expand}) .Apply()); - auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p", "mode"); - attrs.SetAllAttrs(p, mode); + auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p"); + attrs.SetAllAttrs(p); return OpInterpUtil::Dispatch(*op_, {x1, x2}, attrs); } From f8ca1e5724b483a52b9f7486f4abce8c9552813d Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 28 Nov 2022 21:05:23 +0800 Subject: [PATCH 21/26] remove useless variables --- oneflow/core/functional/functional_api.yaml | 2 +- oneflow/core/functional/impl/nn_functor.cpp | 15 +++++---------- python/oneflow/test/modules/test_cdist.py | 2 +- 3 files changed, 7 insertions(+), 12 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index ee1eeb417ce..9860587ef12 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2394,7 +2394,7 @@ bind_python: True - name: "cdist" - signature: 'Tensor (Tensor x1, Tensor x2, Double p=2.0, String compute_mode="use_mm_for_euclid_dist_if_necessary") => CDist' + signature: 'Tensor (Tensor x1, Tensor x2, Double p=2.0, String compute_mode=None) => CDist' bind_python: True - name: "cdist_grad" diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 4dfc82b600b..71ffe95f15c 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -3050,7 +3050,7 @@ class CDistFunctor { }; Maybe operator()(const std::shared_ptr& x1, const std::shared_ptr& x2, - const double& p, const std::string& compute_mode) const { + const double& p, const Optional& compute_mode) const { const int64_t x1_ndim = x1->ndim(); const int64_t x2_ndim = x2->ndim(); CHECK_OR_RETURN(x1_ndim >= 2) << "cdist only supports at least 2D tensors, X1 got: " @@ -3062,15 +3062,10 @@ class CDistFunctor { << " X2: " << x2->dim(x2_ndim - 1); CHECK_OR_RETURN(p >= 0) << "cdist only supports non-negative p values, got " << p; - int32_t mode = 0; - if (compute_mode == "use_mm_for_euclid_dist_if_necessary") { - mode = 0; - } else if (compute_mode == "use_mm_for_euclid_dist") { - mode = 1; - } else if (compute_mode == "donot_use_mm_for_euclid_dist") { - mode = 2; - } else { - THROW(RuntimeError) << compute_mode << " is not a valid value for compute_mode"; + if (compute_mode.has_value()) { + OF_LOG_ONCE(LOG(WARNING) + << "'compute_mode' argument is not supported yet, cdist " + "will not use matrix multiplication approach to calculate euclidean distance"); } int64_t r1 = x1->dim(x1_ndim - 2); diff --git a/python/oneflow/test/modules/test_cdist.py b/python/oneflow/test/modules/test_cdist.py index 80154b7b0b6..9f6cc0c21ee 100644 --- a/python/oneflow/test/modules/test_cdist.py +++ b/python/oneflow/test/modules/test_cdist.py @@ -24,7 +24,7 @@ @flow.unittest.skip_unless_1n1d() class TestCDist(flow.unittest.TestCase): - @autotest(n=1000, check_graph=True) + @autotest(n=10, check_graph=True) def test_cdist(test_case): device = random_device() dim0 = random() From 5f4fe399a2ff645a9a3dddb5efe963badadecfe2 Mon Sep 17 00:00:00 2001 From: WangYi Date: Thu, 15 Dec 2022 15:33:22 +0800 Subject: [PATCH 22/26] add global test --- .../oneflow/test/modules/test_global_cdist.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 python/oneflow/test/modules/test_global_cdist.py diff --git a/python/oneflow/test/modules/test_global_cdist.py b/python/oneflow/test/modules/test_global_cdist.py new file mode 100644 index 00000000000..b29081e8a04 --- /dev/null +++ b/python/oneflow/test/modules/test_global_cdist.py @@ -0,0 +1,46 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import unittest +import oneflow as flow +import torch as torch_ori +import oneflow.unittest + +from oneflow.test_utils.automated_test_util import * + + +@autotest(n=1, auto_backward=False, check_graph=False) +def _test_cdist(test_case, ndim, placement, sbp): + dims = [random(1, 4) * 8 for i in range(ndim)] + x1 = random_tensor(ndim, *dims) + x1 = x1.to_global(placement=placement, sbp=sbp) + x2 = random_tensor(ndim, *dims) + x2 = x2.to_global(placement=placement, sbp=sbp) + z = torch.cdist(x1, x2) + return z + + +class TestCDistGlobal(flow.unittest.TestCase): + @globaltest + def test_cdist(test_case): + ndim = random(2, 5).to(int).value() + for placement in all_placement(): + for sbp in all_sbp(placement, max_dim=ndim): + _test_cdist(test_case, ndim, placement, sbp) + + +if __name__ == "__main__": + unittest.main() From 6a21f987d610adde62015f6372f019ec87a02667 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Wed, 11 Jan 2023 19:23:19 +0800 Subject: [PATCH 23/26] Update functional_api.yaml --- oneflow/core/functional/functional_api.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 9860587ef12..f497bb73b5e 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -2399,7 +2399,7 @@ - name: "cdist_grad" signature: "TensorTuple (Tensor x1, Tensor x2, Tensor out, Tensor dy, Double p=2.0) => CDistGrad" - bind_python: True + bind_python: False - name: "normalize" signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize" From e86699419abec3a4ef3d0f5e0e2db9cbbe621dad Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Wed, 11 Jan 2023 11:25:07 +0000 Subject: [PATCH 24/26] auto format by CI --- oneflow/core/profiler/event_recorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/profiler/event_recorder.cpp b/oneflow/core/profiler/event_recorder.cpp index 994664620a8..dfa0adcd015 100644 --- a/oneflow/core/profiler/event_recorder.cpp +++ b/oneflow/core/profiler/event_recorder.cpp @@ -46,7 +46,7 @@ Maybe EventRecorder::CreateKernelEventRecorder( } return std::make_shared(event); } -#else // WITH_CUDA +#else // WITH_CUDA if (pmgr->use_cpu_) { return std::make_shared( KernelEvent::Create(name, pmgr->record_shapes_ ? shape_getter : nullptr)); From dde8ebe8036520107befbb2171bd44577e23d712 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 17 Feb 2023 15:14:02 +0800 Subject: [PATCH 25/26] Update event_recorder.cpp --- oneflow/core/profiler/event_recorder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oneflow/core/profiler/event_recorder.cpp b/oneflow/core/profiler/event_recorder.cpp index 2bc96634d34..db8e0b0c508 100644 --- a/oneflow/core/profiler/event_recorder.cpp +++ b/oneflow/core/profiler/event_recorder.cpp @@ -52,7 +52,7 @@ Maybe EventRecorder::CreateKernelEventRecorder( } return std::make_shared(event); } - +#else if (pmgr->use_cpu_) { return std::make_shared(KernelEvent::Create(name, description_getter())); } From 4c20eb5a3a8ca383c5d12050b13021e63686c090 Mon Sep 17 00:00:00 2001 From: WangYi Date: Fri, 17 Feb 2023 20:20:47 +0800 Subject: [PATCH 26/26] refine broadcast shape --- oneflow/core/functional/functional_api.yaml | 1 + .../core/functional/impl/array_functor.cpp | 1 + oneflow/core/functional/impl/nn_functor.cpp | 33 +++++-------------- 3 files changed, 11 insertions(+), 24 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 1995aac0143..34f997447e1 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -3160,6 +3160,7 @@ - name: "broadcast_to" signature: "Tensor (Tensor x, Shape shape) => BroadcastTo" bind_python: True + - name: "bincount" signature: "Tensor (Tensor input, Tensor weights=None, Int64 minlength=None) => BinCount" bind_python: True diff --git a/oneflow/core/functional/impl/array_functor.cpp b/oneflow/core/functional/impl/array_functor.cpp index 5a624efa58f..1979c711b18 100644 --- a/oneflow/core/functional/impl/array_functor.cpp +++ b/oneflow/core/functional/impl/array_functor.cpp @@ -3694,6 +3694,7 @@ class BroadcastTensorsFunctor { return outputs; } }; + class BinCountFunctor { public: BinCountFunctor() { diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index 62ce865dc51..0fcbf131efd 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -3207,34 +3207,19 @@ class CDistFunctor { int64_t r2 = x2->dim(x2_ndim - 2); int64_t d = x1->dim(x1_ndim - 1); - Shape x1_batch_shape = Shape(DimVector({x1->shape()->begin(), x1->shape()->end() - 2})); - Shape x2_batch_shape = Shape(DimVector({x2->shape()->begin(), x2->shape()->end() - 2})); - Shape max_batch_shape = - Shape::Ones(std::max(x1_batch_shape.NumAxes(), x2_batch_shape.NumAxes())); - { - for (int64_t i = max_batch_shape.NumAxes() - 1; i >= 0; i--) { - int64_t offset = max_batch_shape.NumAxes() - 1 - i; - int64_t dim_x = x1_batch_shape.NumAxes() - 1 - offset; - int64_t dim_y = x2_batch_shape.NumAxes() - 1 - offset; - int64_t size_x = (dim_x >= 0) ? x1_batch_shape.At(dim_x) : 1; - int64_t size_y = (dim_y >= 0) ? x2_batch_shape.At(dim_y) : 1; - if (!(size_x == size_y || size_x == 1 || size_y == 1)) { - return Error::RuntimeError() - << "The size of tensor a (" << size_x << ") must match the size of tensor b (" - << size_y << ") at non-singleton dimension " << i; - } - max_batch_shape.Set(i, std::max(size_x, size_y)); - } - } - Shape x1_expand_shape(max_batch_shape); - Shape x2_expand_shape(max_batch_shape); + std::vector shape_vector = { + Shape(DimVector({x1->shape()->begin(), x1->shape()->end() - 2})), + Shape(DimVector({x2->shape()->begin(), x2->shape()->end() - 2})), + }; + auto broadcasted_shape = JUST(BroadcastShapes(shape_vector)); + Shape x1_expand_shape(*broadcasted_shape); x1_expand_shape.emplace_back(r1); x1_expand_shape.emplace_back(d); - x2_expand_shape.emplace_back(r2); - x2_expand_shape.emplace_back(d); + broadcasted_shape->emplace_back(r2); + broadcasted_shape->emplace_back(d); const auto x1_expand = JUST(Expand(x1, x1_expand_shape)); - const auto x2_expand = JUST(Expand(x2, x2_expand_shape)); + const auto x2_expand = JUST(Expand(x2, *broadcasted_shape)); // mm_for_euclid_dist has accuracy issue // if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {