Skip to content

Commit

Permalink
Merge pull request #19 from coreweave/rwang.fixhardcoding.07152024
Browse files Browse the repository at this point in the history
Refactor GPT_BPE tokenizer file loading and initial processing
  • Loading branch information
wbrown authored Jul 19, 2024
2 parents b0e1976 + d260b08 commit fccd9c0
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 163 deletions.
57 changes: 18 additions & 39 deletions gpt_bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ func NewMistralEncoder() GPTEncoder {
// Returns a GPTEncoder with the tokenizer data loaded for that vocabulary
// id.
func NewEncoder(vocabId string) (*GPTEncoder, error) {
hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId(vocabId,
"")
hfConfig, resourcesPtr, vocabErr := resources.ResolveVocabId(vocabId, "")

if vocabErr != nil {
return nil, vocabErr
}
Expand Down Expand Up @@ -176,32 +176,6 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
}
}

tokenizerSpecialConfig := resources.TokenizerSpecialsConfig{
AddBosToken: false,
AddEosToken: false,
PadToken: "",
}
altMistralSpecialsConfig := resources.MistralSpecialsConfig{
AddBosToken: false,
AddEosToken: false,
PadToken: "",
}
if special, ok := (rsrcs)["tokenizer_config.json"]; ok {
if special.Data != nil {
err := json.Unmarshal(*special.Data, &tokenizerSpecialConfig)
if err != nil {
err = json.Unmarshal(*special.Data, &altMistralSpecialsConfig)
if err != nil {
log.Fatal("Error unmarshalling tokenizer_config.json")
}
//populate the tokenizerSpecialConfig from the altMistralSpecialsConfig
tokenizerSpecialConfig.AddBosToken = altMistralSpecialsConfig.AddBosToken
tokenizerSpecialConfig.AddEosToken = altMistralSpecialsConfig.AddEosToken
tokenizerSpecialConfig.PadToken = altMistralSpecialsConfig.PadToken
}
}
}

puncRunes := make([]rune, 0)
if specialConfig.PuncRunes != nil {
for _, r := range specialConfig.PuncRunes {
Expand Down Expand Up @@ -364,23 +338,28 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
}

if specialConfig.EncloseEosBos {
tokenizerSpecialConfig.AddBosToken = true
tokenizerSpecialConfig.AddEosToken = true
bosBool := true
eosBool := true
hfConfig.AddBosToken = &bosBool
hfConfig.AddEosToken = &eosBool
}

// Add in default pad token if not already set
padTokenNotFound := (tokenizerSpecialConfig.PadToken == "" && hfConfig.PadTokenStr == nil)
padTokenNotFound := (hfConfig.PadTokenStr == nil)
if padTokenNotFound {
// Inject the pad token into the encoder to uintmax16,
// throw an error if vocab is larger than uintmax16
if len(encoderTokens) >= math.MaxInt16 {
log.Fatalf("Vocab size is larger than uint16 max, default pad token cannot be added." +
"Please specify a pad token in the vocab file.")
if len(encoderTokens) >= math.MaxUint16 {
log.Fatalf("Vocab size of %d is larger than uint16 max of %d. "+
"Please specify a pad token in the vocab file.",
len(encoderTokens), math.MaxUint16)
}
encoderTokens[defaultPadTokenString] = math.MaxUint16
tokenizerSpecialConfig.PadToken = defaultPadTokenString
hfConfig.PadTokenStr = &tokenizerSpecialConfig.PadToken
padToken := defaultPadTokenString
encoderTokens[padToken] = math.MaxUint16
hfConfig.PadTokenStr = &padToken
}

// Create the encoder
encoder := &GPTEncoder{
encoderTokens,
tokensEncoder,
Expand All @@ -403,8 +382,8 @@ func NewEncoder(vocabId string) (*GPTEncoder, error) {
encoderTokens[*hfConfig.EosTokenStr],
encoderTokens[*hfConfig.PadTokenStr],
specialConfig.EncloseEosBos,
tokenizerSpecialConfig.AddBosToken,
tokenizerSpecialConfig.AddEosToken,
*hfConfig.AddBosToken,
*hfConfig.AddEosToken,
specialConfig.PrefixSpace,
specialConfig.LowerCase,
specialConfig.EndOfWord,
Expand Down
23 changes: 23 additions & 0 deletions gpt_bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,29 @@ func TestReadTokenizerConfig(t *testing.T) {
fmt.Println("All Exists - Looks good.")
}

func TestPythiaRemoteDownloadTokenizer(t *testing.T) {
// Tests the ability to download a tokenizer from a remote model
// and use it to encode and decode strings
modelId := "EleutherAI/pythia-70m"
destPath := "./TestPythiaRemoteDownloadTokenizer"
defer os.RemoveAll(destPath)
encoderPythia, err := NewEncoder(modelId)
if err != nil {
t.Errorf("Error creating encoder: %v", err)
}

// Attempt to tokenize
testString := "The fox jumped over the hare.\nThe turtle is faster than the hare."

// Encode the string
encoded := encoderPythia.Encode(&testString)
// Check that the encoded string is the same as the expected - Reference from python's transformers lib
expected := Tokens{510, 30013, 16780, 689, 253, 419, 250, 15, 187, 510, 45993, 310, 7938, 685, 253, 419, 250, 15}
if !assert.Equal(t, expected, *encoded) {
t.Errorf("Expected: %v\nActual: %v", expected, *encoded)
}
}

func TestGPTDecoder_Decode(t *testing.T) {
// TBD
}
Expand Down
Loading

0 comments on commit fccd9c0

Please sign in to comment.