Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] How to do matmul(A, B^T)? #129

Open
mgrabban opened this issue Aug 29, 2024 · 1 comment
Open

[QST] How to do matmul(A, B^T)? #129

mgrabban opened this issue Aug 29, 2024 · 1 comment
Labels
question Further information is requested

Comments

@mgrabban
Copy link

How to do $$matmul(A, B^T)$$?

I was trying to modify this sycl/pvc example, which I believe does $$matmul(A, B)$$, to do $$matmul(A, B^T)$$ (i.e., my B input is transposed) but the verification is failing.

I made these changes

  • changed line (using LayoutB = cutlass::layout::RowMajor;) to using LayoutB = cutlass::layout::ColumnMajor; and
  • changed line (cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); to cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({N, K}));

This line could also be relevant so I also tried with or without the following change, in addition to the above two changes,

  • changed line (stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));) to stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L));

Is it possible to modify that example code to do $$matmul(A, B^T)$$? If yes, could you please help me modify the code correctly?

I also suspect it might not be possible since the MMA_Atom MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT> might not support it but I'm not sure.

Complementary questions:
Does PVC support BF16 inputs only when using xe cores or other input types are also supported? If it does support other input types, what is your plan to create MMA_Atoms for them?
Thanks!

@mgrabban mgrabban added the question Further information is requested label Aug 29, 2024
@mehdi-goli
Copy link
Collaborator

mehdi-goli commented Aug 29, 2024

This line could also be relevant so I also tried with or without the following change, in addition to the above two changes,

  • changed line (stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));) to stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(K, N, L));

You don't need this changes as by converting the RowMajor, to Colmajor the tag_to_stride_B will capture the correct make_cute_packed_stride function for selecting N as an stride. So this change is not needed

Is it possible to modify that example code to do m a t m u l ( A , B T ) ? If yes, could you please help me modify the code correctly?
I also suspect it might not be possible since the MMA_Atom MMA_Atom<XE_8x16x16_F32BF16BF16F32_TT> might not support it but I'm not sure.

At the moment the functionality to load B transpose for xe_copy are not integrated to the pipeline. we are in the process of adding them and we let you know once it is done.

Complementary questions: Does PVC support BF16 inputs only when using xe cores or other input types are also supported? If it does support other input types, what is your plan to create MMA_Atoms for them? Thanks!

Currently it is for BF16, Adding others are in progress.

@github-staff github-staff deleted a comment from louseee Oct 7, 2024
@github-staff github-staff deleted a comment from louseee Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants