Skip to content

Commit

Permalink
Merge pull request #3 from hupe1980/fix_chinese
Browse files Browse the repository at this point in the history
Fix chinese #2
  • Loading branch information
hupe1980 authored May 20, 2024
2 parents d599702 + 475afba commit 9962b7e
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 12 deletions.
23 changes: 13 additions & 10 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
retTokens := []string{}

textLength := len(text)
textRunes := []rune(text)

start := 0

Expand All @@ -94,11 +95,11 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
startFind := start

for {
temp := cutText(text, startFind, textLength)
temp := cut(textRunes, startFind, textLength)
nextSpecial = findRegex2StringIndex(temp, specialRegex)

if nextSpecial != nil {
token := cutText(text, startFind+nextSpecial[0], startFind+nextSpecial[1])
token := cut(textRunes, startFind+nextSpecial[0], startFind+nextSpecial[1])
if _, ok := allowedSpecial[token]; ok {
break
}
Expand All @@ -114,8 +115,9 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
end = start + nextSpecial[0]
}

for _, mat := range findRegex2AllStringMatchIndex(cutText(text, start, end), regex) {
piece := cutText(text, start+mat[0], start+mat[1])
for _, mat := range findRegex2AllStringMatchIndex(cut(textRunes, start, end), regex) {
piece := cut(textRunes, start+mat[0], start+mat[1])

if id, ok := bpe.encoder[piece]; ok {
retIDs = append(retIDs, id)
retTokens = append(retTokens, piece)
Expand All @@ -129,7 +131,7 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
}

if nextSpecial != nil {
temp := cutText(text, start+nextSpecial[0], start+nextSpecial[1])
temp := cut(textRunes, start+nextSpecial[0], start+nextSpecial[1])
id := bpe.specialTokensEncoder[temp]
retIDs = append(retIDs, id)
retTokens = append(retTokens, temp)
Expand All @@ -148,9 +150,10 @@ func (bpe *coreBPE) Encode(text string, allowedSpecial map[string]any) ([]uint,
func (bpe *coreBPE) EncodeOrdinary(text string) ([]uint, []string) {
retIDs := []uint{}
retTokens := []string{}
textRunes := []rune(text)

for _, mat := range findRegex2AllStringMatchIndex(text, bpe.tlRegex) {
piece := cutText(text, mat[0], mat[1])
piece := cut(textRunes, mat[0], mat[1])
if id, ok := bpe.encoder[piece]; ok {
retIDs = append(retIDs, id)
retTokens = append(retTokens, piece)
Expand Down Expand Up @@ -300,14 +303,14 @@ func findRegex2AllStringMatchIndex(text string, reg *regexp2.Regexp) [][]int {
return matches
}

func cutText(text string, start, end int) string {
func cut(runes []rune, start, end int) string {
if start < 0 {
start = 0
}

if end > len(text) {
end = len(text)
if end > len(runes) {
end = len(runes)
}

return text[start:end]
return string(runes[start:end])
}
5 changes: 3 additions & 2 deletions claude_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/text/unicode/norm"
)

func TestClaude(t *testing.T) {
Expand All @@ -25,10 +26,10 @@ func TestClaude(t *testing.T) {
})

t.Run("text normalising", func(t *testing.T) {
idx, _ := encoding.EncodeOrdinary("™")
idx, _ := encoding.EncodeOrdinary(norm.NFKC.String("™"))
assert.Equal(t, 1, len(idx))

idx, _ = encoding.EncodeOrdinary("ϰ")
idx, _ = encoding.EncodeOrdinary(norm.NFKC.String("ϰ"))
assert.Equal(t, 1, len(idx))
})

Expand Down
7 changes: 7 additions & 0 deletions encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ func TestCL100kEncoding(t *testing.T) {
assert.ElementsMatch(t, []uint{15339, 220, 100257}, ids)
})

t.Run("chinese", func(t *testing.T) {
text := "你好世界!"
ids, _, err := encoding.Encode(text, []string{"all"}, nil)
assert.NoError(t, err)
assert.ElementsMatch(t, []uint{57668, 53901, 3574, 244, 98220, 6447}, ids)
})

t.Run("decode", func(t *testing.T) {
assert.Equal(t, "hello world", string(encoding.Decode([]uint{15339, 1917})))
})
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ require (
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/text v0.15.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
Expand Down

0 comments on commit 9962b7e

Please sign in to comment.