From 3a615e667565b46ffbc904ff75473a28128da955 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Wed, 10 Apr 2024 16:25:04 -0700 Subject: [PATCH] expose encodeSpecialTokens functionality --- src/lib.rs | 11 ++++-- test/data/sentence-transformers-labse.json | 8 ++++- tokenizer.go | 41 ++++++++++++++++------ tokenizer_test.go | 17 +++++++++ tokenizers.h | 6 +++- 5 files changed, 68 insertions(+), 15 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 9cb93c90..f0be9698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,11 @@ use std::path::PathBuf; use std::ptr; use tokenizers::tokenizer::Tokenizer; +#[repr(C)] +pub struct TokenizerOptions { + encode_special_tokens: bool, +} + #[repr(C)] pub struct Buffer { ids: *mut u32, @@ -14,12 +19,14 @@ pub struct Buffer { } #[no_mangle] -pub extern "C" fn from_bytes(bytes: *const u8, len: u32) -> *mut Tokenizer { +pub extern "C" fn from_bytes(bytes: *const u8, len: u32, opts: &TokenizerOptions) -> *mut Tokenizer { let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) }; - let tokenizer = Tokenizer::from_bytes(bytes_slice).expect("failed to create tokenizer"); + let mut tokenizer = Tokenizer::from_bytes(bytes_slice).expect("failed to create tokenizer"); + tokenizer.set_encode_special_tokens(opts.encode_special_tokens); Box::into_raw(Box::new(tokenizer)) } +// TODO merge with from_bytes and pass truncation params as an argument to TokenizerOptions #[no_mangle] pub extern "C" fn from_bytes_with_truncation(bytes: *const u8, len: u32, max_len: usize, dir: u8) -> *mut Tokenizer { let bytes_slice = unsafe { std::slice::from_raw_parts(bytes, len as usize) }; diff --git a/test/data/sentence-transformers-labse.json b/test/data/sentence-transformers-labse.json index 4495a894..bef55929 100644 --- a/test/data/sentence-transformers-labse.json +++ b/test/data/sentence-transformers-labse.json @@ -146,13 +146,19 @@ "max_input_chars_per_word": 100, "vocab": { "[PAD]": 0, + "[CLS]":101, + "[SEP]":102, "brown": 51775, "fox": 193284, "jumps": 333915, "over": 15444, "the": 14985, "lazy": 221123, - "dog": 22452 + "dog": 22452, + "[":164, + "CLS":304910, + "]":166, + "SEP":211703 } } } \ No newline at end of file diff --git a/tokenizer.go b/tokenizer.go index dda3b84e..8d7c9c11 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -19,6 +19,18 @@ type Tokenizer struct { tokenizer unsafe.Pointer } +type tokenizerOpts struct { + encodeSpecialTokens C.bool +} + +type TokenizerOption func(to *tokenizerOpts) + +func WithEncodeSpecialTokens() TokenizerOption { + return func(to *tokenizerOpts) { + to.encodeSpecialTokens = C.bool(true) + } +} + type TruncationDirection int const ( @@ -28,8 +40,15 @@ const ( var _ io.Closer = (*Tokenizer)(nil) -func FromBytes(data []byte) (*Tokenizer, error) { - tokenizer := C.from_bytes((*C.uchar)(unsafe.Pointer(&data[0])), C.uint(len(data))) +func FromBytes(data []byte, opts ...TokenizerOption) (*Tokenizer, error) { + allOpts := &tokenizerOpts{ + // by default, we do not encode special tokens + encodeSpecialTokens: C.bool(false), + } + for _, opt := range opts { + opt(allOpts) + } + tokenizer := C.from_bytes((*C.uchar)(unsafe.Pointer(&data[0])), C.uint(len(data)), (*C.struct_TokenizerOptions)(unsafe.Pointer(allOpts))) return &Tokenizer{tokenizer: tokenizer}, nil } @@ -62,7 +81,7 @@ type Encoding struct { Tokens []string } -type EncodeOptions struct { +type encodeOpts struct { AddSpecialTokens C.bool ReturnTypeIDs C.bool @@ -71,7 +90,7 @@ type EncodeOptions struct { ReturnAttentionMask C.bool } -type EncodeOption func(eo *EncodeOptions) +type EncodeOption func(eo *encodeOpts) func uintVecToSlice(arrPtr *C.uint, len int) []uint32 { arr := unsafe.Slice(arrPtr, len) @@ -85,7 +104,7 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 { func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) { cStr := C.CString(str) defer C.free(unsafe.Pointer(cStr)) - options := EncodeOptions{ + options := encodeOpts{ AddSpecialTokens: C.bool(addSpecialTokens), ReturnTokens: C.bool(true), } @@ -109,7 +128,7 @@ func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []strin } func WithReturnAllAttributes() EncodeOption { - return func(eo *EncodeOptions) { + return func(eo *encodeOpts) { eo.ReturnTypeIDs = C.bool(true) eo.ReturnSpecialTokensMask = C.bool(true) eo.ReturnAttentionMask = C.bool(true) @@ -118,25 +137,25 @@ func WithReturnAllAttributes() EncodeOption { } func WithReturnTypeIDs() EncodeOption { - return func(eo *EncodeOptions) { + return func(eo *encodeOpts) { eo.ReturnTypeIDs = C.bool(true) } } func WithReturnSpecialTokensMask() EncodeOption { - return func(eo *EncodeOptions) { + return func(eo *encodeOpts) { eo.ReturnSpecialTokensMask = C.bool(true) } } func WithReturnTokens() EncodeOption { - return func(eo *EncodeOptions) { + return func(eo *encodeOpts) { eo.ReturnTokens = C.bool(true) } } func WithReturnAttentionMask() EncodeOption { - return func(eo *EncodeOptions) { + return func(eo *encodeOpts) { eo.ReturnAttentionMask = C.bool(true) } } @@ -145,7 +164,7 @@ func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts .. cStr := C.CString(str) defer C.free(unsafe.Pointer(cStr)) - encOptions := EncodeOptions{ + encOptions := encodeOpts{ AddSpecialTokens: C.bool(addSpecialTokens), } for _, opt := range opts { diff --git a/tokenizer_test.go b/tokenizer_test.go index 9cbb351f..6e4b55ff 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -144,6 +144,23 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { } } +func TestEncodeSpecialTokens(t *testing.T) { + tk, err := tokenizers.FromBytes(embeddedBytes) + require.NoError(t, err) + // special tokens are not encoded by default, + // meaning if input matches a special token, encoding will include the special token + ids, _ := tk.Encode("[CLS]fox[SEP]", false) + assert.Equal(t, []uint32{101, 193284, 102}, ids) + tk.Close() + + tk, err = tokenizers.FromBytes(embeddedBytes, tokenizers.WithEncodeSpecialTokens()) + require.NoError(t, err) + ids, _ = tk.Encode("[CLS]fox[SEP]", false) + // assert that special tokens 101 and 102 are not present + assert.Equal(t, []uint32{164, 304910, 166, 193284, 164, 211703, 166}, ids) + tk.Close() +} + func TestEncodeOptions(t *testing.T) { tk, err := tokenizers.FromFile("./test/data/bert-base-uncased.json") require.NoError(t, err) diff --git a/tokenizers.h b/tokenizers.h index 3fac620a..f89f61d2 100644 --- a/tokenizers.h +++ b/tokenizers.h @@ -9,6 +9,10 @@ struct EncodeOptions { bool return_attention_mask; }; +struct TokenizerOptions { + bool encode_special_tokens; +}; + struct Buffer { uint32_t *ids; uint32_t *type_ids; @@ -18,7 +22,7 @@ struct Buffer { uint32_t len; }; -void *from_bytes(const uint8_t *config, uint32_t len); +void *from_bytes(const uint8_t *config, uint32_t len, const struct TokenizerOptions *options); void *from_bytes_with_truncation(const uint8_t *config, uint32_t len, uint32_t max_len, uint8_t direction);