From 170ceac858063feb9ef1092e27917478489c5896 Mon Sep 17 00:00:00 2001 From: Christopher Grainger Date: Wed, 13 Dec 2023 17:17:56 +0100 Subject: [PATCH] Expose byte fallback for unigram --- lib/tokenizers/model/bpe.ex | 2 ++ lib/tokenizers/model/unigram.ex | 2 ++ native/ex_tokenizers/src/models.rs | 28 ++++++++++++++++++++-------- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/lib/tokenizers/model/bpe.ex b/lib/tokenizers/model/bpe.ex index 761c178..67f6578 100644 --- a/lib/tokenizers/model/bpe.ex +++ b/lib/tokenizers/model/bpe.ex @@ -2,6 +2,8 @@ defmodule Tokenizers.Model.BPE do @typedoc """ Options for model initialisation. + * `:byte_fallback`- whether to use the byte fallback trick + * `:cache_capacity` - the number of words that the BPE cache can contain. The cache allows to speed-up the process by keeping the result of the merge operations for a number of words. diff --git a/lib/tokenizers/model/unigram.ex b/lib/tokenizers/model/unigram.ex index fe91976..3d943a1 100644 --- a/lib/tokenizers/model/unigram.ex +++ b/lib/tokenizers/model/unigram.ex @@ -2,10 +2,12 @@ defmodule Tokenizers.Model.Unigram do @typedoc """ Options for model initialisation. + * `:byte_fallback`- whether to use the byte fallback trick * `:unk_id`- the unknown token id to be used by the model """ @type options() :: [ + byte_fallback: boolean(), unk_id: float() ] diff --git a/native/ex_tokenizers/src/models.rs b/native/ex_tokenizers/src/models.rs index cc8b942..21e2661 100644 --- a/native/ex_tokenizers/src/models.rs +++ b/native/ex_tokenizers/src/models.rs @@ -168,7 +168,8 @@ pub fn models_info(model: ExTokenizersModel) -> Info { }, ModelWrapper::Unigram(model) => new_info! { model_type: "unigram", - min_score: model.min_score + min_score: model.min_score, + byte_fallback: model.byte_fallback() }, } } @@ -362,6 +363,7 @@ pub fn models_wordlevel_from_file( #[derive(NifTaggedEnum)] pub enum UnigramOption { UnkId(usize), + ByteFallback(bool), } #[rustler::nif] @@ -369,16 +371,26 @@ pub fn models_unigram_init( vocab: Vec<(String, f64)>, options: Vec, ) -> Result { - let unk_id = if !options.is_empty() { - match options[0] { - UnigramOption::UnkId(unk_id) => Some(unk_id), - } - } else { - None + let unk_id = match options + .iter() + .find(|opt| matches!(opt, UnigramOption::UnkId(_))) + .unwrap() + { + UnigramOption::UnkId(unk_id) => Some(*unk_id), + _ => None, + }; + + let byte_fallback = match options + .iter() + .find(|opt| matches!(opt, UnigramOption::ByteFallback(_))) + .unwrap() + { + UnigramOption::ByteFallback(byte_fallback) => *byte_fallback, + _ => false, }; Ok(ExTokenizersModel::new( - tokenizers::models::unigram::Unigram::from(vocab, unk_id, false)?, + tokenizers::models::unigram::Unigram::from(vocab, unk_id, byte_fallback)?, )) }