Skip to content

Commit

Permalink
feat: add set_embeddings method to Context
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Jun 22, 2024
1 parent 138ba89 commit 3bcc7bc
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions ext/llama_cpp/dummy.rb
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,12 @@ def embeddings_ith(i); end
# @return [Array<Float>] shape: (n_embd)
def embeddings_seq(seq_id); end

# Sets whether the model is in embeddings model or not.
#
# @param embd [Boolean] The flag to return embeddings.
# @return [NilClass]
def set_embeddings(embd); end

# Sets the number of threads used for decoding.
#
# @param n_threads [Integer] The number of threads.
Expand Down
11 changes: 11 additions & 0 deletions ext/llama_cpp/llama_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,7 @@ class RbLLaMAContext {
rb_define_method(rb_cLLaMAContext, "embeddings", RUBY_METHOD_FUNC(_llama_context_embeddings), 0);
rb_define_method(rb_cLLaMAContext, "embeddings_ith", RUBY_METHOD_FUNC(_llama_context_embeddings_ith), 1);
rb_define_method(rb_cLLaMAContext, "embeddings_seq", RUBY_METHOD_FUNC(_llama_context_embeddings_seq), 1);
rb_define_method(rb_cLLaMAContext, "set_embeddings", RUBY_METHOD_FUNC(_llama_context_set_embeddings), 1);
rb_define_method(rb_cLLaMAContext, "set_n_threads", RUBY_METHOD_FUNC(_llama_context_set_n_threads), -1);
rb_define_method(rb_cLLaMAContext, "n_ctx", RUBY_METHOD_FUNC(_llama_context_n_ctx), 0);
rb_define_method(rb_cLLaMAContext, "n_batch", RUBY_METHOD_FUNC(_llama_context_n_batch), 0);
Expand Down Expand Up @@ -2357,6 +2358,16 @@ class RbLLaMAContext {
return output;
}

static VALUE _llama_context_set_embeddings(VALUE self, VALUE embs) {
LLaMAContextWrapper* ptr = get_llama_context(self);
if (ptr->ctx == NULL) {
rb_raise(rb_eArgError, "LLaMA context is not initialized");
return Qnil;
}
llama_set_embeddings(ptr->ctx, RTEST(embs) ? true : false);
return Qnil;
}

static VALUE _llama_context_set_n_threads(int argc, VALUE* argv, VALUE self) {
VALUE kw_args = Qnil;
ID kw_table[2] = { rb_intern("n_threads"), rb_intern("n_threads_batch") };
Expand Down
1 change: 1 addition & 0 deletions sig/llama_cpp.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ module LLaMACpp
def embeddings_seq: (Integer) -> Array[Float]
def decode: (::LLaMACpp::Batch) -> void
def logits: () -> Array[Float]
def set_embeddings: (bool) -> void
def set_n_threads: (n_threads: Integer, n_threads_batch: Integer) -> void
def n_ctx: () -> Integer
def n_batch: () -> Integer
Expand Down

0 comments on commit 3bcc7bc

Please sign in to comment.