-
Notifications
You must be signed in to change notification settings - Fork 667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add cdist op #9391
base: master
Are you sure you want to change the base?
Add cdist op #9391
Conversation
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9391/ |
Speed stats:
|
This reverts commit 011712e.
Speed stats:
|
Speed stats:
|
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9391/ |
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9391/ |
Code got formatted by CI. Please request CI again if you still want to have this PR merged. If the PR is from a forked repo, please download the patch files from the GitHub Actions web page and apply them locally. |
Speed stats:
|
Speed stats:
|
// mm_for_euclid_dist has accuracy issue | ||
// if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) { | ||
// shape output_shape(max_batch_shape); | ||
// output_shape.emplace_back(r1); | ||
// output_shape.emplace_back(r2); | ||
// return JUST(Reshape(JUST(euclidean_dist(x1_expand, x2_expand)), output_shape)); | ||
// } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除无用的注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
删除无用的注释
这里的代码在 torch 里面是有的,只是当前还有精度问题,解决掉之后就解除注释了
Maybe<Tensor> euclidean_dist(const std::shared_ptr<Tensor>& x1, | ||
const std::shared_ptr<Tensor>& x2) const { | ||
const auto& x1_norm = JUST(ReduceSum(JUST(ScalarPow(x1, 2, false)), {-1}, true)); | ||
const auto& x2_norm = JUST(ReduceSum(JUST(ScalarPow(x2, 2, false)), {-1}, true)); | ||
const auto& x1_ones = JUST(OnesLike(x1_norm)); | ||
const auto& x2_ones = JUST(OnesLike(x2_norm)); | ||
const auto& x1_cat = JUST(Concat({JUST(ScalarMul(x1, -2, false)), x1_norm, x1_ones}, -1)); | ||
const auto& x2_cat = JUST(Concat({x2, x2_ones, x2_norm}, -1)); | ||
const auto& result = | ||
JUST(MatMul(x1_cat, JUST(Transpose2dim(x2_cat, -1, -2)), false, false, 1.0)); | ||
return Sqrt(JUST(ClampMin(result, 0.0))); | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个函数的调用在下面注释了,没有用到了
x2 = x2.to_global(placement=placement, sbp=sbp) | ||
z = torch.cdist(x1, x2) | ||
return z | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看这里测试都是相同dim的,代码里写了broadcast的逻辑,也测试一下broadcast的情况吧
cdist 对于两个输入 x1 (shape=[B, R1, C]),x2 (shape=[B, R2, C]),计算每个 batch 内 x1 和 x2 每一行向量之间距离的p范数,得到结果 result (shape=[B, R1, R2])。
torch 文档见 https://pytorch.org/docs/stable/generated/torch.cdist.html