Skip to content

Commit

Permalink
feat: move gomarkov
Browse files Browse the repository at this point in the history
  • Loading branch information
vitalvas committed Aug 10, 2024
1 parent 65298b4 commit c844fc7
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
114 changes: 114 additions & 0 deletions gomarkov/example/pokenamer/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package main

import (
"bufio"
"encoding/json"
"flag"
"fmt"
"io/ioutil"

Check failure on line 8 in gomarkov/example/pokenamer/main.go

View workflow job for this annotation

GitHub Actions / build

SA1019: "io/ioutil" has been deprecated since Go 1.19: As of Go 1.16, the same functionality is now provided by package io or package os, and those implementations should be preferred in new code. See the specific function documentation for details. (staticcheck)
"log"
"os"
"strings"

"github.com/vitalvas/gokit/gomarkov/libmarkov"
)

func main() {
train := flag.Bool("train", false, "Train the markov chain")
order := flag.Int("order", 3, "Chain order to use")
sourceFile := flag.String("file", "names.txt", "File name with source")

flag.Parse()
if *train {
chain := buildModel(*order, *sourceFile)
saveModel(chain)
} else {
chain, err := loadModel()
if err != nil {
fmt.Println(err)
return
}

name := generatePokemon(chain)
fmt.Println(name)
}
}

func buildModel(order int, file string) *libmarkov.Chain {
chain := libmarkov.NewChain(order)
for _, data := range getDataset(file) {
if len(data) > 0 {
chain.Add(split(data))
}
}
return chain
}

func split(str string) []string {
return strings.Split(str, "")
}

func getDataset(fileName string) []string {
file, _ := os.Open(fileName)
scanner := bufio.NewScanner(file)
var list []string
for scanner.Scan() {
list = append(list, scanner.Text())
}
return list
}

func saveModel(chain *libmarkov.Chain) {
jsonObj, _ := json.Marshal(chain)
err := ioutil.WriteFile("model.json", jsonObj, 0644)
if err != nil {
fmt.Println(err)
}
}

func loadModel() (*libmarkov.Chain, error) {
var chain libmarkov.Chain
data, err := ioutil.ReadFile("model.json")
if err != nil {
return &chain, err
}
err = json.Unmarshal(data, &chain)
if err != nil {
return &chain, err
}
return &chain, nil
}

func generatePokemon(chain *libmarkov.Chain) string {
order := chain.Order
tokens := make([]string, 0)

for i := 0; i < order; i++ {
tokens = append(tokens, libmarkov.StartToken)
}

for tokens[len(tokens)-1] != libmarkov.EndToken {
next, _ := chain.Generate(tokens[(len(tokens) - order):])
tokens = append(tokens, next)
}

return strings.Join(tokens[order:len(tokens)-1], "")
}

func newAppendModel(name string) {

Check failure on line 98 in gomarkov/example/pokenamer/main.go

View workflow job for this annotation

GitHub Actions / build

func `newAppendModel` is unused (unused)
f, err := os.OpenFile("names_auto.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return
}
defer f.Close()

_, err = f.WriteString(name)
if err != nil {
log.Println(err.Error())
}

_, err = f.WriteString("\n")
if err != nil {
log.Println(err.Error())
}
}
47 changes: 47 additions & 0 deletions gomarkov/libmarkov/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package libmarkov

import "strings"

// Pair godoc
type Pair struct {
CurrentState NGram
NextState string
}

// NGram godoc
type NGram []string

type sparseArray map[int]int

func (ngram NGram) key() string {
return strings.Join(ngram, ":")
}

func (s sparseArray) sum() int {
sum := 0
for _, count := range s {
sum += count
}
return sum
}

func array(value string, count int) []string {
arr := make([]string, count)
for i := range arr {
arr[i] = value
}
return arr
}

// MakePairs godoc
func MakePairs(tokens []string, order int) []Pair {
var pairs []Pair
for i := 0; i < len(tokens)-order; i++ {
pair := Pair{
CurrentState: tokens[i : i+order],
NextState: tokens[i+order],
}
pairs = append(pairs, pair)
}
return pairs
}
158 changes: 158 additions & 0 deletions gomarkov/libmarkov/markov.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package libmarkov

import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"sync"
"time"
)

// StartToken godoc
const StartToken = "^"

// EndToken godoc
const EndToken = "$"

// Chain godoc
type Chain struct {
Order int
statePool *spool
frequencyMat map[int]sparseArray
lock *sync.RWMutex
}

// ChainJSON godoc
type ChainJSON struct {
Order int `json:"int"`
SpoolMap map[string]int `json:"spool_map"`
FreqMat map[int]sparseArray `json:"freq_mat"`
}

func init() {
rand.Seed(time.Now().UnixNano())

Check failure on line 35 in gomarkov/libmarkov/markov.go

View workflow job for this annotation

GitHub Actions / build

SA1019: rand.Seed has been deprecated since Go 1.20 and an alternative has been available since Go 1.0: Programs that call Seed and then expect a specific sequence of results from the global random source (using functions such as Int) can be broken when a dependency changes how much it consumes from the global random source. To avoid such breakages, programs that need a specific result sequence should use NewRand(NewSource(seed)) to obtain a random generator that other packages cannot access. (staticcheck)
}

// MarshalJSON godoc
func (chain Chain) MarshalJSON() ([]byte, error) {
obj := ChainJSON{
Order: chain.Order,
SpoolMap: chain.statePool.stringMap,
FreqMat: chain.frequencyMat,
}
return json.Marshal(obj)
}

// UnmarshalJSON godoc
func (chain *Chain) UnmarshalJSON(b []byte) error {
var obj ChainJSON
err := json.Unmarshal(b, &obj)
if err != nil {
return err
}

chain.Order = obj.Order
intMap := make(map[int]string)

for k, v := range obj.SpoolMap {
intMap[v] = k
}
chain.statePool = &spool{
stringMap: obj.SpoolMap,
intMap: intMap,
}
chain.frequencyMat = obj.FreqMat
chain.lock = new(sync.RWMutex)
return nil
}

// NewChain godoc
func NewChain(order int) *Chain {
chain := Chain{Order: order}
chain.statePool = &spool{
stringMap: make(map[string]int),
intMap: make(map[int]string),
}
chain.frequencyMat = make(map[int]sparseArray, 0)
chain.lock = new(sync.RWMutex)
return &chain
}

// RawAdd godoc
func (chain *Chain) RawAdd(input string) {
split := func(str string) []string {
return strings.Split(str, "")
}

chain.Add(split(input))
}

// Add godoc
func (chain *Chain) Add(input []string) {
startTokens := array(StartToken, chain.Order)
endTokens := array(EndToken, chain.Order)
tokens := make([]string, 0)
tokens = append(tokens, startTokens...)
tokens = append(tokens, input...)
tokens = append(tokens, endTokens...)
pairs := MakePairs(tokens, chain.Order)

for i := 0; i < len(pairs); i++ {
pair := pairs[i]
currentIndex := chain.statePool.add(pair.CurrentState.key())
nextIndex := chain.statePool.add(pair.NextState)
chain.lock.Lock()

if chain.frequencyMat[currentIndex] == nil {
chain.frequencyMat[currentIndex] = make(sparseArray, 0)
}

chain.frequencyMat[currentIndex][nextIndex]++
chain.lock.Unlock()
}
}

// TransitionProbability godoc
func (chain *Chain) TransitionProbability(next string, current NGram) (float64, error) {
if len(current) != chain.Order {
return 0, errors.New("N-gram length does not match chain order")

Check failure on line 120 in gomarkov/libmarkov/markov.go

View workflow job for this annotation

GitHub Actions / build

ST1005: error strings should not be capitalized (stylecheck)
}

currentIndex, currentExists := chain.statePool.get(current.key())
nextIndex, nextExists := chain.statePool.get(next)
if !currentExists || !nextExists {
return 0, nil
}

arr := chain.frequencyMat[currentIndex]
sum := float64(arr.sum())
freq := float64(arr[nextIndex])

return freq / sum, nil
}

// Generate godoc
func (chain *Chain) Generate(current NGram) (string, error) {
if len(current) != chain.Order {
return "", errors.New("N-gram length does not match chain order")

Check failure on line 139 in gomarkov/libmarkov/markov.go

View workflow job for this annotation

GitHub Actions / build

ST1005: error strings should not be capitalized (stylecheck)
}

currentIndex, currentExists := chain.statePool.get(current.key())
if !currentExists {
return "", fmt.Errorf("Unknown ngram %v", current)

Check failure on line 144 in gomarkov/libmarkov/markov.go

View workflow job for this annotation

GitHub Actions / build

ST1005: error strings should not be capitalized (stylecheck)
}

arr := chain.frequencyMat[currentIndex]
sum := arr.sum()
randN := rand.Intn(sum)

for i, freq := range arr {
randN -= freq
if randN <= 0 {
return chain.statePool.intMap[i], nil
}
}
return "", nil
}
38 changes: 38 additions & 0 deletions gomarkov/libmarkov/spool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package libmarkov

import "sync"

type spool struct {
stringMap map[string]int
intMap map[int]string
sync.RWMutex
}

func (s *spool) add(str string) int {
s.RLock()
index, ok := s.stringMap[str]
s.RUnlock()

if ok {
return index
}

s.Lock()
defer s.Unlock()

index, ok = s.stringMap[str]
if ok {
return index
}

index = len(s.stringMap)
s.stringMap[str] = index
s.intMap[index] = str

return index
}

func (s *spool) get(str string) (int, bool) {
index, ok := s.stringMap[str]
return index, ok
}

0 comments on commit c844fc7

Please sign in to comment.