diff --git a/gpt_bpe_test.go b/gpt_bpe_test.go index 68dbfc1..77ce641 100644 --- a/gpt_bpe_test.go +++ b/gpt_bpe_test.go @@ -24,7 +24,8 @@ var clipEncoder GPTEncoder var gpt2Encoder GPTEncoder var pileEncoder GPTEncoder var nerdstashV2Encoder GPTEncoder -var Llama2Encoder GPTEncoder +var llama2Encoder GPTEncoder +var mistralEncoder GPTEncoder var corpus string var clipCorpus string @@ -33,7 +34,8 @@ var gpt2Encoded *Tokens var pileEncoded *Tokens var clipEncoded *Tokens var nerdstashEncoded *Tokens -var Llama2Encoded *Tokens +var llama2Encoded *Tokens +var mistralEncoded *Tokens var unicodeTrimTests []*Tokens const largeCorpusPath = "resources/wiki.train.raw" @@ -93,7 +95,8 @@ func init() { pileEncoder = NewPileEncoder() clipEncoder = NewCLIPEncoder() nerdstashV2Encoder = NewNerdstashV2Encoder() - Llama2Encoder = NewLlama2Encoder() + llama2Encoder = NewLlama2Encoder() + mistralEncoder = NewMistralEncoder() textBytes := handleRead("resources/frankenstein.txt") clipBytes := handleRead("resources/frankenstein_clip.txt") corpus = string(textBytes) @@ -787,25 +790,102 @@ func TestLlamaEncoder_Encode(t *testing.T) { func TestLlamaTwoEncoder_Encode(t *testing.T) { testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." - llamaTokens := Llama2Encoder.Encode(&testString) + llamaTokens := llama2Encoder.Encode(&testString) assert.Equal(t, llamaTokens, &Tokens{1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889}) } func TestLlamaTwoTokenizerDecode(t *testing.T) { outputString := "The fox jumped over the hare.\nThe turtle is faster than the hare." llamaTokens := Tokens{1, 1576, 1701, 29916, 12500, 287, 975, 278, 447, 276, 29889, 13, 1576, 260, 4227, 280, 338, 8473, 1135, 278, 447, 276, 29889} - output := Llama2Encoder.Decode(&llamaTokens) + output := llama2Encoder.Decode(&llamaTokens) assert.Equal(t, outputString, output) } func TestLlamaTwoEncodeDecode(t *testing.T) { testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." outputString := "The fox jumped over the hare.\nThe turtle is faster than the hare." - llamaTokens := Llama2Encoder.Encode(&testString) - output := Llama2Encoder.Decode(llamaTokens) + llamaTokens := llama2Encoder.Encode(&testString) + output := llama2Encoder.Decode(llamaTokens) assert.Equal(t, outputString, output) } +func TestMistralEncoder_Encode(t *testing.T) { + testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." + mistralTokens := mistralEncoder.Encode(&testString) + fmt.Printf("mistralTokens: %v\n", mistralTokens) + assert.Equal(t, mistralTokens, &Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723}) +} + +func TestMistralTokenizerDecode(t *testing.T) { + outputString := " The fox jumped over the hare.\nThe turtle is faster than the hare." + mistralTokens := Tokens{1, 415, 285, 1142, 14949, 754, 272, 295, 492, 28723, 13, 1014, 261, 3525, 291, 349, 9556, 821, 272, 295, 492, 28723} + output := mistralEncoder.Decode(&mistralTokens) + assert.Equal(t, outputString, output) +} + +func TestMistralEncodeDecode(t *testing.T) { + testString := "The fox jumped over the hare.\nThe turtle is faster than the hare." + outputString := " The fox jumped over the hare.\nThe turtle is faster than the hare." + mistralTokens := mistralEncoder.Encode(&testString) + output := mistralEncoder.Decode(mistralTokens) + assert.Equal(t, outputString, output) +} + +func TestMistralEncodeDecodeFrankenstein(t *testing.T) { + frankensteinCorpus := "resources/frankenstein.txt" + frankensteinText, err := os.ReadFile(frankensteinCorpus) + if err != nil { + t.Errorf("Error reading Frankenstein corpus: %v", err) + } + frankensteinString := string(frankensteinText) + mistralTokens := mistralEncoder.Encode(&frankensteinString) + output := mistralEncoder.Decode(mistralTokens) + assert.Equal(t, " "+frankensteinString, output) +} + +func TestReadTokenizerConfig(t *testing.T) { + fmt.Println("Testing ReadTokenizerConfig") + // json with eos, bos, pad as strings + jsonStr := `{"eos_token": "TC", "bos_token": "TD", "pad_token": "TE"}` //cooresponds to 6669, 10989, 5428 in pythia vocab + + //download filler model + modelId := "EleutherAI/pythia-70m" + destPath := "./TestReadTokenizerConfig" + destPathPTR := &destPath + defer os.RemoveAll(destPath) + var rsrcType resources.ResourceType + rsrcType = resources.RESOURCETYPE_TRANSFORMERS + hfApiToken := os.Getenv("HF_API_TOKEN") + os.MkdirAll(destPath, 0755) + _, rsrcErr := resources.ResolveResources(modelId, destPathPTR, + resources.RESOURCE_MODEL, rsrcType, hfApiToken) + if rsrcErr != nil { + os.RemoveAll(destPath) + t.Errorf("Error downloading model resources: %s", rsrcErr) + } + + // replace tokenizer_config.json with jsonStr + tokenizerConfigPath := destPath + "/tokenizer_config.json" + err := os.WriteFile(tokenizerConfigPath, []byte(jsonStr), 0644) + if err != nil { + t.Errorf("Error writing to tokenizer_config.json: %v", err) + } + + // read tokenizer config by encoding a string + encoder, err := NewEncoder(destPath) + if err != nil { + t.Errorf("Error creating encoder: %v", err) + } + + // check that the tokens are correct + assert.Equal(t, encoder.EosToken, Token(6669)) + assert.Equal(t, encoder.BosToken, Token(10989)) + assert.Equal(t, encoder.PadToken, Token(5428)) + + // Clean up by removing the downloaded folder + fmt.Println("All Exists - Looks good.") +} + func TestGPTDecoder_Decode(t *testing.T) { // TBD } @@ -823,10 +903,10 @@ func TestModelDownload(t *testing.T) { rsrcType = resources.RESOURCETYPE_TRANSFORMERS hfApiToken := os.Getenv("HF_API_TOKEN") os.MkdirAll(destPath, 0755) - defer os.RemoveAll(destPath) _, rsrcErr := resources.ResolveResources(modelId, destPathPTR, resources.RESOURCE_MODEL, rsrcType, hfApiToken) if rsrcErr != nil { + os.RemoveAll(destPath) t.Errorf("Error downloading model resources: %s", rsrcErr) } @@ -841,9 +921,11 @@ func TestModelDownload(t *testing.T) { fmt.Println("config.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("config.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for config.json") } @@ -853,9 +935,11 @@ func TestModelDownload(t *testing.T) { fmt.Println("pytorch_model.bin exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model.bin does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model.bin") } @@ -865,9 +949,11 @@ func TestModelDownload(t *testing.T) { fmt.Println("tokenizer.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("tokenizer.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for tokenizer.json") } @@ -877,13 +963,16 @@ func TestModelDownload(t *testing.T) { fmt.Println("vocab.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("vocab.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for vocab.json") } - // Finish Test - Deferred removal of the downloaded folder + // Clean up by removing the downloaded folder + os.RemoveAll(destPath) fmt.Println("All Exists - Looks good.") } @@ -900,10 +989,10 @@ func TestModelDownloadPythia(t *testing.T) { rsrcType = resources.RESOURCETYPE_TRANSFORMERS hfApiToken := os.Getenv("HF_API_TOKEN") os.MkdirAll(destPath, 0755) - defer os.RemoveAll(destPath) _, rsrcErr := resources.ResolveResources(modelId, destPathPTR, resources.RESOURCE_MODEL, rsrcType, hfApiToken) if rsrcErr != nil { + os.RemoveAll(destPath) t.Errorf("Error downloading model resources: %s", rsrcErr) } @@ -918,9 +1007,11 @@ func TestModelDownloadPythia(t *testing.T) { fmt.Println("config.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("config.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for config.json") } @@ -930,9 +1021,11 @@ func TestModelDownloadPythia(t *testing.T) { fmt.Println("pytorch_model.bin exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model.bin does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model.bin") } @@ -942,9 +1035,11 @@ func TestModelDownloadPythia(t *testing.T) { fmt.Println("tokenizer.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("tokenizer.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for tokenizer.json") } @@ -954,13 +1049,16 @@ func TestModelDownloadPythia(t *testing.T) { fmt.Println("vocab.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("vocab.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for vocab.json") } - // Finish Test - Deferred removal of the downloaded folder + // Clean up by removing the downloaded folder + os.RemoveAll(destPath) fmt.Println("All Exists - Looks good.") } @@ -976,10 +1074,10 @@ func TestModelDownloadPythiaSharded(t *testing.T) { rsrcType = resources.RESOURCETYPE_TRANSFORMERS hfApiToken := os.Getenv("HF_API_TOKEN") os.MkdirAll(destPath, 0755) - defer os.RemoveAll(destPath) _, rsrcErr := resources.ResolveResources(modelId, destPathPTR, resources.RESOURCE_MODEL, rsrcType, hfApiToken) if rsrcErr != nil { + os.RemoveAll(destPath) t.Errorf("Error downloading model resources: %s", rsrcErr) } @@ -994,9 +1092,11 @@ func TestModelDownloadPythiaSharded(t *testing.T) { fmt.Println("pytorch_model-00001-of-00002.bin exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model-00001-of-00002.bin does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model-00001-of-00002.bin") } @@ -1006,9 +1106,11 @@ func TestModelDownloadPythiaSharded(t *testing.T) { fmt.Println("pytorch_model-00002-of-00002.bin exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model-00002-of-00002.bin does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model-00002-of-00002.bin") } @@ -1018,13 +1120,16 @@ func TestModelDownloadPythiaSharded(t *testing.T) { fmt.Println("pytorch_model.bin.index.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model.bin.index.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model.bin.index.json") } - // Finish Test - Deferred removal of the downloaded folder + // Clean up by removing the downloaded folder + os.RemoveAll(destPath) fmt.Println("All Exists - Looks good.") } @@ -1137,10 +1242,10 @@ func TestModelDownloadFairseq(t *testing.T) { rsrcType = resources.RESOURCETYPE_TRANSFORMERS hfApiToken := os.Getenv("HF_API_TOKEN") os.MkdirAll(destPath, 0755) - defer os.RemoveAll(destPath) _, rsrcErr := resources.ResolveResources(modelId, destPathPTR, resources.RESOURCE_MODEL, rsrcType, hfApiToken) if rsrcErr != nil { + os.RemoveAll(destPath) t.Errorf("Error downloading model resources: %s", rsrcErr) } @@ -1154,9 +1259,11 @@ func TestModelDownloadFairseq(t *testing.T) { fmt.Println("config.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("config.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for config.json") } @@ -1166,9 +1273,11 @@ func TestModelDownloadFairseq(t *testing.T) { fmt.Println("pytorch_model.bin exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("pytorch_model.bin does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for pytorch_model.bin") } @@ -1178,9 +1287,11 @@ func TestModelDownloadFairseq(t *testing.T) { fmt.Println("vocab.json exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("vocab.json does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for vocab.json") } @@ -1190,12 +1301,15 @@ func TestModelDownloadFairseq(t *testing.T) { fmt.Println("merges.txt exists") } else if errors.Is(err, os.ErrNotExist) { + os.RemoveAll(destPath) t.Errorf("merges.txt does not exist") } else { + os.RemoveAll(destPath) t.Errorf("Error checking for merges.txt") } - // Finish Test - Deferred removal of the downloaded folder + // Clean up by removing the downloaded folder + os.RemoveAll(destPath) fmt.Println("All Exists - Looks good (Fairseq Download).") }