Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 29, 2024
1 parent 7878c31 commit e9a4363
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,8 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty

auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD;

size_t wkspace_size = std::accumulate(workspace_buf->dimensions().begin(),
workspace_buf->dimensions().end(),
1, std::multiplies<>());
size_t wkspace_size = std::accumulate(workspace_buf->dimensions().begin(),
workspace_buf->dimensions().end(), 1, std::multiplies<>());
DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type());
DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

Expand Down

0 comments on commit e9a4363

Please sign in to comment.