From a4fcfc8f0aec371a9e8a6a878b426f4a89c1827f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Sat, 2 Mar 2024 00:53:07 +0700 Subject: [PATCH] Add tokenizer --- lib/bumblebee.ex | 1 + lib/bumblebee/text/pre_trained_tokenizer.ex | 10 ++++++++++ test/bumblebee/text/generation_test.exs | 2 +- .../text/pre_trained_tokenizer_test.exs | 18 ++++++++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index a78eedea..c7fd810d 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -238,6 +238,7 @@ defmodule Bumblebee do "llama" => :llama, "mistral" => :llama, "mbart" => :mbart, + "phi" => :code_gen, "roberta" => :roberta, "t5" => :t5, "whisper" => :whisper, diff --git a/lib/bumblebee/text/pre_trained_tokenizer.ex b/lib/bumblebee/text/pre_trained_tokenizer.ex index 353baa3a..ea2c20a9 100644 --- a/lib/bumblebee/text/pre_trained_tokenizer.ex +++ b/lib/bumblebee/text/pre_trained_tokenizer.ex @@ -127,6 +127,16 @@ defmodule Bumblebee.Text.PreTrainedTokenizer do clip: %{ special_tokens: %{unk: "<|endoftext|>", pad: "<|endoftext|>", eos: "<|endoftext|>"} }, + code_gen: %{ + special_tokens: %{ + unk: "<|endoftext|>", + bos: "<|endoftext|>", + eos: "<|endoftext|>", + # CodeGen doesn't originally have a pad token, however when necessary + # we pad with the EOS token + pad: "<|endoftext|>" + } + }, distilbert: %{ special_tokens: %{unk: "[UNK]", sep: "[SEP]", pad: "[PAD]", cls: "[CLS]", mask: "[MASK]"} }, diff --git a/test/bumblebee/text/generation_test.exs b/test/bumblebee/text/generation_test.exs index 22ed5631..cf82ce51 100644 --- a/test/bumblebee/text/generation_test.exs +++ b/test/bumblebee/text/generation_test.exs @@ -1,4 +1,4 @@ -defmodule Bumblebee.Text.TextGenerationTest do +defmodule Bumblebee.Text.GenerationTest do use ExUnit.Case, async: false import Bumblebee.TestHelpers diff --git a/test/bumblebee/text/pre_trained_tokenizer_test.exs b/test/bumblebee/text/pre_trained_tokenizer_test.exs index 91100e21..8f4f91f0 100644 --- a/test/bumblebee/text/pre_trained_tokenizer_test.exs +++ b/test/bumblebee/text/pre_trained_tokenizer_test.exs @@ -164,6 +164,24 @@ defmodule Bumblebee.Text.PreTrainedTokenizerTest do ) end + test ":code_gen" do + assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "microsoft/phi-2"}) + + assert %Bumblebee.Text.PreTrainedTokenizer{type: :code_gen} = tokenizer + + inputs = Bumblebee.apply_tokenizer(tokenizer, ["Hello everyobdy, how are you?"]) + + assert_equal( + inputs["input_ids"], + Nx.tensor([[15496, 790, 672, 9892, 11, 703, 389, 345, 30]]) + ) + + assert_equal( + inputs["attention_mask"], + Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]]) + ) + end + test ":distilbert" do assert {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "distilbert/distilbert-base-uncased"})