diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 18196f3374..84189170da 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -368,9 +368,8 @@ Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Ty auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; - size_t wkspace_size = std::accumulate(workspace_buf->dimensions().begin(), - workspace_buf->dimensions().end(), - 1, std::multiplies<>()); + size_t wkspace_size = std::accumulate(workspace_buf->dimensions().begin(), + workspace_buf->dimensions().end(), 1, std::multiplies<>()); DType dtype = convert_ffi_datatype_to_te_dtype(q_buf.element_type()); DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());