-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
357 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
package main | ||
|
||
import ( | ||
"bufio" | ||
"encoding/json" | ||
"flag" | ||
"fmt" | ||
"io/ioutil" | ||
Check failure on line 8 in gomarkov/example/pokenamer/main.go GitHub Actions / build
|
||
"log" | ||
"os" | ||
"strings" | ||
|
||
"github.com/vitalvas/gokit/gomarkov/libmarkov" | ||
) | ||
|
||
func main() { | ||
train := flag.Bool("train", false, "Train the markov chain") | ||
order := flag.Int("order", 3, "Chain order to use") | ||
sourceFile := flag.String("file", "names.txt", "File name with source") | ||
|
||
flag.Parse() | ||
if *train { | ||
chain := buildModel(*order, *sourceFile) | ||
saveModel(chain) | ||
} else { | ||
chain, err := loadModel() | ||
if err != nil { | ||
fmt.Println(err) | ||
return | ||
} | ||
|
||
name := generatePokemon(chain) | ||
fmt.Println(name) | ||
} | ||
} | ||
|
||
func buildModel(order int, file string) *libmarkov.Chain { | ||
chain := libmarkov.NewChain(order) | ||
for _, data := range getDataset(file) { | ||
if len(data) > 0 { | ||
chain.Add(split(data)) | ||
} | ||
} | ||
return chain | ||
} | ||
|
||
func split(str string) []string { | ||
return strings.Split(str, "") | ||
} | ||
|
||
func getDataset(fileName string) []string { | ||
file, _ := os.Open(fileName) | ||
scanner := bufio.NewScanner(file) | ||
var list []string | ||
for scanner.Scan() { | ||
list = append(list, scanner.Text()) | ||
} | ||
return list | ||
} | ||
|
||
func saveModel(chain *libmarkov.Chain) { | ||
jsonObj, _ := json.Marshal(chain) | ||
err := ioutil.WriteFile("model.json", jsonObj, 0644) | ||
if err != nil { | ||
fmt.Println(err) | ||
} | ||
} | ||
|
||
func loadModel() (*libmarkov.Chain, error) { | ||
var chain libmarkov.Chain | ||
data, err := ioutil.ReadFile("model.json") | ||
if err != nil { | ||
return &chain, err | ||
} | ||
err = json.Unmarshal(data, &chain) | ||
if err != nil { | ||
return &chain, err | ||
} | ||
return &chain, nil | ||
} | ||
|
||
func generatePokemon(chain *libmarkov.Chain) string { | ||
order := chain.Order | ||
tokens := make([]string, 0) | ||
|
||
for i := 0; i < order; i++ { | ||
tokens = append(tokens, libmarkov.StartToken) | ||
} | ||
|
||
for tokens[len(tokens)-1] != libmarkov.EndToken { | ||
next, _ := chain.Generate(tokens[(len(tokens) - order):]) | ||
tokens = append(tokens, next) | ||
} | ||
|
||
return strings.Join(tokens[order:len(tokens)-1], "") | ||
} | ||
|
||
func newAppendModel(name string) { | ||
f, err := os.OpenFile("names_auto.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||
if err != nil { | ||
return | ||
} | ||
defer f.Close() | ||
|
||
_, err = f.WriteString(name) | ||
if err != nil { | ||
log.Println(err.Error()) | ||
} | ||
|
||
_, err = f.WriteString("\n") | ||
if err != nil { | ||
log.Println(err.Error()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
package libmarkov | ||
|
||
import "strings" | ||
|
||
// Pair godoc | ||
type Pair struct { | ||
CurrentState NGram | ||
NextState string | ||
} | ||
|
||
// NGram godoc | ||
type NGram []string | ||
|
||
type sparseArray map[int]int | ||
|
||
func (ngram NGram) key() string { | ||
return strings.Join(ngram, ":") | ||
} | ||
|
||
func (s sparseArray) sum() int { | ||
sum := 0 | ||
for _, count := range s { | ||
sum += count | ||
} | ||
return sum | ||
} | ||
|
||
func array(value string, count int) []string { | ||
arr := make([]string, count) | ||
for i := range arr { | ||
arr[i] = value | ||
} | ||
return arr | ||
} | ||
|
||
// MakePairs godoc | ||
func MakePairs(tokens []string, order int) []Pair { | ||
var pairs []Pair | ||
for i := 0; i < len(tokens)-order; i++ { | ||
pair := Pair{ | ||
CurrentState: tokens[i : i+order], | ||
NextState: tokens[i+order], | ||
} | ||
pairs = append(pairs, pair) | ||
} | ||
return pairs | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
package libmarkov | ||
|
||
import ( | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"math/rand" | ||
"strings" | ||
"sync" | ||
"time" | ||
) | ||
|
||
// StartToken godoc | ||
const StartToken = "^" | ||
|
||
// EndToken godoc | ||
const EndToken = "$" | ||
|
||
// Chain godoc | ||
type Chain struct { | ||
Order int | ||
statePool *spool | ||
frequencyMat map[int]sparseArray | ||
lock *sync.RWMutex | ||
} | ||
|
||
// ChainJSON godoc | ||
type ChainJSON struct { | ||
Order int `json:"int"` | ||
SpoolMap map[string]int `json:"spool_map"` | ||
FreqMat map[int]sparseArray `json:"freq_mat"` | ||
} | ||
|
||
func init() { | ||
rand.Seed(time.Now().UnixNano()) | ||
Check failure on line 35 in gomarkov/libmarkov/markov.go GitHub Actions / build
|
||
} | ||
|
||
// MarshalJSON godoc | ||
func (chain Chain) MarshalJSON() ([]byte, error) { | ||
obj := ChainJSON{ | ||
Order: chain.Order, | ||
SpoolMap: chain.statePool.stringMap, | ||
FreqMat: chain.frequencyMat, | ||
} | ||
return json.Marshal(obj) | ||
} | ||
|
||
// UnmarshalJSON godoc | ||
func (chain *Chain) UnmarshalJSON(b []byte) error { | ||
var obj ChainJSON | ||
err := json.Unmarshal(b, &obj) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
chain.Order = obj.Order | ||
intMap := make(map[int]string) | ||
|
||
for k, v := range obj.SpoolMap { | ||
intMap[v] = k | ||
} | ||
chain.statePool = &spool{ | ||
stringMap: obj.SpoolMap, | ||
intMap: intMap, | ||
} | ||
chain.frequencyMat = obj.FreqMat | ||
chain.lock = new(sync.RWMutex) | ||
return nil | ||
} | ||
|
||
// NewChain godoc | ||
func NewChain(order int) *Chain { | ||
chain := Chain{Order: order} | ||
chain.statePool = &spool{ | ||
stringMap: make(map[string]int), | ||
intMap: make(map[int]string), | ||
} | ||
chain.frequencyMat = make(map[int]sparseArray, 0) | ||
chain.lock = new(sync.RWMutex) | ||
return &chain | ||
} | ||
|
||
// RawAdd godoc | ||
func (chain *Chain) RawAdd(input string) { | ||
split := func(str string) []string { | ||
return strings.Split(str, "") | ||
} | ||
|
||
chain.Add(split(input)) | ||
} | ||
|
||
// Add godoc | ||
func (chain *Chain) Add(input []string) { | ||
startTokens := array(StartToken, chain.Order) | ||
endTokens := array(EndToken, chain.Order) | ||
tokens := make([]string, 0) | ||
tokens = append(tokens, startTokens...) | ||
tokens = append(tokens, input...) | ||
tokens = append(tokens, endTokens...) | ||
pairs := MakePairs(tokens, chain.Order) | ||
|
||
for i := 0; i < len(pairs); i++ { | ||
pair := pairs[i] | ||
currentIndex := chain.statePool.add(pair.CurrentState.key()) | ||
nextIndex := chain.statePool.add(pair.NextState) | ||
chain.lock.Lock() | ||
|
||
if chain.frequencyMat[currentIndex] == nil { | ||
chain.frequencyMat[currentIndex] = make(sparseArray, 0) | ||
} | ||
|
||
chain.frequencyMat[currentIndex][nextIndex]++ | ||
chain.lock.Unlock() | ||
} | ||
} | ||
|
||
// TransitionProbability godoc | ||
func (chain *Chain) TransitionProbability(next string, current NGram) (float64, error) { | ||
if len(current) != chain.Order { | ||
return 0, errors.New("N-gram length does not match chain order") | ||
} | ||
|
||
currentIndex, currentExists := chain.statePool.get(current.key()) | ||
nextIndex, nextExists := chain.statePool.get(next) | ||
if !currentExists || !nextExists { | ||
return 0, nil | ||
} | ||
|
||
arr := chain.frequencyMat[currentIndex] | ||
sum := float64(arr.sum()) | ||
freq := float64(arr[nextIndex]) | ||
|
||
return freq / sum, nil | ||
} | ||
|
||
// Generate godoc | ||
func (chain *Chain) Generate(current NGram) (string, error) { | ||
if len(current) != chain.Order { | ||
return "", errors.New("N-gram length does not match chain order") | ||
} | ||
|
||
currentIndex, currentExists := chain.statePool.get(current.key()) | ||
if !currentExists { | ||
return "", fmt.Errorf("Unknown ngram %v", current) | ||
} | ||
|
||
arr := chain.frequencyMat[currentIndex] | ||
sum := arr.sum() | ||
randN := rand.Intn(sum) | ||
|
||
for i, freq := range arr { | ||
randN -= freq | ||
if randN <= 0 { | ||
return chain.statePool.intMap[i], nil | ||
} | ||
} | ||
return "", nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package libmarkov | ||
|
||
import "sync" | ||
|
||
type spool struct { | ||
stringMap map[string]int | ||
intMap map[int]string | ||
sync.RWMutex | ||
} | ||
|
||
func (s *spool) add(str string) int { | ||
s.RLock() | ||
index, ok := s.stringMap[str] | ||
s.RUnlock() | ||
|
||
if ok { | ||
return index | ||
} | ||
|
||
s.Lock() | ||
defer s.Unlock() | ||
|
||
index, ok = s.stringMap[str] | ||
if ok { | ||
return index | ||
} | ||
|
||
index = len(s.stringMap) | ||
s.stringMap[str] = index | ||
s.intMap[index] = str | ||
|
||
return index | ||
} | ||
|
||
func (s *spool) get(str string) (int, bool) { | ||
index, ok := s.stringMap[str] | ||
return index, ok | ||
} |