Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::index_reduce operator #1156

Merged
merged 27 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
92738af
First reduced version of index_reduce.
cfgfung Nov 25, 2024
cbd7b03
Implemented reduce_prod.
cfgfung Nov 26, 2024
55daf53
Removed unnecessary if cases.
cfgfung Dec 2, 2024
eb1ad53
Added two reduce operators - amin and amax.
cfgfung Dec 3, 2024
f77e325
Add reduce_mean op.
cfgfung Dec 4, 2024
9066b5d
Skip 3 test cases. These are due to precision errors and the differen…
cfgfung Dec 19, 2024
fe08ee4
Merge branch 'main' into reduce_index_v2
xytintel Jan 3, 2025
bae2544
Merge branch 'main' into reduce_index_v2
xytintel Jan 6, 2025
c439c15
Fix pointer bug & refine code
xytintel Jan 6, 2025
4ddc3af
Update XPUFallback.template
xytintel Jan 6, 2025
41abd1b
Update xpu_test_utils.py
xytintel Jan 6, 2025
68eb7ae
Update TensorAdvancedIndexing.cpp
xytintel Jan 6, 2025
7435995
Update native_functions.yaml
xytintel Jan 6, 2025
e7c5c16
Update native_functions.yaml
xytintel Jan 6, 2025
a4ffaee
Update skip_list_common.py
xytintel Jan 6, 2025
a6f3fc6
Update Indexing.cpp
xytintel Jan 6, 2025
3e0724a
Update TensorInfo.h
xytintel Jan 7, 2025
1e31b0b
Add syclMaxNumSubGroups
xytintel Jan 8, 2025
01549d1
Update Indexing.cpp
xytintel Jan 8, 2025
60381af
Merge branch 'main' into reduce_index_v2
xytintel Jan 8, 2025
743851f
Update Indexing.cpp
xytintel Jan 8, 2025
6cbb613
Update ScatterGatherKernels.cpp
xytintel Jan 8, 2025
a6c36d7
Update skip_list_common.py
xytintel Jan 9, 2025
be54630
Update skip_list_common.py
xytintel Jan 9, 2025
f186ec8
Merge branch 'main' into reduce_index_v2
xytintel Jan 9, 2025
63acc5c
Merge branch 'main' into reduce_index_v2
xytintel Jan 9, 2025
dfab49c
Update apply_torch_pr.py
xytintel Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions src/ATen/native/xpu/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,24 @@
#include <ATen/core/op_registration/adaption.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/ReductionType.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/TensorAdvancedIndexingUtils.h>
#include <ATen/native/TensorIterator.h>
//#include <ATen/native/TensorFactories.cpp>
#include <ATen/native/xpu/sycl/IndexingKernels.h>
#include <ATen/native/xpu/sycl/ScatterGatherKernels.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#include <comm/RegisterUtils.h>
#include <comm/xpu_aten.h>
#include <torch/library.h>

#include <ATen/ops/index_add_meta.h>
#include <ATen/ops/index_reduce_meta.h>
#include <xpu/ATen/ops/index_add_native.h>
#include <xpu/ATen/ops/index_reduce_native.h> //generated
//#include <xpu/ATen/ops/index_reduce_prod_native.h> //generated

namespace at {

Expand All @@ -42,6 +49,7 @@ REGISTER_XPU_DISPATCH(index_fill_stub, &xpu::index_fill_kernel);
REGISTER_XPU_DISPATCH(index_copy_stub, &xpu::index_copy_kernel);
REGISTER_XPU_DISPATCH(put_stub, &xpu::put_kernel);
REGISTER_XPU_DISPATCH(take_stub, &xpu::take_kernel);
// REGISTER_XPU_DISPATCH(index_reduce_stub, &xpu::index_reduce_kernel);

TORCH_IMPL_FUNC(index_add_xpu_out)
(const Tensor& self,
Expand Down Expand Up @@ -126,5 +134,44 @@ Tensor count_nonzero_xpu(const Tensor& self, IntArrayRef dims) {
return (self != 0).sum(dims);
}

TORCH_IMPL_FUNC(index_reduce_xpu_out)
(const Tensor& self,
int64_t dim,
const Tensor& index,
const Tensor& source,
const c10::string_view reduce,
bool include_self,
const Tensor& result) {
TORCH_WARN_ONCE(
"index_reduce() is in beta and the API may change at any time.");
if (reduce == "prod") {
xpu::index_reduce_prod_kernel(
self, dim, index, source, include_self, ReductionType::PROD, result);
} else if (reduce == "mean") {
xpu::index_reduce_mean_kernel(
self, dim, index, source, include_self, ReductionType::MEAN, result);
auto counts = include_self ? ones_like(result) : zeros_like(result);
counts.index_add_(dim, index, ones_like(source));
counts.masked_fill_(counts == 0, 1);
if (result.is_floating_point() || result.is_complex()) {
result.div_(counts);
} else {
result.div_(counts, "floor");
}
} else if (reduce == "amax") {
xpu::index_reduce_amax_kernel(
self, dim, index, source, include_self, ReductionType::MAX, result);
} else if (reduce == "amin") {
xpu::index_reduce_amin_kernel(
self, dim, index, source, include_self, ReductionType::MIN, result);
} else {
TORCH_CHECK(
false,
"Only support prod, mean, amax or amin reduce operator. Input was ",
reduce,
".");
}
}

} // namespace native
} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_fft_r2c",
"_flash_attention_forward",
"geqrf",
"index_reduce.out",
"linalg_cholesky_ex.L",
"_linalg_det.result",
"linalg_eig",
Expand Down
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/Atomics.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int8_t>()(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int16_t>()(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int32_t>()(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<int64_t>()(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<uint32_t>()(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Mul, std::multiplies<uint64_t>()(a, b), uint64_t)

SYCL_ATOMIC_FP(Mul, std::multiplies<float>()(a, b), float)
SYCL_ATOMIC_FP(Mul, std::multiplies<double>()(a, b), double)
Expand Down Expand Up @@ -391,6 +393,8 @@ SYCL_ATOMIC_INTEGER(Max, safe_max<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<int64_t>(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<uint32_t>(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Max, safe_max<uint64_t>(a, b), uint64_t)

SYCL_ATOMIC_FP(Max, safe_max<float>(a, b), float)
SYCL_ATOMIC_FP(Max, safe_max<double>(a, b), double)
Expand All @@ -403,6 +407,8 @@ SYCL_ATOMIC_INTEGER(Min, safe_min<int8_t>(a, b), int8_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int16_t>(a, b), int16_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int32_t>(a, b), int32_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<int64_t>(a, b), int64_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<uint32_t>(a, b), uint32_t)
SYCL_ATOMIC_INTEGER(Min, safe_min<uint64_t>(a, b), uint64_t)

SYCL_ATOMIC_FP(Min, safe_min<float>(a, b), float)
SYCL_ATOMIC_FP(Min, safe_min<double>(a, b), double)
Expand Down
Loading
Loading