From 7cc60ce65387454bd44a635093f29b6f1d7d90af Mon Sep 17 00:00:00 2001 From: DavePearce Date: Mon, 9 Dec 2024 15:10:05 +1300 Subject: [PATCH] Fixes for identifiers & lisp generation --- pkg/corset/ast.go | 150 ++++++++++++++++++++++++++++++++++++++----- pkg/corset/parser.go | 12 ++-- 2 files changed, 140 insertions(+), 22 deletions(-) diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go index c82a9a8..873ceb5 100644 --- a/pkg/corset/ast.go +++ b/pkg/corset/ast.go @@ -1,6 +1,7 @@ package corset import ( + "fmt" "math/big" "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" @@ -197,7 +198,14 @@ func (p *DefColumns) Definitions() util.Iterator[SymbolDefinition] { // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. func (p *DefColumns) Lisp() sexp.SExp { - panic("got here") + list := sexp.EmptyList() + list.Append(sexp.NewSymbol("defcolumns")) + // Add lisp for each individual column + for _, c := range p.Columns { + list.Append(c.Lisp()) + } + // Done + return list } // DefColumn packages together those piece relevant to declaring an individual @@ -260,7 +268,28 @@ func (e *DefColumn) MustProve() bool { // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. func (e *DefColumn) Lisp() sexp.SExp { - panic("got here") + list := sexp.EmptyList() + list.Append(sexp.NewSymbol(e.name)) + // + if e.binding.dataType != nil { + datatype := e.binding.dataType.String() + if e.binding.mustProve { + datatype = fmt.Sprintf("%s@prove", datatype) + } + + list.Append(sexp.NewSymbol(datatype)) + } + // + if e.binding.multiplier != 1 { + list.Append(sexp.NewSymbol(":multiplier")) + list.Append(sexp.NewSymbol(fmt.Sprintf("%d", e.binding.multiplier))) + } + // + if list.Len() == 1 { + return list.Get(0) + } + // + return list } // ============================================================================ @@ -330,7 +359,9 @@ func (e *DefConstUnit) Name() string { // //nolint:revive func (p *DefConstUnit) Lisp() sexp.SExp { - panic("got here") + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol(p.name), + p.binding.value.Lisp()}) } // ============================================================================ @@ -388,7 +419,25 @@ func (p *DefConstraint) Dependencies() util.Iterator[Symbol] { // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. func (p *DefConstraint) Lisp() sexp.SExp { - panic("got here") + modifiers := sexp.EmptyList() + // domain + if p.Domain != nil { + domain := fmt.Sprintf("{%d}", *p.Domain) + // + modifiers.Append(sexp.NewSymbol(":domain")) + modifiers.Append(sexp.NewSymbol(domain)) + } + // + if p.Guard != nil { + modifiers.Append(sexp.NewSymbol(":guard")) + modifiers.Append(p.Guard.Lisp()) + } + // + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("defconstraint"), + sexp.NewSymbol(p.Handle), + modifiers, + p.Constraint.Lisp()}) } // ============================================================================ @@ -540,7 +589,30 @@ func (p *DefPermutation) Dependencies() util.Iterator[Symbol] { // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. func (p *DefPermutation) Lisp() sexp.SExp { - panic("got here") + targets := make([]sexp.SExp, len(p.Targets)) + sources := make([]sexp.SExp, len(p.Sources)) + // Targets + for i, t := range p.Targets { + targets[i] = t.Lisp() + } + // Sources + for i, t := range p.Sources { + var sign string + if p.Signs[i] { + sign = "+" + } else { + sign = "-" + } + // + sources[i] = sexp.NewList([]sexp.SExp{ + sexp.NewSymbol(sign), + t.Lisp()}) + } + // + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("defpermutation"), + sexp.NewList(targets), + sexp.NewList(sources)}) } // ============================================================================ @@ -757,7 +829,7 @@ func (e *Add) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Add) Lisp() sexp.SExp { - panic("todo") + return ListOfExpressions(sexp.NewSymbol("+"), e.Args) } // Substitute all variables (such as for function parameters) arising in @@ -853,7 +925,10 @@ func (e *Exp) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Exp) Lisp() sexp.SExp { - panic("todo") + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("^"), + e.Arg.Lisp(), + e.Pow.Lisp()}) } // Substitute all variables (such as for function parameters) arising in @@ -915,7 +990,16 @@ func (e *IfZero) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *IfZero) Lisp() sexp.SExp { - panic("todo") + if e.FalseBranch != nil { + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("if"), + e.TrueBranch.Lisp(), + e.FalseBranch.Lisp()}) + } + // + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("if"), + e.TrueBranch.Lisp()}) } // Substitute all variables (such as for function parameters) arising in @@ -1036,7 +1120,14 @@ func (e *Invoke) Multiplicity() uint { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Invoke) Lisp() sexp.SExp { - panic("todo") + var fn sexp.SExp + if e.module != nil { + fn = sexp.NewSymbol(fmt.Sprintf("%s.%s", *e.module, e.name)) + } else { + fn = sexp.NewSymbol(e.name) + } + + return ListOfExpressions(fn, e.args) } // Substitute all variables (such as for function parameters) arising in @@ -1083,7 +1174,7 @@ func (e *List) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *List) Lisp() sexp.SExp { - panic("todo") + return ListOfExpressions(sexp.NewSymbol("begin"), e.Args) } // Substitute all variables (such as for function parameters) arising in @@ -1127,7 +1218,7 @@ func (e *Mul) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Mul) Lisp() sexp.SExp { - panic("todo") + return ListOfExpressions(sexp.NewSymbol("*"), e.Args) } // Substitute all variables (such as for function parameters) arising in @@ -1152,7 +1243,8 @@ type Normalise struct{ Arg Expr } // AsConstant attempts to evaluate this expression as a constant (signed) value. // If this expression is not constant, then nil is returned. func (e *Normalise) AsConstant() *big.Int { - panic("todo") + // FIXME: we could do better here. + return nil } // Multiplicity determines the number of values that evaluating this expression @@ -1171,7 +1263,9 @@ func (e *Normalise) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Normalise) Lisp() sexp.SExp { - panic("todo") + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("~"), + e.Arg.Lisp()}) } // Substitute all variables (such as for function parameters) arising in @@ -1215,7 +1309,7 @@ func (e *Sub) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Sub) Lisp() sexp.SExp { - panic("todo") + return ListOfExpressions(sexp.NewSymbol("-"), e.Args) } // Substitute all variables (such as for function parameters) arising in @@ -1269,7 +1363,10 @@ func (e *Shift) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Shift) Lisp() sexp.SExp { - panic("todo") + return sexp.NewList([]sexp.SExp{ + sexp.NewSymbol("shift"), + e.Arg.Lisp(), + e.Shift.Lisp()}) } // Substitute all variables (such as for function parameters) arising in @@ -1377,7 +1474,14 @@ func (e *VariableAccess) Context() Context { // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed.a func (e *VariableAccess) Lisp() sexp.SExp { - panic("todo") + var name string + if e.module != nil { + name = fmt.Sprintf("%s.%s", *e.module, e.name) + } else { + name = e.name + } + // + return sexp.NewSymbol(name) } // Substitute all variables (such as for function parameters) arising in @@ -1450,6 +1554,20 @@ func DependenciesOfExpressions(exprs []Expr) []Symbol { return deps } +// ListOfExpressions converts an array of one or more expressions into a list of +// corresponding lisp expressions. +func ListOfExpressions(head sexp.SExp, exprs []Expr) *sexp.List { + lisps := make([]sexp.SExp, len(exprs)+1) + // Assign head + lisps[0] = head + // + for i, e := range exprs { + lisps[i+1] = e.Lisp() + } + // + return sexp.NewList(lisps) +} + // AsConstantOfExpressions attempts to fold one or more expressions across a // given operation (e.g. add, subtract, etc) to produce a constant value. If // any of the expressions are not themselves constant, then neither is the diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index 50f5432..545a0c7 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -330,7 +330,7 @@ func (p *Parser) parseColumnDeclaration(module string, e sexp.SExp) (*DefColumn, } // Parse a constant declaration -func (p *Parser) parseDefConst(elements []sexp.SExp) (*DefConst, []SyntaxError) { +func (p *Parser) parseDefConst(elements []sexp.SExp) (Declaration, []SyntaxError) { var ( errors []SyntaxError constants []*DefConstUnit @@ -370,7 +370,7 @@ func (p *Parser) parseDefConstUnit(name string, value sexp.SExp) (*DefConstUnit, } // Parse a vanishing declaration -func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, []SyntaxError) { +func (p *Parser) parseDefConstraint(elements []sexp.SExp) (Declaration, []SyntaxError) { var errors []SyntaxError // Initial sanity checks if !isIdentifier(elements[1]) { @@ -398,7 +398,7 @@ func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, []Syn } // Parse a interleaved declaration -func (p *Parser) parseDefInterleaved(module string, elements []sexp.SExp) (*DefInterleaved, *SyntaxError) { +func (p *Parser) parseDefInterleaved(module string, elements []sexp.SExp) (Declaration, *SyntaxError) { // Initial sanity checks if !isIdentifier(elements[1]) { return nil, p.translator.SyntaxError(elements[1], "malformed target column") @@ -465,7 +465,7 @@ func (p *Parser) parseDefLookup(elements []sexp.SExp) (*DefLookup, *SyntaxError) } // Parse a permutation declaration -func (p *Parser) parseDefPermutation(module string, elements []sexp.SExp) (*DefPermutation, *SyntaxError) { +func (p *Parser) parseDefPermutation(module string, elements []sexp.SExp) (Declaration, *SyntaxError) { var err *SyntaxError // sexpTargets := elements[1].AsList() @@ -544,7 +544,7 @@ func (p *Parser) parsePermutedColumnSign(sign *sexp.Symbol) (bool, *SyntaxError) } // Parse a property assertion -func (p *Parser) parseDefProperty(elements []sexp.SExp) (*DefProperty, *SyntaxError) { +func (p *Parser) parseDefProperty(elements []sexp.SExp) (Declaration, *SyntaxError) { // Initial sanity checks if !isIdentifier(elements[1]) { return nil, p.translator.SyntaxError(elements[1], "expected constraint handle") @@ -880,5 +880,5 @@ func isIdentifierStart(c rune) bool { } func isIdentifierMiddle(c rune) bool { - return unicode.IsDigit(c) || isIdentifierStart(c) + return unicode.IsDigit(c) || isIdentifierStart(c) || c == '-' }