Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flip memory coalescing for last dim case #10310

Merged
merged 18 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions oneflow/user/kernels/flip_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/container_util.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/common/nd_index_offset_helper.h"
Expand Down Expand Up @@ -51,6 +52,50 @@ __global__ void FlipGpuForward(const int32_t element, const int64_t total_dims,
}
}

/*
Example tensor:
[[0, 1, 2, 3, 4, 5, 6, 7],
[8, 9, 10, 11, 12, 13, 14]]

Given parameters: BlockSize=4, GridSize=4
For each block_i, `block_begin_idx` is calculated as (i - 1) * BlockSize = (i - 1) * 4,
and `thread_end_idx` is set to 4 for all blocks except the final block.
In the final block, `thread_end_idx` is 2, representing the border index of the active thread.

`i_ori` is an index referring to the original position of data stored in shm[threadIdx.x] before
flipping. For instance, consider block 1 and thread 2 (element 6). The element is located at row 0,
column 7 in the tensor. Its original index `i_ori` is 7, and after flipping, it is mapped to row 0,
column 0.

┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
global mem before: │ 0 │ 1 │ 2 │ 3 │ 4 │ 5 │ 6 │ 7 │ 8 │ 9 │ A │ B │ C │ D │ x │ x │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘

block0 │ block1 │ block2 │ block3
┌───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┼───┬───┬───┬───┐
shm after loading: │ 3 │ 2 │ 1 │ 0 │ 7 │ 6 │ 5 │ 4 │ B │ A │ 9 │ 8 │ D │ C │ x │ x │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘

┌───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┬───┐
global mem after: │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ D │ C │ B │ A │ 9 │ 8 │ 7 │ x │ x │
└───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┴───┘
*/
template<typename T>
__global__ void FlipLastDimGpuForward(const int32_t element, const int64_t last_dim_size,
const T* in_dptr, T* out_dptr) {
__shared__ T shm[ep::CudaStream::kDefaultBlockSize];
CUDA_1D_KERNEL_LOOP(i, element) {
int32_t block_begin_idx = blockDim.x * blockIdx.x;
int32_t thread_end_idx = min(block_begin_idx + blockDim.x, element) - block_begin_idx;
int32_t i_ori = block_begin_idx + (thread_end_idx - threadIdx.x - 1);
shm[threadIdx.x] = in_dptr[i_ori];
__syncthreads();
int32_t row = i_ori / last_dim_size;
int32_t col = last_dim_size - (i_ori - row * last_dim_size) - 1;
out_dptr[row * last_dim_size + col] = shm[threadIdx.x];
}
}

} // namespace

template<typename T>
Expand All @@ -72,14 +117,19 @@ class FlipGpuKernel final : public user_op::OpKernel {
VIS vis;
for (auto x : dims) { vis.val[x] = true; }

if (dims.size() == 1 && dims[0] == x_tensor->shape_view().NumAxes() - 1) {
RUN_CUDA_KERNEL((FlipLastDimGpuForward<T>), ctx->stream(), elem_cnt, elem_cnt,
x_tensor->shape_view().At(total_dims - 1), x_tensor->dptr<T>(),
y_tensor->mut_dptr<T>());
return;
}

SIZE_V sizes_v;
for (int32_t i = 0; i < total_dims; i++) { sizes_v.val[i] = y_tensor->shape_view().At(i); }

// TODO(bbuf) delete strides caluculate, after tensor strides supported
SIZE_V strides_v;
strides_v.val[total_dims - 1] = 1;
for (int32_t i = total_dims - 2; i >= 0; i--) {
strides_v.val[i] = strides_v.val[i + 1] * y_tensor->shape_view().At(i + 1);
for (int32_t i = 0; i < total_dims; i++) {
strides_v.val[i] = CHECK_JUST(VectorAt(y_tensor->stride(), i));
}
RUN_CUDA_KERNEL((FlipGpuForward<T>), ctx->stream(), elem_cnt, elem_cnt, total_dims, sizes_v,
vis, strides_v, x_tensor->dptr<T>(), y_tensor->mut_dptr<T>());
Expand Down
9 changes: 9 additions & 0 deletions python/oneflow/test/modules/test_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,18 @@ def test_flow_flip_bool_tuple_with_random_data(test_case):
y = torch.flip(x, constant((0, 1, 2)))
return y

def test_flow_flip_list_lastdim_with_random_data(test_case):
device = random_device()
x = random_tensor(
ndim=4, dim1=random().to(int), dim2=random().to(int), dim3=random().to(int)
).to(device)
y = torch.flip(x, [-1,])
return y

@profile(torch.flip)
def profile_flip(test_case):
torch.flip(torch.ones(100, 100, 100), [0, 1])
torch.flip(torch.ones(1, 100000), [-1,])


if __name__ == "__main__":
Expand Down
Loading