Skip to content

Commit

Permalink
feat(llama.cpp): expose cache_type_k and cache_type_v for quant of kv…
Browse files Browse the repository at this point in the history
… cache (#4329)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler authored Dec 6, 2024
1 parent 88737e1 commit d4c1746
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 2 deletions.
3 changes: 3 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ message ModelOptions {
repeated float LoraScales = 61;

repeated string Options = 62;

string CacheTypeKey = 63;
string CacheTypeValue = 64;
}

message Result {
Expand Down
6 changes: 6 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,12 @@ static void params_parse(const backend::ModelOptions* request,
}
// params.model_alias ??
params.model_alias = request->modelfile();
if (!request->cachetypekey().empty()) {
params.cache_type_k = request->cachetypekey();
}
if (!request->cachetypevalue().empty()) {
params.cache_type_v = request->cachetypevalue();
}
params.n_ctx = request->contextsize();
//params.memory_f16 = request->f16memory();
params.cpuparams.n_threads = request->threads();
Expand Down
2 changes: 2 additions & 0 deletions core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
TensorParallelSize: int32(c.TensorParallelSize),
MMProj: c.MMProj,
FlashAttention: c.FlashAttention,
CacheTypeKey: c.CacheTypeK,
CacheTypeValue: c.CacheTypeV,
NoKVOffload: c.NoKVOffloading,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
Expand Down
6 changes: 4 additions & 2 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ type LLMConfig struct {
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`

FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`
CacheTypeK string `yaml:"cache_type_k"`
CacheTypeV string `yaml:"cache_type_v"`

RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`
Expand Down

0 comments on commit d4c1746

Please sign in to comment.