From 83ec42954f28e8cb0043388d7b3acfaf1ec56eff Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 13 Mar 2024 13:55:37 +0800 Subject: [PATCH] Allow changing k parameter of `SampleTopK` as a compiler flag --- configs.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/configs.h b/configs.h index bf25596..7b420b5 100644 --- a/configs.h +++ b/configs.h @@ -23,11 +23,17 @@ #define GEMMA_MAX_SEQLEN 4096 #endif // !GEMMA_MAX_SEQLEN +// Allow changing k parameter of `SampleTopK` as a compiler flag +#ifndef GEMMA_TOPK +#define GEMMA_TOPK 1 +#endif // !GEMMA_TOPK + #include namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; +static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; @@ -38,7 +44,7 @@ struct ConfigGemma7B { static constexpr int kHeads = 16; static constexpr int kKVHeads = 16; // standard MHA static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = 1; + static constexpr int kTopK = gcpp::kTopK; }; struct ConfigGemma2B { @@ -50,7 +56,7 @@ struct ConfigGemma2B { static constexpr int kHeads = 8; static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = 1; + static constexpr int kTopK = gcpp::kTopK; }; } // namespace gcpp