From 475afba6f862f2c62ab1a668ddbf2433aed02ea5 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Mon, 20 May 2024 18:12:22 +0200 Subject: [PATCH] Fix chinese --- bpe.go | 23 +++++++++++++---------- claude_test.go | 5 +++-- encoding_test.go | 7 +++++++ go.mod | 1 + go.sum | 2 ++ 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/bpe.go b/bpe.go index 73c6add..3d18c26 100644 --- a/bpe.go +++ b/bpe.go @@ -85,6 +85,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, retTokens := []string{} textLength := len(text) + textRunes := []rune(text) start := 0 @@ -94,11 +95,11 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, startFind := start for { - temp := cutText(text, startFind, textLength) + temp := cut(textRunes, startFind, textLength) nextSpecial = findRegex2StringIndex(temp, specialRegex) if nextSpecial != nil { - token := cutText(text, startFind+nextSpecial[0], startFind+nextSpecial[1]) + token := cut(textRunes, startFind+nextSpecial[0], startFind+nextSpecial[1]) if _, ok := allowedSpecial[token]; ok { break } @@ -114,8 +115,9 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, end = start + nextSpecial[0] } - for _, mat := range findRegex2AllStringMatchIndex(cutText(text, start, end), regex) { - piece := cutText(text, start+mat[0], start+mat[1]) + for _, mat := range findRegex2AllStringMatchIndex(cut(textRunes, start, end), regex) { + piece := cut(textRunes, start+mat[0], start+mat[1]) + if id, ok := bpe.encoder[piece]; ok { retIDs = append(retIDs, id) retTokens = append(retTokens, piece) @@ -129,7 +131,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, } if nextSpecial != nil { - temp := cutText(text, start+nextSpecial[0], start+nextSpecial[1]) + temp := cut(textRunes, start+nextSpecial[0], start+nextSpecial[1]) id := bpe.specialTokensEncoder[temp] retIDs = append(retIDs, id) retTokens = append(retTokens, temp) @@ -148,9 +150,10 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint, func (bpe *coreBPE) EncodeOrdinary(text string) ([]uint, []string) { retIDs := []uint{} retTokens := []string{} + textRunes := []rune(text) for _, mat := range findRegex2AllStringMatchIndex(text, bpe.tlRegex) { - piece := cutText(text, mat[0], mat[1]) + piece := cut(textRunes, mat[0], mat[1]) if id, ok := bpe.encoder[piece]; ok { retIDs = append(retIDs, id) retTokens = append(retTokens, piece) @@ -300,14 +303,14 @@ func findRegex2AllStringMatchIndex(text string, reg *regexp2.Regexp) [][]int { return matches } -func cutText(text string, start, end int) string { +func cut(runes []rune, start, end int) string { if start < 0 { start = 0 } - if end > len(text) { - end = len(text) + if end > len(runes) { + end = len(runes) } - return text[start:end] + return string(runes[start:end]) } diff --git a/claude_test.go b/claude_test.go index e94024b..9eaa768 100644 --- a/claude_test.go +++ b/claude_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/text/unicode/norm" ) func TestClaude(t *testing.T) { @@ -25,10 +26,10 @@ func TestClaude(t *testing.T) { }) t.Run("text normalising", func(t *testing.T) { - idx, _ := encoding.EncodeOrdinary("™") + idx, _ := encoding.EncodeOrdinary(norm.NFKC.String("™")) assert.Equal(t, 1, len(idx)) - idx, _ = encoding.EncodeOrdinary("ϰ") + idx, _ = encoding.EncodeOrdinary(norm.NFKC.String("ϰ")) assert.Equal(t, 1, len(idx)) }) diff --git a/encoding_test.go b/encoding_test.go index db28df5..ade4601 100644 --- a/encoding_test.go +++ b/encoding_test.go @@ -78,6 +78,13 @@ func TestCL100kEncoding(t *testing.T) { assert.ElementsMatch(t, []uint{15339, 220, 100257}, ids) }) + t.Run("chinese", func(t *testing.T) { + text := "你好世界!" + ids, _, err := encoding.Encode(text, []string{"all"}, nil) + assert.NoError(t, err) + assert.ElementsMatch(t, []uint{57668, 53901, 3574, 244, 98220, 6447}, ids) + }) + t.Run("decode", func(t *testing.T) { assert.Equal(t, "hello world", string(encoding.Decode([]uint{15339, 1917}))) }) diff --git a/go.mod b/go.mod index dac6235..34e601f 100644 --- a/go.mod +++ b/go.mod @@ -10,5 +10,6 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/text v0.15.0 gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4d2eff4..3b70680 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=