Skip to content

Commit

Permalink
Add tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Mar 1, 2024
1 parent 8412007 commit a4fcfc8
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ defmodule Bumblebee do
"llama" => :llama,
"mistral" => :llama,
"mbart" => :mbart,
"phi" => :code_gen,
"roberta" => :roberta,
"t5" => :t5,
"whisper" => :whisper,
Expand Down
10 changes: 10 additions & 0 deletions lib/bumblebee/text/pre_trained_tokenizer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do
clip: %{
special_tokens: %{unk: "<|endoftext|>", pad: "<|endoftext|>", eos: "<|endoftext|>"}
},
code_gen: %{
special_tokens: %{
unk: "<|endoftext|>",
bos: "<|endoftext|>",
eos: "<|endoftext|>",
# CodeGen doesn't originally have a pad token, however when necessary
# we pad with the EOS token
pad: "<|endoftext|>"
}
},
distilbert: %{
special_tokens: %{unk: "[UNK]", sep: "[SEP]", pad: "[PAD]", cls: "[CLS]", mask: "[MASK]"}
},
Expand Down
2 changes: 1 addition & 1 deletion test/bumblebee/text/generation_test.exs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
defmodule Bumblebee.Text.TextGenerationTest do
defmodule Bumblebee.Text.GenerationTest do
use ExUnit.Case, async: false

import Bumblebee.TestHelpers
Expand Down
18 changes: 18 additions & 0 deletions test/bumblebee/text/pre_trained_tokenizer_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,24 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do
)
end

test ":code_gen" do
assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "microsoft/phi-2"})

assert %Bumblebee.Text.PreTrainedTokenizer{type: :code_gen} = tokenizer

inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello everyobdy, how are you?"])

assert_equal(
inputs["input_ids"],
Nx.tensor([[15496, 790, 672, 9892, 11, 703, 389, 345, 30]])
)

assert_equal(
inputs["attention_mask"],
Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])
)
end

test ":distilbert" do
assert {:ok, tokenizer} =
Bumblebee.load_tokenizer({:hf, "distilbert/distilbert-base-uncased"})
Expand Down

0 comments on commit a4fcfc8

Please sign in to comment.