diff --git a/src/lib.rs b/src/lib.rs index f0be9698..9825d3ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ pub struct Buffer { special_tokens_mask: *mut u32, attention_mask: *mut u32, tokens: *mut *mut libc::c_char, + offsets: *mut usize, len: usize, } @@ -68,6 +69,7 @@ pub struct EncodeOptions { return_tokens: bool, return_special_tokens_mask: bool, return_attention_mask: bool, + return_offsets: bool, } #[no_mangle] @@ -79,7 +81,7 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o let message_cstr = unsafe { CStr::from_ptr(message) }; let message = message_cstr.to_str(); if message.is_err() { - return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0, type_ids: ptr::null_mut(), special_tokens_mask: ptr::null_mut(), attention_mask: ptr::null_mut() }; + return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0, type_ids: ptr::null_mut(), special_tokens_mask: ptr::null_mut(), attention_mask: ptr::null_mut() , offsets: ptr::null_mut()}; } let encoding = tokenizer.encode(message.unwrap(), options.add_special_tokens).expect("failed to encode input"); @@ -124,7 +126,20 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o std::mem::forget(vec_attention_mask); } - Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, len } + let mut offsets: *mut usize = ptr::null_mut(); + if options.return_offsets { + let vec_offsets_tuples = encoding.get_offsets().to_vec(); + let mut vec_offsets = Vec::with_capacity(vec_offsets_tuples.len() * 2); + for i in vec_offsets_tuples { + vec_offsets.push(i.0); + vec_offsets.push(i.1); + } + vec_offsets.shrink_to_fit(); + offsets = vec_offsets.as_mut_ptr(); + std::mem::forget(vec_offsets); + } + + Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, offsets, len } } #[no_mangle] @@ -183,6 +198,11 @@ pub extern "C" fn free_buffer(buf: Buffer) { Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len); } } + if !buf.offsets.is_null() { + unsafe { + Vec::from_raw_parts(buf.offsets, buf.len*2, buf.len*2); + } + } if !buf.tokens.is_null() { unsafe { let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len); diff --git a/tokenizer.go b/tokenizer.go index 87d55d48..80c0a17c 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -73,12 +73,15 @@ func (t *Tokenizer) Close() error { return nil } +type Offset [2]uint + type Encoding struct { IDs []uint32 TypeIDs []uint32 SpecialTokensMask []uint32 AttentionMask []uint32 Tokens []string + Offsets []Offset } type encodeOpts struct { @@ -88,6 +91,7 @@ type encodeOpts struct { ReturnTokens C.bool ReturnSpecialTokensMask C.bool ReturnAttentionMask C.bool + ReturnOffsets C.bool } type EncodeOption func(eo *encodeOpts) @@ -101,6 +105,18 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 { return slice } +func offsetVecToSlice(arrPtr *C.size_t, tokenLength int) []Offset { + arr := unsafe.Slice(arrPtr, tokenLength*2) + slice := make([]Offset, tokenLength) + counter := 0 + for i := 0; i < tokenLength; i++ { + offset := Offset{uint(arr[counter]), uint(arr[counter+1])} + slice[i] = offset + counter = counter + 2 + } + return slice +} + func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) { cStr := C.CString(str) defer C.free(unsafe.Pointer(cStr)) @@ -133,6 +149,7 @@ func WithReturnAllAttributes() EncodeOption { eo.ReturnSpecialTokensMask = C.bool(true) eo.ReturnAttentionMask = C.bool(true) eo.ReturnTokens = C.bool(true) + eo.ReturnOffsets = C.bool(true) } } @@ -160,6 +177,12 @@ func WithReturnAttentionMask() EncodeOption { } } +func WithReturnOffsets() EncodeOption { + return func(eo *encodeOpts) { + eo.ReturnOffsets = C.bool(true) + } +} + func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts ...EncodeOption) Encoding { cStr := C.CString(str) defer C.free(unsafe.Pointer(cStr)) @@ -201,6 +224,10 @@ func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts .. encoding.AttentionMask = uintVecToSlice(res.attention_mask, len) } + if encOptions.ReturnOffsets && res.offsets != nil { + encoding.Offsets = offsetVecToSlice(res.offsets, len) + } + return encoding } diff --git a/tokenizer_test.go b/tokenizer_test.go index 6e4b55ff..da2c1fda 100644 --- a/tokenizer_test.go +++ b/tokenizer_test.go @@ -35,6 +35,7 @@ func TestEmbeddingConfig(t *testing.T) { wantTokens []string wantSpecialTokensMask []uint32 wantAttentionMask []uint32 + wantOffsets []tokenizers.Offset }{ { name: "without special tokens", @@ -45,6 +46,7 @@ func TestEmbeddingConfig(t *testing.T) { wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"}, wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}}, }, { name: "with special tokens", @@ -55,6 +57,7 @@ func TestEmbeddingConfig(t *testing.T) { wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"}, wantSpecialTokensMask: []uint32{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}, {0x0, 0x0}}, }, } for _, tt := range tests { @@ -65,6 +68,7 @@ func TestEmbeddingConfig(t *testing.T) { assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens") assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets") ids, tokens := tk.Encode(tt.str, tt.addSpecial) assert.Equal(t, tt.wantIDs, ids, "wrong ids") @@ -86,6 +90,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { wantTokens []string wantSpecialTokensMask []uint32 wantAttentionMask []uint32 + wantOffsets []tokenizers.Offset }{ { name: "without special tokens", @@ -96,6 +101,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"}, wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}}, }, { name: "with special tokens", @@ -106,6 +112,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"}, wantSpecialTokensMask: []uint32{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}, {0x0, 0x0}}, }, { name: "empty string", @@ -121,6 +128,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { wantAttentionMask: []uint32{0x1, 0x1}, wantIDs: []uint32{101, 102}, wantTokens: []string{"[CLS]", "[SEP]"}, + wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x0}}, }, { name: "invalid utf8 string", @@ -136,6 +144,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) { assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens") assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets mask") ids, tokens := tk.Encode(tt.str, tt.addSpecial) assert.Equal(t, tt.wantIDs, ids, "wrong ids") @@ -174,6 +183,7 @@ func TestEncodeOptions(t *testing.T) { wantTokens []string wantSpecialTokensMask []uint32 wantAttentionMask []uint32 + wantOffsets []tokenizers.Offset }{ { name: "without special tokens", @@ -184,6 +194,7 @@ func TestEncodeOptions(t *testing.T) { wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"}, wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1}, + wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}}, }, } for _, tt := range tests { @@ -194,6 +205,7 @@ func TestEncodeOptions(t *testing.T) { assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens") assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets") encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnTokens()) assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids") @@ -201,6 +213,7 @@ func TestEncodeOptions(t *testing.T) { assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens") assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets") encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnTypeIDs()) assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids") @@ -208,6 +221,7 @@ func TestEncodeOptions(t *testing.T) { assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens") assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets") encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnSpecialTokensMask()) assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids") @@ -215,6 +229,7 @@ func TestEncodeOptions(t *testing.T) { assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens") assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets") encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnAttentionMask()) assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids") @@ -222,6 +237,15 @@ func TestEncodeOptions(t *testing.T) { assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens") assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets") + + encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnOffsets()) + assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids") + assert.Equal(t, []uint32(nil), encoding.TypeIDs, "wrong type ids") + assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens") + assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask") + assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets") }) } } @@ -300,6 +324,7 @@ func TestEncodeWithPadding(t *testing.T) { wantTokens []string wantSpecialTokensMask []uint32 wantAttentionMask []uint32 + wantOffsets []tokenizers.Offset }{ { name: "sentence with padding", @@ -310,6 +335,7 @@ func TestEncodeWithPadding(t *testing.T) { wantTokens: []string{"this", "short", "sentence", "[PAD]", "[PAD]", "[PAD]", "[PAD]", "[PAD]"}, wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x1}, wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0}, + wantOffsets: []tokenizers.Offset{{0x0, 0x4}, {0x5, 0xa}, {0xb, 0x13}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}}, }, } for _, tt := range tests { @@ -320,6 +346,7 @@ func TestEncodeWithPadding(t *testing.T) { assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens") assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask") assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask") + assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets") ids, tokens := tk.Encode(tt.str, tt.addSpecial) assert.Equal(t, tt.wantIDs, ids, "wrong ids") diff --git a/tokenizers.h b/tokenizers.h index f89f61d2..db881da3 100644 --- a/tokenizers.h +++ b/tokenizers.h @@ -7,6 +7,7 @@ struct EncodeOptions { bool return_tokens; bool return_special_tokens_mask; bool return_attention_mask; + bool return_offsets; }; struct TokenizerOptions { @@ -19,6 +20,7 @@ struct Buffer { uint32_t *special_tokens_mask; uint32_t *attention_mask; char *tokens; + size_t *offsets; uint32_t len; };