From 8ec40b9a6e0f431aec6541d8d43c3707d3733721 Mon Sep 17 00:00:00 2001 From: Daulet Zhanguzin Date: Mon, 4 Nov 2024 17:01:38 -0800 Subject: [PATCH] fix: rename new func to FromPretrained, improve example --- example/main.go | 94 ++++++++++++++++++++++++++++--------------------- tokenizer.go | 5 ++- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/example/main.go b/example/main.go index 3c737322..ed462068 100644 --- a/example/main.go +++ b/example/main.go @@ -2,13 +2,15 @@ package main import ( "fmt" + "log" + "github.com/daulet/tokenizers" ) -func main() { +func simple() error { tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") if err != nil { - panic(err) + return err } // release native resources defer tk.Close() @@ -22,49 +24,61 @@ func main() { fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true)) // brown fox jumps over the lazy dog - var encodeOptions []tokenizers.EncodeOption - encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets()) - encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask()) + return nil - // Or just basically - // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) +} - encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) - fmt.Println(encodingResponse.IDs) - // [2829 4419 14523 2058 1996 13971 3899] - fmt.Println(encodingResponse.TypeIDs) - // [0 0 0 0 0 0 0] - fmt.Println(encodingResponse.SpecialTokensMask) - // [0 0 0 0 0 0 0] - fmt.Println(encodingResponse.AttentionMask) - // [1 1 1 1 1 1 1] - fmt.Println(encodingResponse.Tokens) - // [brown fox jumps over the lazy dog] - fmt.Println(encodingResponse.Offsets) - // [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]] +func advanced() error { + // Load tokenizer from local config file + tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json") + if err != nil { + return err + } + defer tk.Close() + // Load pretrained tokenizer from HuggingFace tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased" - tkFromHf, errHf := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil) - if errHf != nil { - panic(errHf) + tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil) + if err != nil { + return err } - // release native resources defer tkFromHf.Close() - encodingResponseHf := tkFromHf.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) - fmt.Println(encodingResponseHf.IDs) - // [101 2829 4419 14523 2058 1996 13971 3899 102] - fmt.Println(encodingResponseHf.TypeIDs) - // [0 0 0 0 0 0 0 0 0] - fmt.Println(encodingResponseHf.SpecialTokensMask) - // [1 0 0 0 0 0 0 0 1] - fmt.Println(encodingResponseHf.AttentionMask) - // [1 1 1 1 1 1 1 1 1] - fmt.Println(encodingResponseHf.Tokens) - // [[CLS] brown fox jumps over the lazy dog [SEP]] - fmt.Println(encodingResponseHf.Offsets) - // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] + // Encode with specific options + encodeOptions := []tokenizers.EncodeOption{ + tokenizers.WithReturnTypeIDs(), + tokenizers.WithReturnAttentionMask(), + tokenizers.WithReturnTokens(), + tokenizers.WithReturnOffsets(), + tokenizers.WithReturnSpecialTokensMask(), + } + // Or simply: + // encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes()) + + // regardless of how the tokenizer was initialized, the output is the same + for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} { + encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...) + fmt.Println(encodingResponse.IDs) + // [101 2829 4419 14523 2058 1996 13971 3899 102] + fmt.Println(encodingResponse.TypeIDs) + // [0 0 0 0 0 0 0 0 0] + fmt.Println(encodingResponse.SpecialTokensMask) + // [1 0 0 0 0 0 0 0 1] + fmt.Println(encodingResponse.AttentionMask) + // [1 1 1 1 1 1 1 1 1] + fmt.Println(encodingResponse.Tokens) + // [[CLS] brown fox jumps over the lazy dog [SEP]] + fmt.Println(encodingResponse.Offsets) + // [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]] + } + return nil +} + +func main() { + if err := simple(); err != nil { + log.Fatal(err) + } + if err := advanced(); err != nil { + log.Fatal(err) + } } diff --git a/tokenizer.go b/tokenizer.go index 415c2f72..5b36fbe7 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -88,13 +88,13 @@ func FromFile(path string) (*Tokenizer, error) { return &Tokenizer{tokenizer: tokenizer}, nil } -// LoadTokenizerFromHuggingFace downloads necessary files and initializes the tokenizer. +// FromPretrained downloads necessary files and initializes the tokenizer. // Parameters: // - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased"). // - destination: Optional. If provided and not nil, files will be downloaded to this folder. // If nil, a temporary directory will be used. // - authToken: Optional. If provided and not nil, it will be used to authenticate requests. -func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string) (*Tokenizer, error) { +func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) { if strings.TrimSpace(modelID) == "" { return nil, fmt.Errorf("modelID cannot be empty") } @@ -170,7 +170,6 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string func downloadFile(url, destination string, authToken *string) error { // Check if the file already exists if _, err := os.Stat(destination); err == nil { - fmt.Printf("File %s already exists. Skipping download.\n", destination) return nil }