From 0dfa9321fa68b001cfcc6595823f610519ee738e Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 7 Aug 2023 13:51:26 +0800 Subject: [PATCH 01/13] memory coalescing for last dim flip --- oneflow/user/kernels/flip_kernel.cu | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 316abdbf154..a0ab7bbe9e0 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "oneflow/core/common/container_util.h" #include "oneflow/core/framework/framework.h" #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/common/nd_index_offset_helper.h" @@ -51,6 +52,21 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, } } +template +__global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, + const T* in_dptr, T* out_dptr) { + __shared__ T shm[kCudaThreadsNumPerBlock]; + CUDA_1D_KERNEL_LOOP(i, element) { + int32_t end_idx = min(blockDim.x, element); + shm[end_idx - threadIdx.x - 1] = in_dptr[i]; + __syncthreads(); + int32_t i_ori = i - 2 * threadIdx.x + end_idx - 1; + int32_t row = i_ori / last_dim_size; + int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1; + out_dptr[row * last_dim_size + col] = shm[threadIdx.x]; + } +} + } // namespace template @@ -72,14 +88,19 @@ class FlipGpuKernel final : public user_op::OpKernel { VIS vis; for (auto x : dims) { vis.val[x] = true; } + if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { + FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( + elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), y_tensor->mut_dptr()); + return; + } + SIZE_V sizes_v; for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); } - // TODO(bbuf) delete strides caluculate, after tensor strides supported SIZE_V strides_v; - strides_v.val[total_dims - 1] = 1; - for (int32_t i = total_dims - 2; i >= 0; i--) { - strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape_view().At(i + 1); + for (int32_t i = 0; i < total_dims; i++) { + strides_v.val[i] = CHECK_JUST(VectorAt(y_tensor->stride(), i)); } RUN_CUDA_KERNEL((FlipGpuForward), ctx->stream(), elem_cnt, elem_cnt, total_dims, sizes_v, vis, strides_v, x_tensor->dptr(), y_tensor->mut_dptr()); From e2927e293009b31e260c53c48fbb7d20ca66ccf6 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Mon, 7 Aug 2023 06:09:49 +0000 Subject: [PATCH 02/13] auto format by CI --- oneflow/user/kernels/flip_kernel.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index a0ab7bbe9e0..e632d9a1882 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -54,7 +54,7 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, - const T* in_dptr, T* out_dptr) { + const T* in_dptr, T* out_dptr) { __shared__ T shm[kCudaThreadsNumPerBlock]; CUDA_1D_KERNEL_LOOP(i, element) { int32_t end_idx = min(blockDim.x, element); @@ -90,8 +90,9 @@ class FlipGpuKernel final : public user_op::OpKernel { if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), y_tensor->mut_dptr()); + ctx->stream()->As()->cuda_stream()>>>( + elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + y_tensor->mut_dptr()); return; } From 8eb495129c277895a1ce4b7293e38cefd5a021b1 Mon Sep 17 00:00:00 2001 From: WangYi Date: Mon, 7 Aug 2023 14:44:42 +0800 Subject: [PATCH 03/13] refine, fix bug of final block unactive thread --- oneflow/user/kernels/flip_kernel.cu | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index a0ab7bbe9e0..ccf3b6d1fd6 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -54,13 +54,14 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, - const T* in_dptr, T* out_dptr) { + const T* in_dptr, T* out_dptr) { __shared__ T shm[kCudaThreadsNumPerBlock]; CUDA_1D_KERNEL_LOOP(i, element) { - int32_t end_idx = min(blockDim.x, element); - shm[end_idx - threadIdx.x - 1] = in_dptr[i]; + int32_t block_begin_idx = blockDim.x * blockIdx.x; + int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx; + shm[threadIdx.x] = in_dptr[thread_end_idx - i + 2 * block_begin_idx - 1]; __syncthreads(); - int32_t i_ori = i - 2 * threadIdx.x + end_idx - 1; + int32_t i_ori = i - 2 * threadIdx.x + thread_end_idx - 1; int32_t row = i_ori / last_dim_size; int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1; out_dptr[row * last_dim_size + col] = shm[threadIdx.x]; @@ -90,8 +91,9 @@ class FlipGpuKernel final : public user_op::OpKernel { if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), y_tensor->mut_dptr()); + ctx->stream()->As()->cuda_stream()>>>( + elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + y_tensor->mut_dptr()); return; } From df1344c9bfa7a85affec1520a30460609622ae44 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:28:09 +0800 Subject: [PATCH 04/13] rm modification in flip_kernel.cu to test ci ci always crushed of cuda oom, even if I changed nothing in unittest --- oneflow/user/kernels/flip_kernel.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index ccf3b6d1fd6..bb6c7db0356 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -90,11 +90,11 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { - FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), - y_tensor->mut_dptr()); - return; +# FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( +# elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), +# y_tensor->mut_dptr()); +# return; } SIZE_V sizes_v; From e3708d44de39db22d78f22872efab66fe17718b2 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 11 Aug 2023 17:28:51 +0800 Subject: [PATCH 05/13] Update flip_kernel.cu --- oneflow/user/kernels/flip_kernel.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index bb6c7db0356..68e85e08dd2 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -90,11 +90,11 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { -# FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( -# elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), -# y_tensor->mut_dptr()); -# return; +// FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( +// elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), +// y_tensor->mut_dptr()); +// return; } SIZE_V sizes_v; From 0fc137891b8a7b1e27b3db6acc163ff36cef0bb9 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 11 Aug 2023 09:30:50 +0000 Subject: [PATCH 06/13] auto format by CI --- oneflow/user/kernels/flip_kernel.cu | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 68e85e08dd2..b9d48957928 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -90,11 +90,12 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { -// FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( -// elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), -// y_tensor->mut_dptr()); -// return; + // FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( + // elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + // y_tensor->mut_dptr()); + // return; } SIZE_V sizes_v; From 585b93a10d6532b7a9c172bca5a8af61cc6f2471 Mon Sep 17 00:00:00 2001 From: Wang Yi <53533850+marigoold@users.noreply.github.com> Date: Fri, 11 Aug 2023 20:35:38 +0800 Subject: [PATCH 07/13] Update flip_kernel.cu --- oneflow/user/kernels/flip_kernel.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index b9d48957928..cafdfe0bc3b 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -90,12 +90,12 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { - // FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - // elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), - // y_tensor->mut_dptr()); - // return; + FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( + elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + y_tensor->mut_dptr()); + return; } SIZE_V sizes_v; From 8d5adbf190882afeb11e4448de33f9588bd3cea1 Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Fri, 11 Aug 2023 12:37:47 +0000 Subject: [PATCH 08/13] auto format by CI --- oneflow/user/kernels/flip_kernel.cu | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index cafdfe0bc3b..ccf3b6d1fd6 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -90,12 +90,11 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { - FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), - y_tensor->mut_dptr()); - return; + FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( + elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + y_tensor->mut_dptr()); + return; } SIZE_V sizes_v; From 43502d56b07bf39893c5b0599fc09ffda7fa307d Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Aug 2023 01:07:40 +0800 Subject: [PATCH 09/13] test ci oom[DONT MERGE] --- oneflow/user/kernels/flip_kernel.cu | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index ccf3b6d1fd6..1c1006cdb3a 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -55,7 +55,7 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, const T* in_dptr, T* out_dptr) { - __shared__ T shm[kCudaThreadsNumPerBlock]; + __shared__ T shm[256]; CUDA_1D_KERNEL_LOOP(i, element) { int32_t block_begin_idx = blockDim.x * blockIdx.x; int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx; @@ -90,10 +90,13 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { - FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), - y_tensor->mut_dptr()); + // FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( + // elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + // y_tensor->mut_dptr()); + RUN_CUDA_KERNEL((FlipLastDimGpuForward), ctx->stream(), elem_cnt, elem_cnt, + x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), + y_tensor->mut_dptr()); return; } From 5e6bb2c23b3e6e66dad697f919125060456c9af4 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Aug 2023 09:27:29 +0800 Subject: [PATCH 10/13] refine code --- oneflow/user/kernels/flip_kernel.cu | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 1c1006cdb3a..9634706194d 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -55,7 +55,7 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, const T* in_dptr, T* out_dptr) { - __shared__ T shm[256]; + __shared__ T shm[ep::CudaStream::kDefaultBlockSize]; CUDA_1D_KERNEL_LOOP(i, element) { int32_t block_begin_idx = blockDim.x * blockIdx.x; int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx; @@ -90,10 +90,6 @@ class FlipGpuKernel final : public user_op::OpKernel { for (auto x : dims) { vis.val[x] = true; } if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) { - // FlipLastDimGpuForward<<stream()->As()->cuda_stream()>>>( - // elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), - // y_tensor->mut_dptr()); RUN_CUDA_KERNEL((FlipLastDimGpuForward), ctx->stream(), elem_cnt, elem_cnt, x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr(), y_tensor->mut_dptr()); From e41ea55c042fb77c3c1280f708a26e79a2144b4e Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Aug 2023 10:46:23 +0800 Subject: [PATCH 11/13] add unittest, refine code --- oneflow/user/kernels/flip_kernel.cu | 4 ++-- python/oneflow/test/modules/test_flip.py | 9 +++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 9634706194d..23ff6f762a4 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -59,9 +59,9 @@ __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_ CUDA_1D_KERNEL_LOOP(i, element) { int32_t block_begin_idx = blockDim.x * blockIdx.x; int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx; - shm[threadIdx.x] = in_dptr[thread_end_idx - i + 2 * block_begin_idx - 1]; + int32_t i_ori = block_begin_idx + (thread_end_idx - threadIdx.x - 1); + shm[threadIdx.x] = in_dptr[i_ori]; __syncthreads(); - int32_t i_ori = i - 2 * threadIdx.x + thread_end_idx - 1; int32_t row = i_ori / last_dim_size; int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1; out_dptr[row * last_dim_size + col] = shm[threadIdx.x]; diff --git a/python/oneflow/test/modules/test_flip.py b/python/oneflow/test/modules/test_flip.py index 2b3b21c2eb5..f5f23930cf0 100644 --- a/python/oneflow/test/modules/test_flip.py +++ b/python/oneflow/test/modules/test_flip.py @@ -55,9 +55,18 @@ def test_flow_flip_bool_tuple_with_random_data(test_case): y = torch.flip(x, constant((0, 1, 2))) return y + def test_flow_flip_list_lastdim_with_random_data(test_case): + device = random_device() + x = random_tensor( + ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int) + ).to(device) + y = torch.flip(x, [-1,]) + return y + @profile(torch.flip) def profile_flip(test_case): torch.flip(torch.ones(100, 100, 100), [0, 1]) + torch.flip(torch.ones(1, 100000), [-1,]) if __name__ == "__main__": From 46b8546c752c2feeabf18665a06d30a918c3f1b6 Mon Sep 17 00:00:00 2001 From: WangYi Date: Tue, 15 Aug 2023 11:16:34 +0800 Subject: [PATCH 12/13] add comments --- oneflow/user/kernels/flip_kernel.cu | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 23ff6f762a4..7b242e97b88 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -52,6 +52,34 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, } } + +/* +Example tensor: +[[0, 1, 2, 3, 4, 5, 6, 7], + [8, 9, 10, 11, 12, 13, 14]] + +Given parameters: BlockSize=4, GridSize=4 +For each block_i, `block_begin_idx` is calculated as (i - 1) * BlockSize = (i - 1) * 4, +and `thread_end_idx` is set to 4 for all blocks except the final block. +In the final block, `thread_end_idx` is 2, representing the border index of the active thread. + +`i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before flipping. +For instance, consider block 1 and thread 2 (element 6). The element is located at row 0, column 7 in the tensor. +Its original index `i_ori` is 7, and after flipping, it is mapped to row 0, column 0. + + ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ +global mem before: │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ A │ B │ C │ D │ x │ x │ + └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ + + block0 │ block1 │ block2 │ block3 + ┌───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┐ +shm after loading: │ 3 │ 2 │ 1 │ 0 │ 7 │ 6 │ 5 │ 4 │ B │ A │ 9 │ 8 │ D │ C │ x │ x │ + └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ + + ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ +global mem after: │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ D │ C │ B │ A │ 9 │ 8 │ 7 │ x │ x │ + └───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘ +*/ template __global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size, const T* in_dptr, T* out_dptr) { From 5638907f7955ad9e117168c9b6cfb42d7686d2be Mon Sep 17 00:00:00 2001 From: oneflow-ci-bot Date: Tue, 15 Aug 2023 03:19:17 +0000 Subject: [PATCH 13/13] auto format by CI --- oneflow/user/kernels/flip_kernel.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/oneflow/user/kernels/flip_kernel.cu b/oneflow/user/kernels/flip_kernel.cu index 7b242e97b88..556323afc26 100644 --- a/oneflow/user/kernels/flip_kernel.cu +++ b/oneflow/user/kernels/flip_kernel.cu @@ -52,7 +52,6 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims, } } - /* Example tensor: [[0, 1, 2, 3, 4, 5, 6, 7], @@ -63,9 +62,10 @@ For each block_i, `block_begin_idx` is calculated as (i - 1) * BlockSize = (i - and `thread_end_idx` is set to 4 for all blocks except the final block. In the final block, `thread_end_idx` is 2, representing the border index of the active thread. -`i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before flipping. -For instance, consider block 1 and thread 2 (element 6). The element is located at row 0, column 7 in the tensor. -Its original index `i_ori` is 7, and after flipping, it is mapped to row 0, column 0. +`i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before +flipping. For instance, consider block 1 and thread 2 (element 6). The element is located at row 0, +column 7 in the tensor. Its original index `i_ori` is 7, and after flipping, it is mapped to row 0, +column 0. ┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐ global mem before: │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ A │ B │ C │ D │ x │ x │