Skip to content

Commit

Permalink
feat: FromPretrained to load tokenizer directly from HF (#27)
Browse files Browse the repository at this point in the history
* add LoadTokenizerFromHuggingFace function to load tokenizer directly from huggingface, update README.md

* using channels as unbuffered channel, update channel names and minimize some approaches

* fix: rename new func to FromPretrained, improve example

* fix: clean up downloadFile

* fix: concurrency issues in case of an error

* fix: make optional params optional

* fix: cache path has to be model specific

* add unit tests for `FromPretrained`

* migrate to table driven tests, unify/simplify test cases

* fix: clean up nits

---------

Co-authored-by: Resul Berkay Ersoy <resul.ersoy@trendyol.com>
Co-authored-by: Daulet Zhanguzin <daulet@zhanguzin.kz>
  • Loading branch information
3 people authored Nov 7, 2024
1 parent 0d469f8 commit 9c972d9
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 4 deletions.
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Go bindings for the [HuggingFace Tokenizers](https://github.com/huggingface/toke

## Installation

`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers.a'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers.a"` to avoid specifying it every time.
`make build` to build `libtokenizers.a` that you need to run your application that uses bindings. In addition, you need to inform the linker where to find that static library: `go run -ldflags="-extldflags '-L./path/to/libtokenizers/directory'" .` or just add it to the `CGO_LDFLAGS` environment variable: `CGO_LDFLAGS="-L./path/to/libtokenizers/directory"` to avoid specifying it every time.

### Using pre-built binaries

Expand All @@ -31,6 +31,20 @@ if err != nil {
defer tk.Close()
```

Load a tokenizer from Huggingface:

```go
import "github.com/daulet/tokenizers"

tokenizerPath := "../huggingface-tokenizers/google-bert/bert-base-uncased"
tk, err := tokenizers.LoadTokenizerFromHuggingFace("google-bert/bert-base-uncased", &tokenizerPath, nil)
if err != nil {
return err
}
// release native resources
defer tk.Close()
```

Encode text and decode tokens:

```go
Expand All @@ -44,6 +58,34 @@ fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true
// brown fox jumps over the lazy dog
```

Encode text with options:

```go
var encodeOptions []tokenizers.EncodeOption
encodeOptions = append(encodeOptions, tokenizers.WithReturnTypeIDs())
encodeOptions = append(encodeOptions, tokenizers.WithReturnAttentionMask())
encodeOptions = append(encodeOptions, tokenizers.WithReturnTokens())
encodeOptions = append(encodeOptions, tokenizers.WithReturnOffsets())
encodeOptions = append(encodeOptions, tokenizers.WithReturnSpecialTokensMask())

// Or just basically
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())

encodingResponse := tk.EncodeWithOptions("brown fox jumps over the lazy dog", false, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [2829 4419 14523 2058 1996 13971 3899]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [0 0 0 0 0 0 0]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [brown fox jumps over the lazy dog]
fmt.Println(encodingResponse.Offsets)
// [[0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33]]
```

## Benchmarks

`go test . -run=^\$ -bench=. -benchmem -count=10 > test/benchmark/$(git rev-parse HEAD).txt`
Expand Down
61 changes: 59 additions & 2 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ package main

import (
"fmt"
"log"

"github.com/daulet/tokenizers"
)

func main() {
func simple() error {
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
panic(err)
return err
}
// release native resources
defer tk.Close()

fmt.Println("Vocab size:", tk.VocabSize())
// Vocab size: 30522
fmt.Println(tk.Encode("brown fox jumps over the lazy dog", false))
Expand All @@ -21,4 +23,59 @@ func main() {
// [101 2829 4419 14523 2058 1996 13971 3899 102] [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(tk.Decode([]uint32{2829, 4419, 14523, 2058, 1996, 13971, 3899}, true))
// brown fox jumps over the lazy dog
return nil
}

func advanced() error {
// Load tokenizer from local config file
tk, err := tokenizers.FromFile("../test/data/bert-base-uncased.json")
if err != nil {
return err
}
defer tk.Close()

// Load pretrained tokenizer from HuggingFace
tkFromHf, err := tokenizers.FromPretrained("google-bert/bert-base-uncased", tokenizers.WithCacheDir("./.cache/tokenizers"))
if err != nil {
return err
}
defer tkFromHf.Close()

// Encode with specific options
encodeOptions := []tokenizers.EncodeOption{
tokenizers.WithReturnTypeIDs(),
tokenizers.WithReturnAttentionMask(),
tokenizers.WithReturnTokens(),
tokenizers.WithReturnOffsets(),
tokenizers.WithReturnSpecialTokensMask(),
}
// Or simply:
// encodeOptions = append(encodeOptions, tokenizers.WithReturnAllAttributes())

// regardless of how the tokenizer was initialized, the output is the same
for _, tkzr := range []*tokenizers.Tokenizer{tk, tkFromHf} {
encodingResponse := tkzr.EncodeWithOptions("brown fox jumps over the lazy dog", true, encodeOptions...)
fmt.Println(encodingResponse.IDs)
// [101 2829 4419 14523 2058 1996 13971 3899 102]
fmt.Println(encodingResponse.TypeIDs)
// [0 0 0 0 0 0 0 0 0]
fmt.Println(encodingResponse.SpecialTokensMask)
// [1 0 0 0 0 0 0 0 1]
fmt.Println(encodingResponse.AttentionMask)
// [1 1 1 1 1 1 1 1 1]
fmt.Println(encodingResponse.Tokens)
// [[CLS] brown fox jumps over the lazy dog [SEP]]
fmt.Println(encodingResponse.Offsets)
// [[0 0] [0 5] [6 9] [10 15] [16 20] [21 24] [25 29] [30 33] [0 0]]
}
return nil
}

func main() {
if err := simple(); err != nil {
log.Fatal(err)
}
if err := advanced(); err != nil {
log.Fatal(err)
}
}
162 changes: 161 additions & 1 deletion tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,29 @@ import "C"
import (
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"unsafe"
)

const WANT_VERSION = "1.20.2"
const (
WANT_VERSION = "1.20.2"

baseURL = "https://huggingface.co"
)

// List of necessary tokenizer files and their mandatory status.
// True means mandatory, false means optional.
var tokenizerFiles = map[string]bool{
"tokenizer.json": true,
"vocab.txt": false,
"merges.txt": false,
"special_tokens_map.json": false,
"added_tokens.json": false,
}

func init() {
version := C.version()
Expand Down Expand Up @@ -78,6 +97,147 @@ func FromFile(path string) (*Tokenizer, error) {
return &Tokenizer{tokenizer: tokenizer}, nil
}

type tokenizerConfig struct {
cacheDir *string
authToken *string
}

type TokenizerConfigOption func(cfg *tokenizerConfig)

func WithCacheDir(path string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.cacheDir = &path
}
}

func WithAuthToken(token string) TokenizerConfigOption {
return func(cfg *tokenizerConfig) {
cfg.authToken = &token
}
}

// FromPretrained downloads necessary files and initializes the tokenizer.
// Parameters:
// - modelID: The Hugging Face model identifier (e.g., "bert-base-uncased").
// - destination: Optional. If provided and not nil, files will be downloaded to this folder.
// If nil, a temporary directory will be used.
// - authToken: Optional. If provided and not nil, it will be used to authenticate requests.
func FromPretrained(modelID string, opts ...TokenizerConfigOption) (*Tokenizer, error) {
cfg := &tokenizerConfig{}
for _, opt := range opts {
opt(cfg)
}
if strings.TrimSpace(modelID) == "" {
return nil, fmt.Errorf("modelID cannot be empty")
}

// Construct the model URL
modelURL := fmt.Sprintf("%s/%s/resolve/main", baseURL, modelID)

// Determine the download directory
var downloadDir string
if cfg.cacheDir != nil {
downloadDir = fmt.Sprintf("%s/%s", *cfg.cacheDir, modelID)
// Create the destination directory if it doesn't exist
err := os.MkdirAll(downloadDir, os.ModePerm)
if err != nil {
return nil, fmt.Errorf("failed to create destination directory %s: %w", downloadDir, err)
}
} else {
// Create a temporary directory
tmpDir, err := os.MkdirTemp("", "huggingface-tokenizer-*")
if err != nil {
return nil, fmt.Errorf("error creating temporary directory: %w", err)
}
downloadDir = tmpDir
}

var wg sync.WaitGroup
errCh := make(chan error)

// Download each tokenizer file concurrently
for filename, isMandatory := range tokenizerFiles {
wg.Add(1)
go func(fn string, mandatory bool) {
defer wg.Done()
fileURL := fmt.Sprintf("%s/%s", modelURL, fn)
destPath := filepath.Join(downloadDir, fn)
err := downloadFile(fileURL, destPath, cfg.authToken)
if err != nil && mandatory {
// If the file is mandatory, report an error
errCh <- fmt.Errorf("failed to download mandatory file %s: %w", fn, err)
}
}(filename, isMandatory)
}

go func() {
wg.Wait()
close(errCh)
}()

var errs []error
for err := range errCh {
errs = append(errs, err)
}

if len(errs) > 0 {
if err := os.RemoveAll(downloadDir); err != nil {
fmt.Printf("Warning: failed to clean up directory %s: %v\n", downloadDir, err)
}
return nil, errs[0]
}

return FromFile(filepath.Join(downloadDir, "tokenizer.json"))
}

// downloadFile downloads a file from the given URL and saves it to the specified destination.
// If authToken is provided (non-nil), it will be used for authorization.
// Returns an error if the download fails.
func downloadFile(url, destination string, authToken *string) error {
// Check if the file already exists
if _, err := os.Stat(destination); err == nil {
return nil
}

// Create a new HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return fmt.Errorf("failed to create request for %s: %w", url, err)
}

// If authToken is provided, set the Authorization header
if authToken != nil {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *authToken))
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to download from %s: %w", url, err)
}
defer resp.Body.Close()

// Check for successful response
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download from %s: status code %d", url, resp.StatusCode)
}

// Create the destination file
out, err := os.Create(destination)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", destination, err)
}
defer out.Close()

// Write the response body to the file
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write to file %s: %w", destination, err)
}

fmt.Printf("Successfully downloaded %s\n", destination)
return nil
}

func (t *Tokenizer) Close() error {
C.free_tokenizer(t.tokenizer)
t.tokenizer = nil
Expand Down
Loading

0 comments on commit 9c972d9

Please sign in to comment.