Skip to content

Commit

Permalink
{pyactr} Handle printing of numbers and variables (#65)
Browse files Browse the repository at this point in the history
Addresses more of #32.
  • Loading branch information
asmaloney authored Oct 11, 2021
1 parent 080f1df commit f3594b9
Showing 1 changed file with 94 additions and 37 deletions.
131 changes: 94 additions & 37 deletions framework/pyactr/pyactr.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ func (p *PyACTR) WriteModel(path, initialGoal string) (outputFileName string, er

p.outputAuthors()

if p.model.HasPrintStatement() {
// We use csv to parse the print text we are generating.
// This is just simpler than writing it ourselves (i.e. handling "foo, bar ", 66).
p.Writeln("import csv")
}

p.Writeln("import pyactr as actr")

if p.model.HasPrintStatement() {
Expand All @@ -137,20 +143,6 @@ func (p *PyACTR) WriteModel(path, initialGoal string) (outputFileName string, er

p.Write("%s = actr.ACTRModel(%s)\n\n", p.className, strings.Join(additionalInit, ", "))

if p.model.HasPrintStatement() {
p.Writeln(`
# Monkey patch Buffer to add a new method.
# Currently we can only output simple strings.
def print_text(*args):
text = ''.join(args[1:])
text = text.strip("'")
print(text)
Buffer.print_text = print_text
`)
}

// chunks
for _, chunk := range p.model.Chunks {
if chunk.IsInternal() {
Expand Down Expand Up @@ -218,13 +210,6 @@ Buffer.print_text = print_text
if production.DoStatements != nil {
for _, statement := range production.DoStatements {
if statement.Print != nil {
if !onlyStrings(statement.Print.Values) {
warning := fmt.Sprintf("Warning: ('%s') pyactr currently cannot print variables from productions", production.Name)
warnings = append(warnings, warning)

continue
}

numPrintStatements++
if numPrintStatements > 1 {
warning := fmt.Sprintf("Warning: ('%s') pyactr currently only supports one print statement per production", production.Name)
Expand All @@ -233,7 +218,7 @@ Buffer.print_text = print_text
}
}

p.outputStatement(statement)
p.outputStatement(production, statement)
}
}

Expand All @@ -250,12 +235,80 @@ Buffer.print_text = print_text

p.Writeln("")

if p.model.HasPrintStatement() {
// We add some python code to monkey patch the model and the buffer.
// This lets us output strings, numbers, and the contents of buffer slots.

p.Writeln(`
# Monkey patch ACTRModel to add a new method.
def get_buffer(self, buffer_name: str) -> Buffer:
if buffer_name == 'goal':
return self._ACTRModel__buffers['goal']
if buffer_name == 'retrieval':
return self._ACTRModel__buffers["retrieval"]
print('ERROR: Buffer \'' + buffer_name + '\' not found')
actr.ACTRModel.get_buffer = get_buffer
# Monkey patch Buffer to add a new methods.
def get_slot_contents(self, buffer_name: str, slot_name: str) -> str:
if self._data:
chunk = self._data.copy().pop()
else:
chunk = None
try:
return str(getattr(chunk, slot_name))
except AttributeError:
print('ERROR: no slot named \'' + slot_name +
'\' in buffer \'' + buffer_name + '\'')
return ''
def print_text(*args):
text = ''.join(args[1:]).strip('"')
output = '' # build up our output in this buffer
for itemlist in csv.reader([text]):
for item in itemlist:
item = item.strip(' ')
# Handle string
if item[0] == '\'' or item[0] == '"':
output += item[1:-1]
else:
# Handle number
try:
float(item)
output += item
except ValueError:
# If we are here, we should have a buffer.slotname
ids = item.split('.')
if len(ids) != 2:
print(
'ERROR: expected <buffer>.<slot_name>, found \'' + item + '\'')
else:
buffer = %s.get_buffer(ids[0])
output += buffer.get_slot_contents(ids[0], ids[1])
print(output)
Buffer.get_slot_contents = get_slot_contents
Buffer.print_text = print_text
`, p.className)
}

// ...add our code to run
p.Writeln("# Main")
p.Writeln("if __name__ == '__main__':")
p.Writeln("\tsim = %s.simulation()", p.className)
p.Writeln("\tsim.run()")
// TODO: Add some intelligent output when logging level is info or detail
p.Writeln("\tif goal.test_buffer('full') == True:")
p.Writeln("\tif goal.test_buffer('full') is True:")
p.Writeln("\t\tprint('final goal: ' + str(goal.pop()))")

return
Expand Down Expand Up @@ -335,7 +388,7 @@ func addPatternSlot(tabbedItems *framework.KeyValueList, slotName string, patter
}
}

func (p *PyACTR) outputStatement(s *actr.Statement) {
func (p *PyACTR) outputStatement(production *actr.Production, s *actr.Statement) {
if s.Set != nil {
buffer := s.Set.Buffer
bufferName := buffer.GetName()
Expand Down Expand Up @@ -367,27 +420,31 @@ func (p *PyACTR) outputStatement(s *actr.Statement) {
p.Writeln("\t+retrieval>")
p.outputPattern(s.Recall.Pattern, 2)
} else if s.Print != nil {
// Using "goal" here is arbitrary because of the way we monkey patch the python code.
// Our "print_text" statement handles its own formatting and lookup.
p.Writeln("\t!goal>")
values := framework.PythonValuesToStrings(s.Print.Values, true)
p.Writeln("\t\tprint_text %s", strings.Join(values, " "))

str := make([]string, len(*s.Print.Values))

for index, val := range *s.Print.Values {
if val.Var != nil {
varIndex := production.VarIndexMap[*val.Var]
str[index] = fmt.Sprintf("%s.%s", varIndex.Buffer.GetName(), varIndex.SlotName)
} else if val.Str != nil {
str[index] = fmt.Sprintf("'%s'", *val.Str)
} else if val.Number != nil {
str[index] = *val.Number
}
}

p.Writeln("\t\tprint_text \"%s\"", strings.Join(str, ", "))
} else if s.Clear != nil {
for _, name := range s.Clear.BufferNames {
p.Writeln("\t~%s>", name)
}
}
}

func onlyStrings(values *[]*actr.Value) bool {
for _, v := range *values {
if v.Var != nil {
return false
}
// v.ID should not be possible because of validation
}

return true
}

// removeWarning will remove the long warning whenever pyactr is run without tkinter.
func removeWarning(text []byte) []byte {
str := string(text)
Expand Down

0 comments on commit f3594b9

Please sign in to comment.