Skip to content

Commit

Permalink
Expose byte fallback for unigram
Browse files Browse the repository at this point in the history
  • Loading branch information
cigrainger committed Dec 13, 2023
1 parent 83ad8e8 commit 170ceac
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 8 deletions.
2 changes: 2 additions & 0 deletions lib/tokenizers/model/bpe.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ defmodule Tokenizers.Model.BPE do
@typedoc """
Options for model initialisation.
* `:byte_fallback`- whether to use the byte fallback trick
* `:cache_capacity` - the number of words that the BPE cache can
contain. The cache allows to speed-up the process by keeping
the result of the merge operations for a number of words.
Expand Down
2 changes: 2 additions & 0 deletions lib/tokenizers/model/unigram.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ defmodule Tokenizers.Model.Unigram do
@typedoc """
Options for model initialisation.
* `:byte_fallback`- whether to use the byte fallback trick
* `:unk_id`- the unknown token id to be used by the model
"""
@type options() :: [
byte_fallback: boolean(),
unk_id: float()
]

Expand Down
28 changes: 20 additions & 8 deletions native/ex_tokenizers/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ pub fn models_info(model: ExTokenizersModel) -> Info {
},
ModelWrapper::Unigram(model) => new_info! {
model_type: "unigram",
min_score: model.min_score
min_score: model.min_score,
byte_fallback: model.byte_fallback()
},
}
}
Expand Down Expand Up @@ -362,23 +363,34 @@ pub fn models_wordlevel_from_file(
#[derive(NifTaggedEnum)]
pub enum UnigramOption {
UnkId(usize),
ByteFallback(bool),
}

#[rustler::nif]
pub fn models_unigram_init(
vocab: Vec<(String, f64)>,
options: Vec<UnigramOption>,
) -> Result<ExTokenizersModel, ExTokenizersError> {
let unk_id = if !options.is_empty() {
match options[0] {
UnigramOption::UnkId(unk_id) => Some(unk_id),
}
} else {
None
let unk_id = match options
.iter()
.find(|opt| matches!(opt, UnigramOption::UnkId(_)))
.unwrap()
{
UnigramOption::UnkId(unk_id) => Some(*unk_id),
_ => None,
};

let byte_fallback = match options
.iter()
.find(|opt| matches!(opt, UnigramOption::ByteFallback(_)))
.unwrap()
{
UnigramOption::ByteFallback(byte_fallback) => *byte_fallback,
_ => false,
};

Ok(ExTokenizersModel::new(
tokenizers::models::unigram::Unigram::from(vocab, unk_id, false)?,
tokenizers::models::unigram::Unigram::from(vocab, unk_id, byte_fallback)?,
))
}

Expand Down

0 comments on commit 170ceac

Please sign in to comment.