diff --git a/.github/scripts/apply_torch_pr.py b/.github/scripts/apply_torch_pr.py index bbe89ed7d..a9befbc14 100644 --- a/.github/scripts/apply_torch_pr.py +++ b/.github/scripts/apply_torch_pr.py @@ -13,6 +13,8 @@ "https://github.com/pytorch/pytorch/pull/126516", # Modify the tolerance level in TIMM benchmark "https://github.com/pytorch/pytorch/pull/143739", + # Fix build error caused by incorrect namespace change by #144014 + "https://github.com/pytorch/pytorch/pull/144450", ] ) parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[]) diff --git a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp index a2e3a1375..bd24aa3a0 100644 --- a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp +++ b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp @@ -7,17 +7,24 @@ #include #include #include +#include #include #include #include +//#include #include #include +#include +#include #include #include #include #include +#include #include +#include //generated +//#include //generated namespace at { @@ -42,6 +49,7 @@ REGISTER_XPU_DISPATCH(index_fill_stub, &xpu::index_fill_kernel); REGISTER_XPU_DISPATCH(index_copy_stub, &xpu::index_copy_kernel); REGISTER_XPU_DISPATCH(put_stub, &xpu::put_kernel); REGISTER_XPU_DISPATCH(take_stub, &xpu::take_kernel); +// REGISTER_XPU_DISPATCH(index_reduce_stub, &xpu::index_reduce_kernel); TORCH_IMPL_FUNC(index_add_xpu_out) (const Tensor& self, @@ -126,5 +134,44 @@ Tensor count_nonzero_xpu(const Tensor& self, IntArrayRef dims) { return (self != 0).sum(dims); } +TORCH_IMPL_FUNC(index_reduce_xpu_out) +(const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + const c10::string_view reduce, + bool include_self, + const Tensor& result) { + TORCH_WARN_ONCE( + "index_reduce() is in beta and the API may change at any time."); + if (reduce == "prod") { + xpu::index_reduce_prod_kernel( + self, dim, index, source, include_self, ReductionType::PROD, result); + } else if (reduce == "mean") { + xpu::index_reduce_mean_kernel( + self, dim, index, source, include_self, ReductionType::MEAN, result); + auto counts = include_self ? ones_like(result) : zeros_like(result); + counts.index_add_(dim, index, ones_like(source)); + counts.masked_fill_(counts == 0, 1); + if (result.is_floating_point() || result.is_complex()) { + result.div_(counts); + } else { + result.div_(counts, "floor"); + } + } else if (reduce == "amax") { + xpu::index_reduce_amax_kernel( + self, dim, index, source, include_self, ReductionType::MAX, result); + } else if (reduce == "amin") { + xpu::index_reduce_amin_kernel( + self, dim, index, source, include_self, ReductionType::MIN, result); + } else { + TORCH_CHECK( + false, + "Only support prod, mean, amax or amin reduce operator. Input was ", + reduce, + "."); + } +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 72f2aacdd..62e5770ba 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -163,7 +163,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "_fft_r2c", "_flash_attention_forward", "geqrf", - "index_reduce.out", "linalg_cholesky_ex.L", "_linalg_det.result", "linalg_eig", diff --git a/src/ATen/native/xpu/sycl/Atomics.h b/src/ATen/native/xpu/sycl/Atomics.h index b381ed5b0..d6cc1fe77 100644 --- a/src/ATen/native/xpu/sycl/Atomics.h +++ b/src/ATen/native/xpu/sycl/Atomics.h @@ -360,6 +360,8 @@ SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), int8_t) SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), int16_t) SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), int32_t) SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), int64_t) +SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), uint32_t) +SYCL_ATOMIC_INTEGER(Mul, std::multiplies()(a, b), uint64_t) SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), float) SYCL_ATOMIC_FP(Mul, std::multiplies()(a, b), double) @@ -391,6 +393,8 @@ SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int8_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int16_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int32_t) SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), int64_t) +SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), uint32_t) +SYCL_ATOMIC_INTEGER(Max, safe_max(a, b), uint64_t) SYCL_ATOMIC_FP(Max, safe_max(a, b), float) SYCL_ATOMIC_FP(Max, safe_max(a, b), double) @@ -403,6 +407,8 @@ SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), int8_t) SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), int16_t) SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), int32_t) SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), int64_t) +SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), uint32_t) +SYCL_ATOMIC_INTEGER(Min, safe_min(a, b), uint64_t) SYCL_ATOMIC_FP(Min, safe_min(a, b), float) SYCL_ATOMIC_FP(Min, safe_min(a, b), double) diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 44086f778..10ac1b67c 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -1151,6 +1152,329 @@ void put_kernel( }); } +template < + typename T, + typename IndicesType, + typename IndexType, + bool IndexIsMajor, + typename func_t> +struct IndexFuncLargeIndexFunctor { + void operator()(sycl::nd_item<1> item) const { + // We stride over the output including the indexed dimension + // (totalSize), and calculate the destination index point based on that + auto local_range = item.get_local_range(0); + for (IndexType linearIndex = + item.get_group(0) * local_range + item.get_local_id(0); + linearIndex < totalSize_; + linearIndex += item.get_group_range(0) * local_range) { + IndexType srcIndex, elementInSlice; + if (IndexIsMajor) { + srcIndex = linearIndex / innerSize_; + elementInSlice = linearIndex % innerSize_; + } else { + elementInSlice = linearIndex / innerSize_; + srcIndex = linearIndex % innerSize_; + } + + // Lua indices begin at 1 + IndexType dstIndex = + indices_.data[IndexToOffset::get( + srcIndex, indices_)]; + CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize_); + + IndexType dstOffset = + IndexToOffset::get(elementInSlice, dst_); + dstOffset += dstIndex * dst_.strides[dstAddDim_]; + + IndexType srcOffset = + IndexToOffset::get(elementInSlice, src_); + srcOffset += srcIndex * src_.strides[srcAddDim_]; + + T val = src_.data[srcOffset] * alpha_; + op_(dst_.data, dstOffset, dstNumel_, &val); + } + } + IndexFuncLargeIndexFunctor( + TensorInfo dst, + TensorInfo src, + TensorInfo indices, + int dstAddDim, + int srcAddDim, + IndexType totalSize, + IndexType innerSize, + int64_t dstAddDimSize, + int64_t dstNumel, + func_t op, + T alpha) + : dst_(dst), + src_(src), + indices_(indices), + dstAddDim_(dstAddDim), + srcAddDim_(srcAddDim), + totalSize_(totalSize), + innerSize_(innerSize), + dstAddDimSize_(dstAddDimSize), + dstNumel_(dstNumel), + op_(op), + alpha_(alpha) {} + + private: + TensorInfo dst_; + TensorInfo src_; + TensorInfo indices_; + int dstAddDim_; + int srcAddDim_; + IndexType totalSize_; + IndexType innerSize_; + int64_t dstAddDimSize_; + int64_t dstNumel_; + func_t op_; + T alpha_; +}; + +template +void index_reduce_func_xpu_template( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const func_t& reduce_func, + const Tensor& result) { + globalContext().alertNotDeterministic("index_reduce_xpu"); + + if (!result.is_same(self)) + result.copy_(self); + + // Scalars are treated as 1-d tensor + Tensor self_ = (result.dim() == 0) ? result.view(1) : result; + Tensor source_ = (source.dim() == 0) ? source.view(1) : source; + + TORCH_CHECK( + result.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + TORCH_CHECK( + source.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + TORCH_CHECK( + index.dim() <= XPU_MAX_TENSORINFO_DIMS, + "tensor has too many (>", + XPU_MAX_TENSORINFO_DIMS, + ") dims"); + + if (!include_self) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "index_reduce_func_xpu_exclude_input_init", + [&] { + scalar_t init_val; + switch (reduce) { + case ReductionType::PROD: + init_val = (scalar_t)1; + break; + case ReductionType::MAX: + init_val = std::numeric_limits::has_infinity + ? -std::numeric_limits::infinity() + : std::numeric_limits::lowest(); + break; + case ReductionType::MIN: + init_val = std::numeric_limits::has_infinity + ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + break; + default: + init_val = (scalar_t)0; + break; + } + // index_fill_ requires index to be a LongTensor + self_.index_fill_(dim, index.to(at::ScalarType::Long), init_val); + }); + } + + uint64_t sliceSize = getSliceSize(self_, dim, index, source_); + uint64_t sourceTotalSize = source.numel(); + uint64_t selfReduceDimSize = self_.size(dim); + // uint64_t numIndex = index.numel(); + uint64_t selfNumel = self_.numel(); + if (sliceSize == 0) { + return; + } + + { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "index_reduce", + [&] { + TensorInfo selfInfo = + getTensorInfo(self_); + int selfReduceDim = selfInfo.collapseDims(dim); + selfInfo.reduceDim(selfReduceDim); + auto alpha_value = (scalar_t)1; + + TensorInfo sourceInfo = + getTensorInfo(source_); + int sourceReduceDim = sourceInfo.collapseDims(dim); + sourceInfo.reduceDim(sourceReduceDim); + + AT_DISPATCH_INDEX_TYPES( + index.scalar_type(), "index_reduce_xpu", [&]() { + TensorInfo indexInfo = + getTensorInfo(index); + indexInfo.collapseDims(); + auto caller = IndexFuncLargeIndexFunctor< + scalar_t, + index_t, + uint64_t, + true, + func_t>( + selfInfo, + sourceInfo, + indexInfo, + selfReduceDim, + sourceReduceDim, + sourceTotalSize, + sliceSize, + selfReduceDimSize, + selfNumel, + reduce_func, + alpha_value); + int defaultMaxGroupThreads = syclMaxWorkGroupSize(caller); + int sgc = syclMaxNumSubGroups(); + size_t num_wg = std::min( + ceil_div(sourceTotalSize, (uint64_t)128), + (uint64_t)(sgc * 8)); + size_t wg_size = (sourceTotalSize < defaultMaxGroupThreads) + ? sourceTotalSize + : defaultMaxGroupThreads; + sycl_kernel_submit( + num_wg * wg_size, wg_size, getCurrentSYCLQueue(), caller); + }); + }); + } +} + +struct IndexReduceMultiplyFunctor { + template + void operator()( + scalar_t* self_data_start, + int64_t index, + int64_t numel, + const scalar_t* src_data) const { + (void)numel; // suppress unused warning + atomicMul((sycl_global_ptr)(self_data_start + index), *src_data); + } +}; +static IndexReduceMultiplyFunctor index_reduce_multiply; + +struct IndexReduceMeanFunctor { + template + void operator()( + scalar_t* self_data_start, + int64_t index, + int64_t numel, + const scalar_t* src_data) const { + atomicAdd((sycl_global_ptr)(self_data_start + index), *src_data); + } +}; +static IndexReduceMeanFunctor index_reduce_mean; + +struct IndexReduceMaxFunctor { + template + void operator()( + scalar_t* self_data_start, + int64_t index, + int64_t numel, + const scalar_t* src_data) const { + (void)numel; // suppress unused warning + atomicMax((sycl_global_ptr)(self_data_start + index), *src_data); + } +}; +static IndexReduceMaxFunctor index_reduce_max; + +struct IndexReduceMinFunctor { + template + void operator()( + scalar_t* self_data_start, + int64_t index, + int64_t numel, + const scalar_t* src_data) const { + (void)numel; // suppress unused warning + atomicMin((sycl_global_ptr)(self_data_start + index), *src_data); + } +}; +static IndexReduceMinFunctor index_reduce_min; + +void index_reduce_prod_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result) { + index_reduce_func_xpu_template( + self, + dim, + index, + source, + include_self, + reduce, + index_reduce_multiply, + result); +} + +void index_reduce_mean_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result) { + index_reduce_func_xpu_template( + self, + dim, + index, + source, + include_self, + reduce, + index_reduce_mean, + result); +} + +void index_reduce_amax_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result) { + index_reduce_func_xpu_template( + self, dim, index, source, include_self, reduce, index_reduce_max, result); +} + +void index_reduce_amin_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result) { + index_reduce_func_xpu_template( + self, dim, index, source, include_self, reduce, index_reduce_min, result); +} + // ForwardIt: only legacy random access iterator is supported. template static inline ForwardIt find_bound( diff --git a/src/ATen/native/xpu/sycl/IndexingKernels.h b/src/ATen/native/xpu/sycl/IndexingKernels.h index e5f434585..097d51a5d 100644 --- a/src/ATen/native/xpu/sycl/IndexingKernels.h +++ b/src/ATen/native/xpu/sycl/IndexingKernels.h @@ -1,5 +1,6 @@ #pragma once #include +#include namespace at::native::xpu { @@ -65,6 +66,42 @@ TORCH_XPU_API void put_kernel( TORCH_XPU_API void take_kernel(TensorIterator& iter, const TensorBase& input); +TORCH_XPU_API void index_reduce_prod_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result); + +TORCH_XPU_API void index_reduce_mean_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result); + +TORCH_XPU_API void index_reduce_amax_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result); + +TORCH_XPU_API void index_reduce_amin_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index, + const Tensor& source, + bool include_self, + const ReductionType& reduce, + const Tensor& result); + TORCH_XPU_API Tensor index_select_sparse_kernel( const Tensor& self, int64_t dim, diff --git a/src/comm/DeviceProperties.h b/src/comm/DeviceProperties.h index b98281357..724f574f4 100644 --- a/src/comm/DeviceProperties.h +++ b/src/comm/DeviceProperties.h @@ -112,6 +112,12 @@ static inline int64_t syclMaxWorkItemsPerEU( return simd_width * hw_threads; } +static inline int64_t syclMaxNumSubGroups( + at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { + auto* dev_prop = at::xpu::getDeviceProperties(dev_id); + return dev_prop->max_num_sub_groups; +} + static inline int64_t syclMaxDSSNum( at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) { // TODO: We need to got this info from DPC++ Runtime diff --git a/src/comm/TensorInfo.h b/src/comm/TensorInfo.h index 6d602a632..67b5c5aa9 100644 --- a/src/comm/TensorInfo.h +++ b/src/comm/TensorInfo.h @@ -60,9 +60,9 @@ struct TensorInfo { // See note on [collapse dims]. int collapseDims(const int excludeDim = -1); - int outerSize(const int dim); + IndexType outerSize(const int dim); - int innerSize(const int dim); + IndexType innerSize(const int dim); // Contiguous tensors of more than one dimension are collapsed down // to one tensor @@ -104,7 +104,7 @@ TensorInfo::TensorInfo( TORCH_INTERNAL_ASSERT(dims <= XPU_MAX_TENSORINFO_DIMS); is_contiguous = true; - int z = 1; + IndexType z = 1; for (int i = dim - 1; i >= 0; i--) { sizes[i] = sz[i]; strides[i] = st[i]; @@ -133,8 +133,8 @@ int TensorInfo::collapseDims(const int excludeDim) { } template -int TensorInfo::innerSize(const int exclusive) { - int size = 1; +IndexType TensorInfo::innerSize(const int exclusive) { + IndexType size = 1; for (int i = dims - 1; i > exclusive; i--) { size *= sizes[i]; } @@ -142,8 +142,8 @@ int TensorInfo::innerSize(const int exclusive) { } template -int TensorInfo::outerSize(const int exclusive) { - int size = 1; +IndexType TensorInfo::outerSize(const int exclusive) { + IndexType size = 1; for (int i = 0; i < exclusive; i++) { size *= sizes[i]; } diff --git a/test/xpu/extended/skip_list_common.py b/test/xpu/extended/skip_list_common.py index 643d631eb..7cd960c33 100644 --- a/test/xpu/extended/skip_list_common.py +++ b/test/xpu/extended/skip_list_common.py @@ -198,5 +198,14 @@ "test_compare_cpu_div_trunc_rounding_xpu_float16", "test_compare_cpu_div_floor_rounding_xpu_float16", "test_compare_cpu_div_floor_rounding_xpu_bfloat16", + + # AssertionError: Tensor-likes are not close! + # Mismatched elements: 1 / 125 (0.8%) + # Greatest absolute difference: 0.0013427734375 at index (0, 2, 4) (up to 0.001 allowed) + # Greatest relative difference: 0.008453369140625 at index (0, 2, 4) (up to 0.001 allowed) + "test_compare_cpu_index_reduce_mean_xpu_bfloat16", + "test_compare_cpu_index_reduce_mean_xpu_float16", + "test_compare_cpu_index_reduce_prod_xpu_bfloat16", + "test_compare_cpu_index_reduce_prod_xpu_float16", ), } diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index d036482ac..780f2efd7 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -2036,6 +2036,8 @@ # All are oneDNN issues ### Error #0 in TestBwdGradientsXPU , totally 271 , RuntimeError: Double and complex datatype matmul is not supported in oneDNN + "test_fn_grad_index_reduce_prod_xpu_float64", + "test_inplace_grad_index_reduce_prod_xpu_float64", "test_fn_grad___rmatmul___xpu_complex128", "test_fn_grad___rmatmul___xpu_float64", "test_fn_grad_addbmm_xpu_float64", @@ -2411,6 +2413,20 @@ # internally uses index_put deterministic implementation # dependent on "test_index_put_non_accumulate_deterministic" "test_index_copy_deterministic", + + # scatter_add needs handle XPU deterministic + # https://github.com/intel/torch-xpu-ops/issues/906 + "test_gather_backward_deterministic_path_xpu", + "test_scatter_add_one_dim_deterministic_xpu", + + # Precision error + # Fail occasionally + # Mismatched elements: 1 / 60 (1.7%) + # Greatest absolute difference: 0.0625 at index (2, 1, 4) (up to 1e-05 allowed) + # Greatest relative difference: 0.001125335693359375 at index (2, 1, 4) (up to 0.001 allowed) + "test_index_reduce_reduce_mean_xpu_bfloat16", + "test_index_reduce_reduce_mean_xpu_float16", + "test_index_reduce_reduce_prod_xpu_float16", ), "nn/test_multihead_attention_xpu.py": ( diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index ec561f337..a7e583331 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -87,6 +87,7 @@ "nn.functional.mish", "i0", "index_add", + "index_reduce", "index_fill", "index_put", "index_select", diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index db6821026..e9c399bd4 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -6125,6 +6125,22 @@ dispatch: XPU: _compute_linear_combination_out +- func: index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + structured: True + variants: function + precomputed: + - dim -> int dim + dispatch: + XPU: index_reduce_xpu_out + +- func: index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) + structured_delegate: index_reduce.out + variants: method + +- func: index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + structured_delegate: index_reduce.out + variants: function, method + - func: index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) device_check: NoCheck # TensorIterator variants: method