From 3bcc7bc0f4ce57138efa9f995b521fbc423f13dc Mon Sep 17 00:00:00 2001 From: yoshoku Date: Sat, 22 Jun 2024 15:44:46 +0900 Subject: [PATCH] feat: add set_embeddings method to Context --- ext/llama_cpp/dummy.rb | 6 ++++++ ext/llama_cpp/llama_cpp.cpp | 11 +++++++++++ sig/llama_cpp.rbs | 1 + 3 files changed, 18 insertions(+) diff --git a/ext/llama_cpp/dummy.rb b/ext/llama_cpp/dummy.rb index 55f4f37..96d6f51 100644 --- a/ext/llama_cpp/dummy.rb +++ b/ext/llama_cpp/dummy.rb @@ -779,6 +779,12 @@ def embeddings_ith(i); end # @return [Array] shape: (n_embd) def embeddings_seq(seq_id); end + # Sets whether the model is in embeddings model or not. + # + # @param embd [Boolean] The flag to return embeddings. + # @return [NilClass] + def set_embeddings(embd); end + # Sets the number of threads used for decoding. # # @param n_threads [Integer] The number of threads. diff --git a/ext/llama_cpp/llama_cpp.cpp b/ext/llama_cpp/llama_cpp.cpp index 74d836e..6650c44 100644 --- a/ext/llama_cpp/llama_cpp.cpp +++ b/ext/llama_cpp/llama_cpp.cpp @@ -2133,6 +2133,7 @@ class RbLLaMAContext { rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0); rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1); rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1); + rb_define_method(rb_cLLaMAContext, "set_embeddings", RUBY_METHOD_FUNC(_llama_context_set_embeddings), 1); rb_define_method(rb_cLLaMAContext, "set_n_threads", RUBY_METHOD_FUNC(_llama_context_set_n_threads), -1); rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0); rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0); @@ -2357,6 +2358,16 @@ class RbLLaMAContext { return output; } + static VALUE _llama_context_set_embeddings(VALUE self, VALUE embs) { + LLaMAContextWrapper* ptr = get_llama_context(self); + if (ptr->ctx == NULL) { + rb_raise(rb_eArgError, "LLaMA context is not initialized"); + return Qnil; + } + llama_set_embeddings(ptr->ctx, RTEST(embs) ? true : false); + return Qnil; + } + static VALUE _llama_context_set_n_threads(int argc, VALUE* argv, VALUE self) { VALUE kw_args = Qnil; ID kw_table[2] = { rb_intern("n_threads"), rb_intern("n_threads_batch") }; diff --git a/sig/llama_cpp.rbs b/sig/llama_cpp.rbs index 24cc854..ba0e106 100644 --- a/sig/llama_cpp.rbs +++ b/sig/llama_cpp.rbs @@ -259,6 +259,7 @@ module LLaMACpp def embeddings_seq: (Integer) -> Array[Float] def decode: (::LLaMACpp::Batch) -> void def logits: () -> Array[Float] + def set_embeddings: (bool) -> void def set_n_threads: (n_threads: Integer, n_threads_batch: Integer) -> void def n_ctx: () -> Integer def n_batch: () -> Integer