diff --git a/ext/llama_cpp/dummy.rb b/ext/llama_cpp/dummy.rb index 55f4f37..96d6f51 100644 --- a/ext/llama_cpp/dummy.rb +++ b/ext/llama_cpp/dummy.rb @@ -779,6 +779,12 @@ def embeddings_ith(i); end # @return [Array] shape: (n_embd) def embeddings_seq(seq_id); end + # Sets whether the model is in embeddings model or not. + # + # @param embd [Boolean] The flag to return embeddings. + # @return [NilClass] + def set_embeddings(embd); end + # Sets the number of threads used for decoding. # # @param n_threads [Integer] The number of threads. diff --git a/ext/llama_cpp/llama_cpp.cpp b/ext/llama_cpp/llama_cpp.cpp index 74d836e..6650c44 100644 --- a/ext/llama_cpp/llama_cpp.cpp +++ b/ext/llama_cpp/llama_cpp.cpp @@ -2133,6 +2133,7 @@ class RbLLaMAContext { rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0); rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1); rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1); + rb_define_method(rb_cLLaMAContext, "set_embeddings", RUBY_METHOD_FUNC(_llama_context_set_embeddings), 1); rb_define_method(rb_cLLaMAContext, "set_n_threads", RUBY_METHOD_FUNC(_llama_context_set_n_threads), -1); rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0); rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0); @@ -2357,6 +2358,16 @@ class RbLLaMAContext { return output; } + static VALUE _llama_context_set_embeddings(VALUE self, VALUE embs) { + LLaMAContextWrapper* ptr = get_llama_context(self); + if (ptr->ctx == NULL) { + rb_raise(rb_eArgError, "LLaMA context is not initialized"); + return Qnil; + } + llama_set_embeddings(ptr->ctx, RTEST(embs) ? true : false); + return Qnil; + } + static VALUE _llama_context_set_n_threads(int argc, VALUE* argv, VALUE self) { VALUE kw_args = Qnil; ID kw_table[2] = { rb_intern("n_threads"), rb_intern("n_threads_batch") }; diff --git a/sig/llama_cpp.rbs b/sig/llama_cpp.rbs index 24cc854..ba0e106 100644 --- a/sig/llama_cpp.rbs +++ b/sig/llama_cpp.rbs @@ -259,6 +259,7 @@ module LLaMACpp def embeddings_seq: (Integer) -> Array[Float] def decode: (::LLaMACpp::Batch) -> void def logits: () -> Array[Float] + def set_embeddings: (bool) -> void def set_n_threads: (n_threads: Integer, n_threads_batch: Integer) -> void def n_ctx: () -> Integer def n_batch: () -> Integer