Skip to content

Commit

Permalink
fix: rename new func to FromPretrained, improve example
Browse files Browse the repository at this point in the history
  • Loading branch information
daulet committed Nov 5, 2024
1 parent 15395ec commit 8ec40b9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 43 deletions.
94 changes: 54 additions & 40 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package main

import (
"fmt"
"log"

"github.com/daulet/tokenizers"
)

func main() {
func simple() error {
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
panic(err)
return err
}
// release native resources
defer tk.Close()
Expand All @@ -22,49 +24,61 @@ func main() {
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog

var encodeOptions []tokenizers.EncodeOption
encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs())
encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask())
encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens())
encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets())
encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask())
return nil

// Or just basically
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())
}

encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [2829 4419 14523 2058 1996 13971 3899]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [brown fox jumps over the lazy dog]
fmt.Println(encodingResponse.Offsets)
// [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]]
func advanced() error {
// Load tokenizer from local config file
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
return err
}
defer tk.Close()

// Load pretrained tokenizer from HuggingFace
tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased"
tkFromHf, errHf := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil)
if errHf != nil {
panic(errHf)
tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", &tokenizerPath, nil)
if err != nil {
return err
}
// release native resources
defer tkFromHf.Close()

encodingResponseHf := tkFromHf.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...)
fmt.Println(encodingResponseHf.IDs)
// [101 2829 4419 14523 2058 1996 13971 3899 102]
fmt.Println(encodingResponseHf.TypeIDs)
// [0 0 0 0 0 0 0 0 0]
fmt.Println(encodingResponseHf.SpecialTokensMask)
// [1 0 0 0 0 0 0 0 1]
fmt.Println(encodingResponseHf.AttentionMask)
// [1 1 1 1 1 1 1 1 1]
fmt.Println(encodingResponseHf.Tokens)
// [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(encodingResponseHf.Offsets)
// [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]]
// Encode with specific options
encodeOptions := []tokenizers.EncodeOption{
tokenizers.WithReturnTypeIDs(),
tokenizers.WithReturnAttentionMask(),
tokenizers.WithReturnTokens(),
tokenizers.WithReturnOffsets(),
tokenizers.WithReturnSpecialTokensMask(),
}
// Or simply:
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())

// regardless of how the tokenizer was initialized, the output is the same
for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} {
encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [101 2829 4419 14523 2058 1996 13971 3899 102]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [1 0 0 0 0 0 0 0 1]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(encodingResponse.Offsets)
// [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]]
}
return nil
}

func main() {
if err := simple(); err != nil {
log.Fatal(err)
}
if err := advanced(); err != nil {
log.Fatal(err)
}
}
5 changes: 2 additions & 3 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ func FromFile(path string) (*Tokenizer, error) {
return &Tokenizer{tokenizer: tokenizer}, nil
}

// LoadTokenizerFromHuggingFace downloads necessary files and initializes the tokenizer.
// FromPretrained downloads necessary files and initializes the tokenizer.
// Parameters:
// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased").
// - destination: Optional. If provided and not nil, files will be downloaded to this folder.
// If nil, a temporary directory will be used.
// - authToken: Optional. If provided and not nil, it will be used to authenticate requests.
func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string) (*Tokenizer, error) {
func FromPretrained(modelID string, destination, authToken *string) (*Tokenizer, error) {
if strings.TrimSpace(modelID) == "" {
return nil, fmt.Errorf("modelID cannot be empty")
}
Expand Down Expand Up @@ -170,7 +170,6 @@ func LoadTokenizerFromHuggingFace(modelID string, destination, authToken *string
func downloadFile(url, destination string, authToken *string) error {
// Check if the file already exists
if _, err := os.Stat(destination); err == nil {
fmt.Printf("File %s already exists. Skipping download.\n", destination)
return nil
}

Expand Down

0 comments on commit 8ec40b9

Please sign in to comment.