-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use llama binary instead of extending ruby
Using the binary simplifies the code, but the biggest reason for this change is that wrapping the binary is significantly faster. I'm not sure why this is, but the same C++ code compiled as a Ruby extension takes about 4x as long to run. I'm keeping the code in ext/ because it's sort of like a native extension in that it compiles code on the user's machine. This seems like the least confusing way to manage that. (E.g. if the compilation fails, the user gets an error saying "ERROR: Failed to build gem native extension" rather than "some random script failed.")
- Loading branch information
Showing
11 changed files
with
90 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
/tmp/ | ||
|
||
# ext files | ||
/ext/llama/llama | ||
/bin/llama | ||
|
||
# rspec failure tracking | ||
.rspec_status | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,7 @@ | ||
PATH | ||
remote: . | ||
specs: | ||
llama-rb (0.1.0) | ||
llama-rb (0.2.0) | ||
|
||
GEM | ||
remote: https://rubygems.org/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,5 +5,4 @@ set -vx | |
|
||
bundle install | ||
|
||
cd ext/llama | ||
ruby extconf.rb | ||
ruby ext/extconf.rb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# dummy file to make gem installer happy | ||
all: | ||
clean: | ||
install: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
require 'fileutils' | ||
|
||
root = File.dirname(__FILE__) | ||
llama_root = File.join(root, '..', 'llama.cpp') | ||
|
||
main = File.join(root, '..', 'bin', 'llama') | ||
llama_main = File.join(llama_root, 'main') | ||
|
||
Dir.chdir(llama_root) { system('make main', exception: true) } | ||
FileUtils.cp(llama_main, main) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,58 @@ | ||
require 'open3' | ||
require 'shellwords' | ||
|
||
module Llama | ||
class Model | ||
# move methods defined in `model.cpp` from public to private | ||
# private :initialize_cpp, :predict_cpp | ||
# | ||
# # rubocop:disable Metrics/MethodLength | ||
# def self.new( | ||
# model, # path to model file, e.g. "models/7B/ggml-model-q4_0.bin" | ||
# n_ctx: 512, # context size | ||
# n_parts: -1, # amount of model parts (-1 = determine from model dimensions) | ||
# seed: Time.now.to_i, # RNG seed | ||
# memory_f16: true, # use f16 instead of f32 for memory kv | ||
# use_mlock: false # use mlock to keep model in memory | ||
# ) | ||
# instance = allocate | ||
# | ||
# instance.instance_eval do | ||
# initialize | ||
# | ||
# @model = model | ||
# @n_ctx = n_ctx | ||
# @n_parts = n_parts | ||
# @seed = seed | ||
# @memory_f16 = memory_f16 | ||
# @use_mlock = use_mlock | ||
# | ||
# capture_stderr do | ||
# initialize_cpp( | ||
# model, | ||
# n_ctx, | ||
# n_parts, | ||
# seed, | ||
# memory_f16, | ||
# use_mlock, | ||
# ) | ||
# end | ||
# end | ||
# | ||
# instance | ||
# end | ||
# # rubocop:enable Metrics/MethodLength | ||
# | ||
# def predict( | ||
# prompt, # string used as prompt | ||
# n_predict: 128 # number of tokens to predict | ||
# ) | ||
# text = '' | ||
# | ||
# capture_stderr { text = predict_cpp(prompt, n_predict) } | ||
# | ||
# process_text(text) | ||
# end | ||
# | ||
# attr_reader :model, :n_ctx, :n_parts, :seed, :memory_f16, :use_mlock, :stderr | ||
# | ||
# private | ||
# | ||
# def capture_stderr | ||
# previous = $stderr.dup | ||
# tmp = Tempfile.open('llama-rb-stderr') | ||
# | ||
# begin | ||
# $stderr.reopen(tmp) | ||
# | ||
# yield | ||
# | ||
# tmp.rewind | ||
# @stderr = tmp.read | ||
# ensure | ||
# tmp.close(true) | ||
# $stderr.reopen(previous) | ||
# end | ||
# end | ||
# | ||
# def process_text(text) | ||
# text = text.force_encoding(Encoding.default_external) | ||
# | ||
# # remove the space that was added as a tokenizer hack in model.cpp | ||
# text[0] = '' if text.size.positive? | ||
# | ||
# text | ||
# end | ||
class ModelError < StandardError | ||
end | ||
|
||
def initialize( | ||
model, | ||
seed: Time.now.to_i, | ||
n_predict: 128, | ||
binary: default_binary | ||
) | ||
@model = model | ||
@seed = seed | ||
@n_predict = n_predict | ||
@binary = binary | ||
end | ||
|
||
def predict(prompt) | ||
stdout, @stderr, @status = Open3.capture3(command(prompt)) | ||
|
||
raise ModelError, "Error #{status.to_i}" unless status.success? | ||
|
||
# remove the space that is added as a tokenizer hack in examples/main/main.cpp | ||
stdout[0] = '' | ||
stdout | ||
end | ||
|
||
attr_reader :model, :seed, :n_predict, :binary | ||
|
||
private | ||
|
||
attr_reader :stderr, :status | ||
|
||
def default_binary | ||
File.join(File.dirname(__FILE__), '..', '..', 'bin', 'llama') | ||
end | ||
|
||
def command(prompt) | ||
escape_command(binary, | ||
model: model, | ||
prompt: prompt, | ||
seed: seed, | ||
n_predict: n_predict) | ||
end | ||
|
||
def escape_command(command, **flags) | ||
flags_string = flags.map do |key, value| | ||
"--#{Shellwords.escape(key)} #{Shellwords.escape(value)}" | ||
end.join(' ') | ||
command_string = Shellwords.escape(command) | ||
|
||
"#{command_string} #{flags_string}" | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
module Llama | ||
VERSION = '0.1.0'.freeze | ||
VERSION = '0.2.0'.freeze | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
RSpec.describe Llama::Model do | ||
subject(:model) { described_class.new('models/7B/ggml-model-q4_0.bin', seed: 2) } | ||
subject(:model) { described_class.new('models/7B/ggml-model-q4_0.bin', seed: 10, n_predict: 1) } | ||
|
||
it 'predicts text' do | ||
expect( | ||
model.predict('The most common words for testing a new programming language are: h', n_predict: 2), | ||
model.predict('hello, wo'), | ||
).to eq( | ||
'The most common words for testing a new programming language are: hmmm', | ||
'hello, woof', | ||
) | ||
end | ||
end |