diff --git a/configs.h b/configs.h index bf25596..7b420b5 100644 --- a/configs.h +++ b/configs.h @@ -23,11 +23,17 @@ #define GEMMA_MAX_SEQLEN 4096 #endif // !GEMMA_MAX_SEQLEN +// Allow changing k parameter of `SampleTopK` as a compiler flag +#ifndef GEMMA_TOPK +#define GEMMA_TOPK 1 +#endif // !GEMMA_TOPK + #include namespace gcpp { static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN; +static constexpr size_t kTopK = GEMMA_TOPK; struct ConfigGemma7B { static constexpr int kSeqLen = gcpp::kSeqLen; @@ -38,7 +44,7 @@ struct ConfigGemma7B { static constexpr int kHeads = 16; static constexpr int kKVHeads = 16; // standard MHA static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = 1; + static constexpr int kTopK = gcpp::kTopK; }; struct ConfigGemma2B { @@ -50,7 +56,7 @@ struct ConfigGemma2B { static constexpr int kHeads = 8; static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support static constexpr int kQKVDim = 256; // query size == key size == value size - static constexpr int kTopK = 1; + static constexpr int kTopK = gcpp::kTopK; }; } // namespace gcpp