Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cdist op #9391

Open
wants to merge 34 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4fdd7a3
add op, cpu kernel, functor
marigoold Nov 8, 2022
f1bc1a1
refine code
marigoold Nov 8, 2022
374d1fe
add backward func, register autograd
marigoold Nov 9, 2022
e1047e1
fix bug of memory init, and grad func call, add eager unittest
marigoold Nov 10, 2022
b846c69
add docs
marigoold Nov 10, 2022
d7b8b62
refine code
marigoold Nov 10, 2022
6f4d423
resolve conflict
marigoold Nov 11, 2022
4a792bb
add compute_mode unittest
marigoold Nov 15, 2022
015c0f2
format code
marigoold Nov 15, 2022
8f30598
add cuda kernel
marigoold Nov 17, 2022
d533e86
Merge branch 'master' into dev_add_cdist
marigoold Nov 18, 2022
130e28e
add cuda backward kernel, refine code (Cdist=>CDist), refine unittest…
marigoold Nov 18, 2022
01936d6
auto format by CI
oneflow-ci-bot Nov 18, 2022
d231937
Delete test_large_size_tensor.py
marigoold Nov 18, 2022
e7767e4
auto format by CI
oneflow-ci-bot Nov 18, 2022
b5149fd
Update test_cdist.py
marigoold Nov 18, 2022
5b1f1b0
set every possible param p in unittest to float
marigoold Nov 20, 2022
29b8ae2
Merge branch 'dev_add_cdist' of https://github.com/Oneflow-Inc/oneflo…
marigoold Nov 20, 2022
f9acf95
set p attribute from f32 to f64
marigoold Nov 21, 2022
011712e
set p attribute from f64 to f32
marigoold Nov 23, 2022
965f9ab
Merge branch 'master' into dev_add_cdist
marigoold Nov 23, 2022
0bd0d25
Revert "set p attribute from f64 to f32"
marigoold Nov 23, 2022
8c228dc
Merge branch 'master' into dev_add_cdist
marigoold Nov 28, 2022
5ee010c
refine code
marigoold Nov 28, 2022
e120c59
remove attr 'mode'
marigoold Nov 28, 2022
f8ca1e5
remove useless variables
marigoold Nov 28, 2022
5f4fe39
add global test
marigoold Dec 15, 2022
6a21f98
Update functional_api.yaml
marigoold Jan 11, 2023
e866994
auto format by CI
oneflow-ci-bot Jan 11, 2023
eb970bc
Merge branch 'master' into dev_add_cdist
marigoold Feb 17, 2023
dde8ebe
Update event_recorder.cpp
marigoold Feb 17, 2023
4c20eb5
refine broadcast shape
marigoold Feb 17, 2023
e96bd79
Merge branch 'master' into dev_add_cdist
marigoold Apr 11, 2023
c38d2fa
Merge branch 'master' into dev_add_cdist
marigoold Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ Other Ops
adaptive_avg_pool3d
broadcast_like
cast
cdist
cumprod
cumsum
diag
Expand Down
96 changes: 96 additions & 0 deletions oneflow/core/autograd/gradient_funcs/cdist.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr_grad_function.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/functional/functional.h"

namespace oneflow {
namespace one {

namespace {

struct CDistCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
size_t x1_index = 0;
size_t x2_index = 0;
size_t out_index = 0;
double p = 0.0;
};

class CDistGrad : public OpExprGradFunction<CDistCaptureState> {
public:
virtual ~CDistGrad() = default;

using OpExprGradFunction<CDistCaptureState>::Init;

Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(CDistCaptureState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrMap& attrs) const override;
Maybe<void> Apply(const CDistCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

private:
AttrMap base_attrs_;
};

Maybe<void> CDistGrad::Init(const OpExpr& op) {
const auto* fw_op_expr = dynamic_cast<const UserOpExpr*>(&op);
CHECK_NOTNULL_OR_RETURN(fw_op_expr); // NOLINT(maybe-need-error-msg)
base_attrs_ = MakeAttrMapFromUserOpConf(fw_op_expr->proto());
return Maybe<void>::Ok();
}

Maybe<void> CDistGrad::Capture(CDistCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }

ctx->x1_index = ctx->SaveTensorForBackward(inputs.at(0));
ctx->x2_index = ctx->SaveTensorForBackward(inputs.at(1));
ctx->out_index = ctx->SaveTensorForBackward(outputs.at(0));
ComposedAttrMap composed_attrs(attrs, base_attrs_);
ctx->p = JUST(composed_attrs.GetAttr<double>("p"));

return Maybe<void>::Ok();
}

Maybe<void> CDistGrad::Apply(const CDistCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const {
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
CHECK_LE_OR_RETURN(out_grads.size(), 1); // NOLINT(maybe-need-error-msg)

const auto& x1 = ctx->SavedTensors().at(ctx->x1_index);
const auto& x2 = ctx->SavedTensors().at(ctx->x2_index);
const auto& out = ctx->SavedTensors().at(ctx->out_index);
const double p = ctx->p;

in_grads->resize(2);
auto results = JUST(functional::CDistGrad(x1, x2, out, out_grads.at(0), p));
(*in_grads)[0] = results->at(0);
(*in_grads)[1] = results->at(1);
return Maybe<void>::Ok();
}

} // namespace

REGISTER_OP_EXPR_GRAD_FUNCTION("cdist", CDistGrad);

} // namespace one
} // namespace oneflow
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2612,6 +2612,14 @@
signature: "Tensor (Tensor x, Tensor y, Int32 dim=1, Double eps=1e-8) => CosineSimilarity"
bind_python: True

- name: "cdist"
signature: 'Tensor (Tensor x1, Tensor x2, Double p=2.0, String compute_mode=None) => CDist'
bind_python: True

- name: "cdist_grad"
signature: "TensorTuple (Tensor x1, Tensor x2, Tensor out, Tensor dy, Double p=2.0) => CDistGrad"
bind_python: False

- name: "normalize"
signature: "Tensor (Tensor input, Float p=2.0, Int32 dim=1, Float eps=1e-12, Bool use_l2_norm_kernel=True) => Normalize"
bind_python: True
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/functional/impl/array_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3801,6 +3801,7 @@ class BroadcastTensorsFunctor {
return outputs;
}
};

class BinCountFunctor {
public:
BinCountFunctor() {
Expand Down
78 changes: 78 additions & 0 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3431,6 +3431,83 @@ class CosineSimilarityFunctor {
}
};

class CDistFunctor {
public:
CDistFunctor() {
op_ = CHECK_JUST(OpBuilder("cdist").Input("x1").Input("x2").Output("out").Build());
}
Maybe<Tensor> euclidean_dist(const std::shared_ptr<Tensor>& x1,
const std::shared_ptr<Tensor>& x2) const {
const auto& x1_norm = JUST(ReduceSum(JUST(ScalarPow(x1, 2, false)), {-1}, true));
const auto& x2_norm = JUST(ReduceSum(JUST(ScalarPow(x2, 2, false)), {-1}, true));
const auto& x1_ones = JUST(OnesLike(x1_norm));
const auto& x2_ones = JUST(OnesLike(x2_norm));
const auto& x1_cat = JUST(Concat({JUST(ScalarMul(x1, -2, false)), x1_norm, x1_ones}, -1));
const auto& x2_cat = JUST(Concat({x2, x2_ones, x2_norm}, -1));
const auto& result =
JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0));
return Sqrt(JUST(ClampMin(result, 0.0)));
};
Comment on lines +3439 to +3450
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的调用在下面注释了,没有用到了


Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x1, const std::shared_ptr<Tensor>& x2,
const double& p, const Optional<std::string>& compute_mode) const {
const int64_t x1_ndim = x1->ndim();
const int64_t x2_ndim = x2->ndim();
CHECK_OR_RETURN(x1_ndim >= 2) << "cdist only supports at least 2D tensors, X1 got: "
<< x1->ndim() << "D";
CHECK_OR_RETURN(x2_ndim >= 2) << "cdist only supports at least 2D tensors, X2 got: "
<< x2->ndim() << "D";
CHECK_OR_RETURN(x1->dim(x1_ndim - 1) == x2->dim(x2_ndim - 1))
<< "X1 and X2 must have the same number of columns. X1: " << x1->dim(x1_ndim - 1)
<< " X2: " << x2->dim(x2_ndim - 1);
CHECK_OR_RETURN(p >= 0) << "cdist only supports non-negative p values, got " << p;

if (compute_mode.has_value()) {
OF_LOG_ONCE(LOG(WARNING)
<< "'compute_mode' argument is not supported yet, cdist "
"will not use matrix multiplication approach to calculate euclidean distance");
}

int64_t r1 = x1->dim(x1_ndim - 2);
int64_t r2 = x2->dim(x2_ndim - 2);
int64_t d = x1->dim(x1_ndim - 1);

std::vector<Shape> shape_vector = {
Shape(DimVector({x1->shape()->begin(), x1->shape()->end() - 2})),
Shape(DimVector({x2->shape()->begin(), x2->shape()->end() - 2})),
};
auto broadcasted_shape = JUST(BroadcastShapes(shape_vector));
Shape x1_expand_shape(*broadcasted_shape);
x1_expand_shape.emplace_back(r1);
x1_expand_shape.emplace_back(d);
broadcasted_shape->emplace_back(r2);
broadcasted_shape->emplace_back(d);

const auto x1_expand = JUST(Expand(x1, x1_expand_shape));
const auto x2_expand = JUST(Expand(x2, *broadcasted_shape));

// mm_for_euclid_dist has accuracy issue
// if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
// shape output_shape(max_batch_shape);
// output_shape.emplace_back(r1);
// output_shape.emplace_back(r2);
// return JUST(Reshape(JUST(euclidean_dist(x1_expand, x2_expand)), output_shape));
// }
Comment on lines +3489 to +3495
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用的注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用的注释

这里的代码在 torch 里面是有的,只是当前还有精度问题,解决掉之后就解除注释了


TensorProcessor tensor_processor;
JUST(tensor_processor.PromoteInputsToCommonDtype(true)
.AddInputs({x1_expand, x2_expand})
.Apply());

auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p");
attrs.SetAllAttrs(p);
return OpInterpUtil::Dispatch<Tensor>(*op_, {x1, x2}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class L2NormalizeFunctor {
public:
L2NormalizeFunctor() {
Expand Down Expand Up @@ -5501,6 +5578,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::PairwiseDistanceFunctor>("PairwiseDistance");
m.add_functor<impl::CosineSimilarityFunctor>("CosineSimilarity");
m.add_functor<impl::NormalizeFunctor>("Normalize");
m.add_functor<impl::CDistFunctor>("CDist");
m.add_functor<impl::L2NormalizeFunctor>("L2Normalize");
m.add_functor<impl::L2NormalizeGradFunctor>("L2NormalizeGrad");
m.add_functor<impl::FusedBiasAddGeluFunctor>("FusedBiasAddGelu");
Expand Down
26 changes: 26 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,31 @@ class BinaryCrossEntropyWithLogitsReduceMeanLossTargetGradFunctor {
}
};

class CDistGradFunctor {
public:
CDistGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("cdist_grad")
.Input("x1")
.Input("x2")
.Input("out")
.Input("dy")
.Output("dx1")
.Output("dx2")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<Tensor>& x1,
const std::shared_ptr<Tensor>& x2,
const std::shared_ptr<Tensor>& out,
const std::shared_ptr<Tensor>& dy, const double& p) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("p");
attrs.SetAllAttrs(p);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {x1, x2, out, dy}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class CombinedMarginLossGradFunctor {
public:
CombinedMarginLossGradFunctor() {
Expand Down Expand Up @@ -1692,6 +1717,7 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::SparseSoftmaxCrossEntropyGrad>("SparseSoftmaxCrossEntropyGrad");
m.add_functor<impl::SparseSoftmaxCrossEntropyMsGrad>("SparseSoftmaxCrossEntropyMsGrad");
m.add_functor<impl::SmoothL1LossGradFunctor>("SmoothL1LossGrad");
m.add_functor<impl::CDistGradFunctor>("CDistGrad");
m.add_functor<impl::CombinedMarginLossGradFunctor>("CombinedMarginLossGrad");
m.add_functor<impl::AffineGridGradFunctor>("AffineGridGrad");
m.add_functor<impl::GridSampleGradFunctor>("GridSampleGrad");
Expand Down
37 changes: 37 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4483,6 +4483,43 @@ def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoMemoryEffect, DeclareOpInt
let has_data_type_infer_fn = 1;
}

def OneFlow_CDistOp : OneFlow_BaseOp<"cdist", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x1,
OneFlow_Tensor:$x2
);
let output = (outs
OneFlow_Tensor:$out
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "2.0">:$p
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_CDistGradOp : OneFlow_BaseOp<"cdist_grad", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x1,
OneFlow_Tensor:$x2,
OneFlow_Tensor:$out,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx1,
OneFlow_Tensor:$dx2
);
let attrs = (ins
DefaultValuedAttr<F64Attr, "2.0">:$p
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
Expand Down
Loading
Loading