Skip to content

Commit

Permalink
1.07x speedup: merge MQA parallel sections as suggested by @veluca93
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621772392
  • Loading branch information
jan-wassenberg authored and copybara-github committed Apr 4, 2024
1 parent ede337f commit 44e6274
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions gemma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -405,18 +405,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
});
} else {
// Multi-Query Attention
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
});

constexpr const size_t q_offset = kHeads * kQKVDim * kModelDim;
constexpr const size_t k_offset = q_offset + 0 * kQKVDim * kModelDim;
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;

ProjKV(k_offset, v_offset, kv_offset);

pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
ProjQ(head, head * kQKVDim * kModelDim);
Attn(head, 0);
});
}
Expand Down

0 comments on commit 44e6274

Please sign in to comment.