From b52abec9c809236365cdd360620e86d4ad796f6a Mon Sep 17 00:00:00 2001 From: DavePearce Date: Wed, 6 Nov 2024 17:14:25 +1300 Subject: [PATCH] Initial compiler outline This puts in place an initial outline of the compiler, though there remains quite a lot of work to be done. --- .golangci.yml | 2 - cmd/testgen/main.go | 6 +- pkg/cmd/check.go | 2 +- pkg/cmd/debug.go | 2 +- pkg/cmd/test.go | 2 +- pkg/cmd/util.go | 86 +++-- pkg/corset/ast.go | 367 ++++++++++++++++++ pkg/corset/compiler.go | 74 ++++ pkg/corset/parser.go | 481 +++++++++++++++++++++++ pkg/corset/resolver.go | 1 + pkg/corset/translator.go | 37 ++ pkg/hir/environment.go | 116 ------ pkg/hir/eval.go | 6 + pkg/hir/expr.go | 41 ++ pkg/hir/lisp.go | 15 + pkg/hir/lower.go | 7 + pkg/hir/macro.go | 18 + pkg/hir/parser.go | 745 ------------------------------------ pkg/hir/schema.go | 17 + pkg/sexp/error.go | 34 -- pkg/sexp/parser.go | 58 +-- pkg/sexp/sexp_test.go | 7 +- pkg/sexp/source_file.go | 103 +++++ pkg/sexp/source_map.go | 24 ++ pkg/sexp/translator.go | 86 +++-- pkg/test/ir_test.go | 25 +- pkg/util/maps.go | 11 + testdata/purefun_01.accepts | 6 + testdata/purefun_01.lisp | 3 + testdata/purefun_01.rejects | 5 + testdata/purefun_02.accepts | 15 + testdata/purefun_02.lisp | 3 + testdata/purefun_02.rejects | 35 ++ 33 files changed, 1422 insertions(+), 1018 deletions(-) create mode 100644 pkg/corset/ast.go create mode 100644 pkg/corset/compiler.go create mode 100644 pkg/corset/parser.go create mode 100644 pkg/corset/resolver.go create mode 100644 pkg/corset/translator.go delete mode 100644 pkg/hir/environment.go create mode 100644 pkg/hir/macro.go delete mode 100644 pkg/hir/parser.go delete mode 100644 pkg/sexp/error.go create mode 100644 pkg/sexp/source_file.go create mode 100644 pkg/util/maps.go create mode 100644 testdata/purefun_01.accepts create mode 100644 testdata/purefun_01.lisp create mode 100644 testdata/purefun_01.rejects create mode 100644 testdata/purefun_02.accepts create mode 100644 testdata/purefun_02.lisp create mode 100644 testdata/purefun_02.rejects diff --git a/.golangci.yml b/.golangci.yml index 972c326..16c2f6f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -54,8 +54,6 @@ linters-settings: severity: warning confidence: 0.8 rules: - - name: indent-error-flow - severity: warning - name: errorf severity: warning - name: context-as-argument diff --git a/cmd/testgen/main.go b/cmd/testgen/main.go index b469d8b..db4d362 100644 --- a/cmd/testgen/main.go +++ b/cmd/testgen/main.go @@ -8,8 +8,10 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" cmdutil "github.com/consensys/go-corset/pkg/cmd" + "github.com/consensys/go-corset/pkg/corset" "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/sexp" tr "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/trace/json" "github.com/consensys/go-corset/pkg/util" @@ -185,8 +187,10 @@ func readSchemaFile(filename string) *hir.Schema { fmt.Println(err) os.Exit(1) } + // Package up as source file + srcfile := sexp.NewSourceFile(filename, bytes) // Attempt to parse schema - schema, err2 := hir.ParseSchemaString(string(bytes)) + schema, err2 := corset.CompileSourceFile(srcfile) // Check whether parsed successfully or not if err2 == nil { // Ok diff --git a/pkg/cmd/check.go b/pkg/cmd/check.go index ea2f4f2..6e609cb 100644 --- a/pkg/cmd/check.go +++ b/pkg/cmd/check.go @@ -57,7 +57,7 @@ var checkCmd = &cobra.Command{ // stats := util.NewPerfStats() // Parse constraints - hirSchema = readSchemaFile(args[1]) + hirSchema = readSchema(args[1:]) // stats.Log("Reading constraints file") // Parse trace file diff --git a/pkg/cmd/debug.go b/pkg/cmd/debug.go index 2e98c7f..cd8c4d3 100644 --- a/pkg/cmd/debug.go +++ b/pkg/cmd/debug.go @@ -29,7 +29,7 @@ var debugCmd = &cobra.Command{ air := GetFlag(cmd, "air") stats := GetFlag(cmd, "stats") // Parse constraints - hirSchema := readSchemaFile(args[0]) + hirSchema := readSchema(args) // Print constraints if stats { diff --git a/pkg/cmd/test.go b/pkg/cmd/test.go index 3bcd5ea..c483e8a 100644 --- a/pkg/cmd/test.go +++ b/pkg/cmd/test.go @@ -56,7 +56,7 @@ var testCmd = &cobra.Command{ // stats := util.NewPerfStats() // Parse constraints - hirSchema = readSchemaFile(args[0]) + hirSchema = readSchema(args) // stats.Log("Reading constraints file") // diff --git a/pkg/cmd/util.go b/pkg/cmd/util.go index 62697d2..8e6bb18 100644 --- a/pkg/cmd/util.go +++ b/pkg/cmd/util.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/consensys/go-corset/pkg/binfile" + "github.com/consensys/go-corset/pkg/corset" "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/sexp" "github.com/consensys/go-corset/pkg/trace" @@ -133,54 +134,79 @@ func readTraceFile(filename string) []trace.RawColumn { return nil } -// Parse a constraints schema file using a parser based on the extension of the -// filename. -func readSchemaFile(filename string) *hir.Schema { +func readSchema(filenames []string) *hir.Schema { + if len(filenames) == 0 { + fmt.Println("source or binary constraint(s) file required.") + os.Exit(5) + } else if len(filenames) == 1 && path.Ext(filenames[0]) == "bin" { + // Single (binary) file supplied + return readBinaryFile(filenames[0]) + } + // Must be source files + return readSourceFiles(filenames) +} + +// Read a "bin" file. +func readBinaryFile(filename string) *hir.Schema { var schema *hir.Schema // Read schema file bytes, err := os.ReadFile(filename) // Handle errors if err == nil { - // Check file extension - ext := path.Ext(filename) - // - switch ext { - case ".lisp": - // Parse bytes into an S-Expression - schema, err = hir.ParseSchemaString(string(bytes)) - if err == nil { - return schema - } - case ".bin": - schema, err = binfile.HirSchemaFromJson(bytes) - if err == nil { - return schema - } - default: - err = fmt.Errorf("Unknown schema file format: %s\n", ext) + // Read the binary file + schema, err = binfile.HirSchemaFromJson(bytes) + if err == nil { + return schema } } - // Handle error - if e, ok := err.(*sexp.SyntaxError); ok { - printSyntaxError(filename, e, string(bytes)) - } else { - fmt.Println(err) - } - + // Handle error & exit + fmt.Println(err) os.Exit(2) // unreachable return nil } +// Parse a set of source files and compile them into a single schema. This can +// result, for example, in a syntax error, etc. +func readSourceFiles(filenames []string) *hir.Schema { + srcfiles := make([]*sexp.SourceFile, len(filenames)) + // Read each file + for i, n := range filenames { + // Read source file + bytes, err := os.ReadFile(n) + // Sanity check for errors + if err != nil { + fmt.Println(err) + os.Exit(3) + } + // + srcfiles[i] = sexp.NewSourceFile(n, bytes) + } + // Parse and compile source files + schema, errs := corset.CompileSourceFiles(srcfiles) + // Check for any errors + if errs == nil { + return schema + } + // Report errors + for _, err := range errs { + printSyntaxError(&err) + } + // Fail + os.Exit(4) + // unreachable + return nil +} + // Print a syntax error with appropriate highlighting. -func printSyntaxError(filename string, err *sexp.SyntaxError, text string) { +func printSyntaxError(err *sexp.SyntaxError) { span := err.Span() // Construct empty source map in order to determine enclosing line. - srcmap := sexp.NewSourceMap[sexp.SExp]([]rune(text)) + srcmap := sexp.NewSourceMap[sexp.SExp](err.SourceFile().Contents()) // line := srcmap.FindFirstEnclosingLine(span) // Print error + line number - fmt.Printf("%s:%d: %s\n", filename, line.Number(), err.Message()) + fmt.Printf("%s:%d: %s\n", err.SourceFile().Filename(), line.Number(), err.Message()) // Print separator line fmt.Println() // Print line diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go new file mode 100644 index 0000000..d0776a0 --- /dev/null +++ b/pkg/corset/ast.go @@ -0,0 +1,367 @@ +package corset + +import ( + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/sexp" + "github.com/consensys/go-corset/pkg/trace" +) + +// Circuit represents the root of the Abstract Syntax Tree. This is also +// referred to as the "prelude". All modules are contained within the root, and +// declarations can also be declared here as well. +type Circuit struct { + Modules []Module + Declarations []Declaration +} + +// Module represents a top-level module declaration. This corresponds to a +// table in the final constraint set. +type Module struct { + Name string + Declarations []Declaration +} + +// Node provides common functionality across all elements of the Abstract Syntax +// Tree. For example, it ensures every element can converted back into Lisp +// form for debugging. Furthermore, it provides a reference point for +// constructing a suitable source map for reporting syntax errors. +type Node interface { + // Convert this node into its lisp representation. This is primarily used + // for debugging purposes. + Lisp() sexp.SExp +} + +// Declaration represents a top-level declaration in a Corset source file (e.g. +// defconstraint, defcolumns, etc). +type Declaration interface { + Node + Resolve() +} + +// DefColumns captures a set of one or more columns being declared. +type DefColumns struct { + Columns []DefColumn +} + +// Resolve something. +func (p *DefColumns) Resolve() { + panic("got here") +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefColumns) Lisp() sexp.SExp { + panic("got here") +} + +// DefColumn packages together those piece relevant to declaring an individual +// column, such its name and type. +type DefColumn struct { + Name string + DataType sc.Type +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefColumn) Lisp() sexp.SExp { + panic("got here") +} + +// DefConstraint represents a vanishing constraint, which is either "local" or +// "global". A local constraint applies either to the first or last rows, +// whilst a global constraint applies to all rows. For a constraint to hold, +// its expression must evaluate to zero for the rows on which it is active. A +// constraint may also have a "guard" which is an expression that must evaluate +// to a non-zero value for the constraint to be considered active. The +// expression for a constraint must have a single context. That is, it can only +// be applied to columns within the same module (i.e. to ensure they have the +// same height). Furthermore, within a given module, we require that all +// columns accessed by the constraint have the same length multiplier. +type DefConstraint struct { + // Unique handle given to this constraint. This is primarily useful for + // debugging (i.e. so we know which constaint failed, etc). + Handle string + // Domain of this constraint, where nil indicates a global constraint. + // Otherwise, a given value indicates a single row on which this constraint + // should apply (where negative values are taken from the end, meaning that + // -1 represents the last row of a given module). + Domain *int + // A selector which determines for which rows this constraint is active. + // Specifically, when the expression evaluates to a non-zero value then the + // constraint is active; otherwiser, its inactive. Nil is permitted to + // indicate no guard is present. + Guard Expr + // The constraint itself which (when active) should evaluate to zero for the + // relevant set of rows. + Constraint Expr +} + +// Resolve something. +func (p *DefConstraint) Resolve() { + panic("got here") +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (p *DefConstraint) Lisp() sexp.SExp { + panic("got here") +} + +// DefLookup represents a lookup constraint between a set N of source +// expressions and a set of N target expressions. The source expressions must +// have a single context (i.e. all be in the same module) and likewise for the +// target expressions (though the source and target contexts can differ). The +// constraint can be viewed as a "subset constraint". Let the set of "source +// tuples" be those obtained by evaluating the source expressions over all rows +// in the source context, and likewise the "target tuples" those for the target +// expressions in the target context. Then the lookup constraint holds if the +// set of source tuples is a subset of the target tuples. This does not need to +// be a strict subset, so the two sets can be identical. Furthermore, these are +// not treated as multi-sets, hence the number of occurrences of a given tuple +// is not relevant. +type DefLookup struct { +} + +// DefPermutation represents a (lexicographically sorted) permutation of a set +// of source columns in a given source context, manifested as an assignment to a +// corresponding set of target columns. The sort direction for each of the +// source columns can be specified as increasing or decreasing. +type DefPermutation struct { +} + +// DefFun represents defines a (possibly pure) "function" (which, in actuality, +// is more like a macro). Specifically, whenever an invocation of this function +// is encountered we can imagine that, in the final constraint set, the body of +// the function is inlined at the point of the call. A pure function is not +// permitted to access any columns in scope (i.e. it can only manipulate its +// parameters). In contrast, an impure function can access those columns +// defined within its enclosing context. +type DefFun struct { +} + +// Expr represents an arbitrary expression over the columns of a given context +// (or the parameters of an enclosing function). Such expressions are pitched +// at a higher-level than those of the underlying constraint system. For +// example, they can contain conditionals (i.e. if expressions) and +// normalisations, etc. During the lowering process down to the underlying +// constraints level (AIR), such expressions are "compiled out" using various +// techniques (such as introducing computed columns where necessary). +type Expr interface { + Node + // Resolve resolves this expression in a given scope and constructs a fully + // resolved HIR expression. + Resolve() +} + +// ============================================================================ +// Addition +// ============================================================================ + +// Add represents the sum over zero or more expressions. +type Add struct{ Args []Expr } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Add) Resolve() { + for _, arg := range e.Args { + arg.Resolve() + } +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Add) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// Constants +// ============================================================================ + +// Constant represents a constant value within an expression. +type Constant struct{ Val fr.Element } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Constant) Resolve() { + // Nothing to resolve! +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Constant) Lisp() sexp.SExp { + return sexp.NewSymbol(e.Val.String()) +} + +// ============================================================================ +// Exponentiation +// ============================================================================ + +// Exp represents the a given value taken to a power. +type Exp struct { + Arg Expr + Pow uint64 +} + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Exp) Resolve() { + e.Arg.Resolve() +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Exp) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// IfZero +// ============================================================================ + +// IfZero returns the (optional) true branch when the condition evaluates to zero, and +// the (optional false branch otherwise. +type IfZero struct { + // Elements contained within this list. + Condition Expr + // True branch (optional). + TrueBranch Expr + // False branch (optional). + FalseBranch Expr +} + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *IfZero) Resolve() { + e.Condition.Resolve() + e.TrueBranch.Resolve() + e.FalseBranch.Resolve() +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *IfZero) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// List +// ============================================================================ + +// List represents a block of zero or more expressions. +type List struct{ Args []Expr } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *List) Resolve() { + for _, arg := range e.Args { + arg.Resolve() + } +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *List) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// Multiplication +// ============================================================================ + +// Mul represents the product over zero or more expressions. +type Mul struct{ Args []Expr } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Mul) Resolve() { + for _, arg := range e.Args { + arg.Resolve() + } +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Mul) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// Normalise +// ============================================================================ + +// Normalise reduces the value of an expression to either zero (if it was zero) +// or one (otherwise). +type Normalise struct{ Arg Expr } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Normalise) Resolve() { + e.Arg.Resolve() +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Normalise) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// Subtraction +// ============================================================================ + +// Sub represents the subtraction over zero or more expressions. +type Sub struct{ Args []Expr } + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *Sub) Resolve() { + for _, arg := range e.Args { + arg.Resolve() + } +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *Sub) Lisp() sexp.SExp { + panic("todo") +} + +// ============================================================================ +// VariableAccess +// ============================================================================ + +// VariableAccess represents reading the value of a given local variable (such +// as a function parameter). +type VariableAccess struct { + Module string + Name string + Shift int + Binding *Binder +} + +// Resolve accesses in this expression as either variable, column or macro +// accesses. +func (e *VariableAccess) Resolve() { + panic("todo") +} + +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *VariableAccess) Lisp() sexp.SExp { + panic("todo") +} + +// Binder provides additional information determined during the resolution +// phase. Specifically, it clarifies the meaning of a given variable name used +// within an expression (i.e. is it a column access, a local variable access, +// etc). +type Binder struct { + // Identifies whether this is a column access, or a variable access. + Column bool + // For a column access, this identifies the enclosing context. + Context trace.Context + // Identifies the variable or column index (as appropriate). + Index uint +} diff --git a/pkg/corset/compiler.go b/pkg/corset/compiler.go new file mode 100644 index 0000000..b2ef2e5 --- /dev/null +++ b/pkg/corset/compiler.go @@ -0,0 +1,74 @@ +package corset + +import ( + "github.com/consensys/go-corset/pkg/hir" + "github.com/consensys/go-corset/pkg/sexp" +) + +// SyntaxError defines the kind of errors that can be reported by this compiler. +// Syntax errors are always associated with some line in one of the original +// source files. For simplicity, we reuse existing notion of syntax error from +// the S-Expression library. +type SyntaxError = sexp.SyntaxError + +// CompileSourceFiles compiles one or more source files into a schema. This +// process can fail if the source files are mal-formed, or contain syntax errors +// or other forms of error (e.g. type errors). +func CompileSourceFiles(srcfiles []*sexp.SourceFile) (*hir.Schema, []SyntaxError) { + circuit, srcmap, errs := ParseSourceFiles(srcfiles) + // Check for parsing errors + if errs != nil { + return nil, errs + } + // Compile each module into the schema + return NewCompiler(circuit, srcmap).Compile() +} + +// CompileSourceFile compiles exactly one source file into a schema. This is +// really helper function for e.g. the testing environment. This process can +// fail if the source file is mal-formed, or contains syntax errors or other +// forms of error (e.g. type errors). +func CompileSourceFile(srcfile *sexp.SourceFile) (*hir.Schema, []SyntaxError) { + schema, errs := CompileSourceFiles([]*sexp.SourceFile{srcfile}) + // Check for errors + if errs != nil { + return nil, errs + } + // + return schema, nil +} + +// Compiler packages up everything needed to compile a given set of module +// definitions down into an HIR schema. Observe that the compiler may fail if +// the modules definitions are malformed in some way (e.g. fail type checking). +type Compiler struct { + // A high-level definition of a Corset circuit. + circuit Circuit + // Source maps nodes in the circuit back to the spans in their original + // source files. This is needed when reporting syntax errors to generate + // highlights of the relevant source line(s) in question. + srcmap *sexp.SourceMaps[Node] +} + +// NewCompiler constructs a new compiler for a given set of modules. +func NewCompiler(circuit Circuit, srcmaps *sexp.SourceMaps[Node]) *Compiler { + return &Compiler{circuit, srcmaps} +} + +// Compile is the top-level function for the corset compiler which actually +// compiles the given modules down into a schema. This can fail in a variety of +// ways if the given modules are malformed in some way. For example, if some +// expression refers to a non-existent module or column, or is not well-typed, +// etc. +func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) { + schema := hir.EmptySchema() + // Allocate columns? + // + // Resolve variables (via nested scopes) + // Check constraint contexts (e.g. for constraints, lookups, etc) + // Type check constraints + // Finally, translate everything and add it to the schema. + errors := translateCircuit(&p.circuit, schema) + // Done + return schema, errors +} diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go new file mode 100644 index 0000000..908ab7c --- /dev/null +++ b/pkg/corset/parser.go @@ -0,0 +1,481 @@ +package corset + +import ( + "errors" + "math/big" + "sort" + "strconv" + "strings" + "unicode" + + "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/sexp" +) + +// =================================================================== +// Public +// =================================================================== + +// ParseSourceFiles parses zero or more source files producing zero or more +// modules. Observe that, since a given module can be spread over multiple +// files, there can be far few modules created than there are source files. This +// function does more than just parse the individual files, because it +// additional combines all fragments of the same module together into one place. +// Thus, you should never expect to see duplicate module names in the returned +// array. +func ParseSourceFiles(files []*sexp.SourceFile) (Circuit, *sexp.SourceMaps[Node], []SyntaxError) { + var circuit Circuit + // (for now) at most one error per source file is supported. + var errors []SyntaxError = make([]SyntaxError, len(files)) + // Construct an initially empty source map + srcmaps := sexp.NewSourceMaps[Node]() + // num_errs counts the number of errors reported + var num_errs uint + // Contents map holds the combined fragments of each module. + contents := make(map[string]Module, 0) + // Names identifies the names of each unique module. + names := make([]string, 0) + // + for i, file := range files { + c, srcmap, err := ParseSourceFile(file) + // Handle errors + if err != nil { + num_errs++ + // Report any errors encountered + errors[i] = *err + } else { + // Combine source maps + srcmaps.Join(srcmap) + } + // Update top-level declarations + circuit.Declarations = append(circuit.Declarations, c.Declarations...) + // Allocate any module fragments + for _, m := range c.Modules { + if om, ok := contents[m.Name]; !ok { + contents[m.Name] = m + names = append(names, m.Name) + } else { + om.Declarations = append(om.Declarations, m.Declarations...) + contents[m.Name] = om + } + } + } + // Bring all fragmenmts together + circuit.Modules = make([]Module, len(names)) + // Sort module names to ensure that compilation is always deterministic. + sort.Strings(names) + // Finalise every module + for i, n := range names { + // Assume this cannot fail as every module in names has been assigned at + // least one fragment. + circuit.Modules[i] = contents[n] + } + // Done + if num_errs > 0 { + return circuit, srcmaps, errors + } + // no errors + return circuit, srcmaps, nil +} + +// ParseSourceFile parses the contents of a single lisp file into one or more +// modules. Observe that every lisp file starts in the "prelude" or "root" +// module, and may declare items for additional modules as necessary. +func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], *SyntaxError) { + var circuit Circuit + // Parse bytes into an S-Expression + terms, srcmap, err := srcfile.ParseAll() + // Check test file parsed ok + if err != nil { + return circuit, nil, err + } + // Construct parser for corset syntax + p := NewParser(srcfile, srcmap) + // Parse whatever is declared at the beginning of the file before the first + // module declaration. These declarations form part of the "prelude". + if circuit.Declarations, terms, err = p.parseModuleContents(terms); err != nil { + return circuit, nil, err + } + // Continue parsing string until nothing remains. + for len(terms) != 0 { + var ( + name string + decls []Declaration + ) + // Extract module name + if name, err = p.parseModuleStart(terms[0]); err != nil { + return circuit, nil, err + } + // Parse module contents + if decls, terms, err = p.parseModuleContents(terms[1:]); err != nil { + return circuit, nil, err + } else if len(decls) != 0 { + circuit.Modules = append(circuit.Modules, Module{name, decls}) + } + } + // Done + return circuit, p.nodemap, nil +} + +// Parser implements a simple parser for the Corset language. The parser itself +// is relatively simplistic and simply packages up the relevant lisp constructs +// into their corresponding AST forms. This can fail in various ways, such as +// e.g. a "defconstraint" not having exactly three arguments, etc. However, the +// parser does not attempt to perform more complex forms of validation (e.g. +// ensuring that expressions are well-typed, etc) --- that is left up to the +// compiler. +type Parser struct { + // Translator used for recursive expressions. + translator *sexp.Translator[Expr] + // Mapping from constructed S-Expressions to their spans in the original text. + nodemap *sexp.SourceMap[Node] +} + +// NewParser constructs a new parser using a given mapping from S-Expressions to +// spans in the underlying source file. +func NewParser(srcfile *sexp.SourceFile, srcmap *sexp.SourceMap[sexp.SExp]) *Parser { + p := sexp.NewTranslator[Expr](srcfile, srcmap) + // Construct (initially empty) node map + nodemap := sexp.NewSourceMap[Node](srcmap.Text()) + // Construct parser + parser := &Parser{p, nodemap} + // Configure expression translator + p.AddSymbolRule(constantParserRule) + p.AddSymbolRule(varAccessParserRule) + p.AddBinaryRule("shift", shiftParserRule) + p.AddRecursiveRule("+", addParserRule) + p.AddRecursiveRule("-", subParserRule) + p.AddRecursiveRule("*", mulParserRule) + p.AddRecursiveRule("~", normParserRule) + p.AddRecursiveRule("^", powParserRule) + p.AddRecursiveRule("if", ifParserRule) + p.AddRecursiveRule("begin", beginParserRule) + // + return parser +} + +// Extract all declarations associated with a given module and package them up. +func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.SExp, *SyntaxError) { + // + decls := make([]Declaration, 0) + // + for i, s := range terms { + e, ok := s.(*sexp.List) + // Check for error + if !ok { + return nil, nil, p.translator.SyntaxError(s, "unexpected or malformed declaration") + } + // Check for end-of-module + if e.MatchSymbols(2, "module") { + return decls, terms[i:], nil + } + // Parse the declaration + if decl, err := p.parseDeclaration(e); err != nil { + return nil, nil, err + } else { + // Continue accumulating declarations for this module. + decls = append(decls, decl) + } + } + // End-of-file signals end-of-module. + return decls, make([]sexp.SExp, 0), nil +} + +// Parse a module declaration of the form "(module m1)" which indicates the +// start of module m1. +func (p *Parser) parseModuleStart(s sexp.SExp) (string, *SyntaxError) { + l, ok := s.(*sexp.List) + // Check for error + if !ok { + return "", p.translator.SyntaxError(s, "unexpected or malformed declaration") + } + // Sanity check declaration + if len(l.Elements) > 2 { + return "", p.translator.SyntaxError(l, "malformed module declaration") + } + // Extract column name + name := l.Elements[1].AsSymbol().Value + // + return name, nil +} + +func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, *SyntaxError) { + if s.MatchSymbols(1, "defcolumns") { + return p.parseColumnDeclarations(s) + } else if s.Len() == 4 && s.MatchSymbols(2, "defconstraint") { + return p.parseConstraintDeclaration(s.Elements) + } + /* + else if e.Len() == 3 && e.MatchSymbols(2, "assert") { + return p.parseAssertionDeclaration(env, e.Elements) + } else if e.Len() == 3 && e.MatchSymbols(1, "defpermutation") { + return p.parsePermutationDeclaration(env, e) + } else if e.Len() == 4 && e.MatchSymbols(1, "deflookup") { + return p.parseLookupDeclaration(env, e) + } else if e.Len() == 3 && e.MatchSymbols(1, "definterleaved") { + return p.parseInterleavingDeclaration(env, e) + } else if e.Len() == 3 && e.MatchSymbols(1, "definrange") { + return p.parseRangeDeclaration(env, e) + } else if e.Len() == 3 && e.MatchSymbols(1, "defpurefun") { + return p.parsePureFunDeclaration(env, e) + } */ + return nil, p.translator.SyntaxError(s, "malformed declaration") +} + +// Parse a column declaration +func (p *Parser) parseColumnDeclarations(l *sexp.List) (*DefColumns, *SyntaxError) { + columns := make([]DefColumn, l.Len()-1) + // Sanity check declaration + if len(l.Elements) == 1 { + return nil, p.translator.SyntaxError(l, "malformed column declaration") + } + // Process column declarations one by one. + for i := 1; i < len(l.Elements); i++ { + decl, err := p.parseColumnDeclaration(l.Elements[i]) + // Extract column name + if err != nil { + return nil, err + } + // Assign the declaration + columns[i-1] = decl + } + // Done + return &DefColumns{columns}, nil +} + +func (p *Parser) parseColumnDeclaration(e sexp.SExp) (DefColumn, *SyntaxError) { + var defcolumn DefColumn + // Default to field type + defcolumn.DataType = &sc.FieldType{} + // Check whether extended declaration or not. + if l := e.AsList(); l != nil { + // Check at least the name provided. + if len(l.Elements) == 0 { + return defcolumn, p.translator.SyntaxError(l, "empty column declaration") + } + // Column name is always first + defcolumn.Name = l.Elements[0].String(false) + // Parse type (if applicable) + if len(l.Elements) == 2 { + var err *SyntaxError + if defcolumn.DataType, err = p.parseType(l.Elements[1]); err != nil { + return defcolumn, err + } + } else if len(l.Elements) > 2 { + // For now. + return defcolumn, p.translator.SyntaxError(l, "unknown column declaration attributes") + } + } else { + defcolumn.Name = e.String(false) + } + // + return defcolumn, nil +} + +// Parse a vanishing declaration +func (p *Parser) parseConstraintDeclaration(elements []sexp.SExp) (*DefConstraint, *SyntaxError) { + // + handle := elements[1].AsSymbol().Value + // Vanishing constraints do not have global scope, hence qualified column + // accesses are not permitted. + domain, guard, err := p.parseConstraintAttributes(elements[2]) + // Check for error + if err != nil { + return nil, err + } + // Translate expression + expr, err := p.translator.Translate(elements[3]) + if err != nil { + return nil, err + } + // Done + return &DefConstraint{handle, domain, guard, expr}, nil +} + +func (p *Parser) parseConstraintAttributes(attributes sexp.SExp) (domain *int, guard Expr, err *SyntaxError) { + // Check attribute list is a list + if attributes.AsList() == nil { + return nil, nil, p.translator.SyntaxError(attributes, "expected attribute list") + } + // Deconstruct as list + attrs := attributes.AsList() + // Process each attribute in turn + for i := 0; i < attrs.Len(); i++ { + ith := attrs.Get(i) + // Check start of attribute + if ith.AsSymbol() == nil { + return nil, nil, p.translator.SyntaxError(ith, "malformed attribute") + } + // Check what we've got + switch ith.AsSymbol().Value { + case ":domain": + i++ + if domain, err = p.parseDomainAttribute(attrs.Get(i)); err != nil { + return nil, nil, err + } + case ":guard": + i++ + if guard, err = p.translator.Translate(attrs.Get(i)); err != nil { + return nil, nil, err + } + default: + return nil, nil, p.translator.SyntaxError(ith, "unknown attribute") + } + } + // Done + return domain, guard, nil +} + +func (p *Parser) parseDomainAttribute(attribute sexp.SExp) (domain *int, err *SyntaxError) { + if attribute.AsSet() == nil { + return nil, p.translator.SyntaxError(attribute, "malformed domain set") + } + // Sanity check + set := attribute.AsSet() + // Check all domain elements well-formed. + for i := 0; i < set.Len(); i++ { + ith := set.Get(i) + if ith.AsSymbol() == nil { + return nil, p.translator.SyntaxError(ith, "malformed domain") + } + } + // Currently, only support domains of size 1. + if set.Len() == 1 { + first, err := strconv.Atoi(set.Get(0).AsSymbol().Value) + // Check for parse error + if err != nil { + return nil, p.translator.SyntaxError(set.Get(0), "malformed domain element") + } + // Done + return &first, nil + } + // Fail + return nil, p.translator.SyntaxError(attribute, "multiple values not supported") +} + +func (p *Parser) parseType(term sexp.SExp) (sc.Type, *SyntaxError) { + symbol := term.AsSymbol() + if symbol == nil { + return nil, p.translator.SyntaxError(term, "malformed column") + } + // Access string of symbol + str := symbol.Value + if strings.HasPrefix(str, ":u") { + n, err := strconv.Atoi(str[2:]) + if err != nil { + return nil, p.translator.SyntaxError(symbol, err.Error()) + } + // Done + return sc.NewUintType(uint(n)), nil + } + // Error + return nil, p.translator.SyntaxError(symbol, "unknown type") +} + +func beginParserRule(_ string, args []Expr) (Expr, error) { + return &List{args}, nil +} + +func constantParserRule(symbol string) (Expr, bool, error) { + if symbol[0] >= '0' && symbol[0] < '9' { + var num fr.Element + // Attempt to parse + _, err := num.SetString(symbol) + // Check for errors + if err != nil { + return nil, true, err + } + // Done + return &Constant{Val: num}, true, nil + } + // Not applicable + return nil, false, nil +} + +func varAccessParserRule(col string) (Expr, bool, error) { + // Sanity check what we have + if !unicode.IsLetter(rune(col[0])) { + return nil, false, nil + } + // Handle qualified accesses (where permitted) + // Attempt to split column name into module / column pair. + split := strings.Split(col, ".") + if len(split) == 2 { + return &VariableAccess{split[0], split[1], 0, nil}, true, nil + } else if len(split) > 2 { + return nil, true, errors.New("malformed column access") + } + // Done + return &VariableAccess{"", col, 0, nil}, true, nil +} + +func addParserRule(_ string, args []Expr) (Expr, error) { + return &Add{args}, nil +} + +func subParserRule(_ string, args []Expr) (Expr, error) { + return &Sub{args}, nil +} + +func mulParserRule(_ string, args []Expr) (Expr, error) { + return &Mul{args}, nil +} + +func ifParserRule(_ string, args []Expr) (Expr, error) { + if len(args) == 2 { + return &IfZero{args[0], args[1], nil}, nil + } else if len(args) == 3 { + return &IfZero{args[0], args[1], args[2]}, nil + } + + return nil, errors.New("incorrect number of arguments") +} + +func shiftParserRule(col string, amt string) (Expr, error) { + n, err := strconv.Atoi(amt) + + if err != nil { + return nil, err + } + // Sanity check what we have + if !unicode.IsLetter(rune(col[0])) { + return nil, nil + } + // Handle qualified accesses (where appropriate) + split := strings.Split(col, ".") + if len(split) == 2 { + return &VariableAccess{split[0], split[1], n, nil}, nil + } else if len(split) > 2 { + return nil, errors.New("malformed column access") + } + // Done + return &VariableAccess{"", col, n, nil}, nil +} + +func powParserRule(_ string, args []Expr) (Expr, error) { + var k big.Int + + if len(args) != 2 { + return nil, errors.New("incorrect number of arguments") + } + + c, ok := args[1].(*Constant) + if !ok { + return nil, errors.New("expected constant power") + } else if !c.Val.IsUint64() { + return nil, errors.New("constant power too large") + } + // Convert power to uint64 + c.Val.BigInt(&k) + // Done + return &Exp{Arg: args[0], Pow: k.Uint64()}, nil +} + +func normParserRule(_ string, args []Expr) (Expr, error) { + if len(args) != 1 { + return nil, errors.New("incorrect number of arguments") + } + + return &Normalise{Arg: args[0]}, nil +} diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go new file mode 100644 index 0000000..9998c5c --- /dev/null +++ b/pkg/corset/resolver.go @@ -0,0 +1 @@ +package corset diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go new file mode 100644 index 0000000..13de05b --- /dev/null +++ b/pkg/corset/translator.go @@ -0,0 +1,37 @@ +package corset + +import ( + "github.com/consensys/go-corset/pkg/hir" +) + +// Translate the components of a Corset circuit and add them to the schema. By +// the time we get to this point, all malformed source files should have been +// rejected already and the translation should go through easily. Thus, whilst +// syntax errors can be returned here, this should never happen. The mechanism +// is supported, however, to simplify development of new features, etc. +func translateCircuit(circuit *Circuit, schema *hir.Schema) []SyntaxError { + panic("todo") +} + +// Translate a Corset declaration and add it to the schema. By the time we get +// to this point, all malformed source files should have been rejected already +// and the translation should go through easily. Thus, whilst syntax errors can +// be returned here, this should never happen. The mechanism is supported, +// however, to simplify development of new features, etc. +func translateDeclaration(decl Declaration, schema *hir.Schema) []SyntaxError { + if d, ok := decl.(*DefColumns); ok { + translateDefColumns(d, schema) + } else if d, ok := decl.(*DefConstraint); ok { + translateDefConstraint(d, schema) + } + // Error handling + panic("unknown declaration") +} + +func translateDefColumns(decl *DefColumns, schema *hir.Schema) { + panic("TODO") +} + +func translateDefConstraint(decl *DefConstraint, schema *hir.Schema) { + panic("TODO") +} diff --git a/pkg/hir/environment.go b/pkg/hir/environment.go deleted file mode 100644 index bdd0968..0000000 --- a/pkg/hir/environment.go +++ /dev/null @@ -1,116 +0,0 @@ -package hir - -import ( - "fmt" - - "github.com/consensys/go-corset/pkg/schema" - sc "github.com/consensys/go-corset/pkg/schema" - "github.com/consensys/go-corset/pkg/trace" -) - -// =================================================================== -// Environment -// =================================================================== - -// Identifies a specific column within the environment. -type columnRef struct { - module uint - column string -} - -// Environment maps module and column names to their (respective) module and -// column indices. The environment also keeps trace of which modules / columns -// are declared so we can sanity check them when they are referred to (e.g. in a -// constraint). -type Environment struct { - // Maps module names to their module indices. - modules map[string]uint - // Maps column references to their column indices. - columns map[columnRef]uint - // Schema being constructed - schema *Schema -} - -// EmptyEnvironment constructs an empty environment. -func EmptyEnvironment() *Environment { - modules := make(map[string]uint) - columns := make(map[columnRef]uint) - schema := EmptySchema() - // - return &Environment{modules, columns, schema} -} - -// RegisterModule registers a new module within this environment. Observe that -// this will panic if the module already exists. -func (p *Environment) RegisterModule(module string) trace.Context { - if p.HasModule(module) { - panic(fmt.Sprintf("module %s already exists", module)) - } - // Update schema - mid := p.schema.AddModule(module) - // Update cache - p.modules[module] = mid - // Done - return trace.NewContext(mid, 1) -} - -// AddDataColumn registers a new column within a given module. Observe that -// this will panic if the column already exists. -func (p *Environment) AddDataColumn(context trace.Context, column string, datatype sc.Type) uint { - if p.HasColumn(context, column) { - panic(fmt.Sprintf("column %d:%s already exists", context.Module(), column)) - } - // Update schema - p.schema.AddDataColumn(context, column, datatype) - // Update cache - cid := uint(len(p.columns)) - cref := columnRef{context.Module(), column} - p.columns[cref] = cid - // Done - return cid -} - -// AddAssignment appends a new assignment (i.e. set of computed columns) to be -// used during trace expansion for this schema. Computed columns are introduced -// by the process of lowering from HIR / MIR to AIR. -func (p *Environment) AddAssignment(decl schema.Assignment) { - // Update schema - index := p.schema.AddAssignment(decl) - // Update cache - for i := decl.Columns(); i.HasNext(); { - ith := i.Next() - cref := columnRef{ith.Context().Module(), ith.Name()} - p.columns[cref] = index - index++ - } -} - -// LookupModule determines the module index for a given named module, or return -// false if no such module exists. -func (p *Environment) LookupModule(module string) (trace.Context, bool) { - mid, ok := p.modules[module] - return trace.NewContext(mid, 1), ok -} - -// LookupColumn determines the column index for a given named column in a given -// module, or return false if no such column exists. -func (p *Environment) LookupColumn(context trace.Context, column string) (uint, bool) { - cref := columnRef{context.Module(), column} - cid, ok := p.columns[cref] - - return cid, ok -} - -// HasModule checks whether a given module exists, or not. -func (p *Environment) HasModule(module string) bool { - _, ok := p.LookupModule(module) - // Discard column index - return ok -} - -// HasColumn checks whether a given module has a given column, or not. -func (p *Environment) HasColumn(context trace.Context, column string) bool { - _, ok := p.LookupColumn(context, column) - // Discard column index - return ok -} diff --git a/pkg/hir/eval.go b/pkg/hir/eval.go index 4012f64..ec59ff4 100644 --- a/pkg/hir/eval.go +++ b/pkg/hir/eval.go @@ -15,6 +15,12 @@ func (e *ColumnAccess) EvalAllAt(k int, tr trace.Trace) []fr.Element { return []fr.Element{val} } +// EvalAllAt attempts to evaluate a variable access at a given row in a trace. +// However, at this time, that does not make sense. +func (e *VariableAccess) EvalAllAt(k int, tr trace.Trace) []fr.Element { + panic("unsupported operation") +} + // EvalAllAt evaluates a constant at a given row in a trace, which simply returns // that constant. func (e *Constant) EvalAllAt(k int, tr trace.Trace) []fr.Element { diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 279ac93..7a5a02a 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -415,6 +415,47 @@ func (p *Normalise) RequiredCells(row int, tr trace.Trace) *util.AnySortedSet[tr // does not perform any form of simplification to determine this. func (p *Normalise) AsConstant() *fr.Element { return nil } +// ============================================================================ +// VariableAccess +// ============================================================================ + +// VariableAccess represents reading the value of a given local variable (such +// as a function parameter). +type VariableAccess struct { + Name string + Shift int +} + +// Bounds returns max shift in either the negative (left) or positive +// direction (right). +func (p *VariableAccess) Bounds() util.Bounds { + panic("variable accesses do not have bounds") +} + +// Context determines the evaluation context (i.e. enclosing module) for this +// expression. +func (p *VariableAccess) Context(schema sc.Schema) trace.Context { + panic("variable accesses do not have a context") +} + +// RequiredColumns returns the set of columns on which this term depends. +// That is, columns whose values may be accessed when evaluating this term +// on a given trace. +func (p *VariableAccess) RequiredColumns() *util.SortedSet[uint] { + panic("unsupported operation") +} + +// RequiredCells returns the set of trace cells on which this term depends. +// In this case, that is the empty set. +func (p *VariableAccess) RequiredCells(row int, tr trace.Trace) *util.AnySortedSet[trace.CellRef] { + panic("unsupported operation") +} + +// AsConstant determines whether or not this is a constant expression. If +// so, the constant is returned; otherwise, nil is returned. NOTE: this +// does not perform any form of simplification to determine this. +func (p *VariableAccess) AsConstant() *fr.Element { return nil } + // ============================================================================ // ColumnAccess // ============================================================================ diff --git a/pkg/hir/lisp.go b/pkg/hir/lisp.go index 23756c9..cb55745 100644 --- a/pkg/hir/lisp.go +++ b/pkg/hir/lisp.go @@ -23,6 +23,21 @@ func (e *ColumnAccess) Lisp(schema sc.Schema) sexp.SExp { return sexp.NewList([]sexp.SExp{sexp.NewSymbol("shift"), access, shift}) } +// Lisp converts this schema element into a simple S-Expression, for example +// so it can be printed. +func (e *VariableAccess) Lisp(schema sc.Schema) sexp.SExp { + access := sexp.NewSymbol(e.Name) + // Check whether shifted (or not) + if e.Shift == 0 { + // Not shifted + return access + } + // Shifted + shift := sexp.NewSymbol(fmt.Sprintf("%d", e.Shift)) + + return sexp.NewList([]sexp.SExp{sexp.NewSymbol("shift"), access, shift}) +} + // Lisp converts this schema element into a simple S-Expression, for example // so it can be printed. func (e *Constant) Lisp(schema sc.Schema) sexp.SExp { diff --git a/pkg/hir/lower.go b/pkg/hir/lower.go index 8df3efa..949baf9 100644 --- a/pkg/hir/lower.go +++ b/pkg/hir/lower.go @@ -113,6 +113,13 @@ func (e *ColumnAccess) LowerTo(schema *mir.Schema) []mir.Expr { return lowerTo(e, schema) } +// LowerTo lowers a variable access to the MIR level. This requires expanding +// the arguments, then lowering them. Furthermore, conditionals are "lifted" to +// the top. +func (e *VariableAccess) LowerTo(schema *mir.Schema) []mir.Expr { + return lowerTo(e, schema) +} + // LowerTo lowers an exponent expression to the MIR level. This requires expanding // the argument andn lowering it. Furthermore, conditionals are "lifted" to // the top. diff --git a/pkg/hir/macro.go b/pkg/hir/macro.go new file mode 100644 index 0000000..f7ad5e6 --- /dev/null +++ b/pkg/hir/macro.go @@ -0,0 +1,18 @@ +package hir + +// MacroDefinition represents something which can be called, and that will be +// inlined at the point of call. +type MacroDefinition struct { + // Enclosing module + module uint + // Name of the macro + name string + // Parameters of the macro + params []string + // Body of the macro + body Expr + // Indicates whether or not this macro is "pure". More specifically, pure + // macros can only refer to parameters (i.e. cannot access enclosing columns + // directly). + pure bool +} diff --git a/pkg/hir/parser.go b/pkg/hir/parser.go deleted file mode 100644 index 37eb363..0000000 --- a/pkg/hir/parser.go +++ /dev/null @@ -1,745 +0,0 @@ -package hir - -import ( - "errors" - "fmt" - "math/big" - "strconv" - "strings" - "unicode" - - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - sc "github.com/consensys/go-corset/pkg/schema" - "github.com/consensys/go-corset/pkg/schema/assignment" - "github.com/consensys/go-corset/pkg/sexp" - "github.com/consensys/go-corset/pkg/trace" -) - -// =================================================================== -// Public -// =================================================================== - -// ParseSchemaString parses a sequence of zero or more HIR schema declarations -// represented as a string. Internally, this uses sexp.ParseAll and -// ParseSchemaSExp to do the work. -func ParseSchemaString(str string) (*Schema, error) { - parser := sexp.NewParser(str) - // Parse bytes into an S-Expression - terms, err := parser.ParseAll() - // Check test file parsed ok - if err != nil { - return nil, err - } - // Parse terms into an HIR schema - p := newHirParser(parser.SourceMap()) - // Continue parsing string until nothing remains. - for _, term := range terms { - // Process declaration - err2 := p.parseDeclaration(term) - if err2 != nil { - return nil, err2 - } - } - // Done - return p.env.schema, nil -} - -// =================================================================== -// Private -// =================================================================== - -type hirParser struct { - // Translator used for recursive expressions. - translator *sexp.Translator[Expr] - // Current module being parsed. - module trace.Context - // Environment used during parsing to resolve column names into column - // indices. - env *Environment - // Global is used exclusively when parsing expressions to signal whether or - // not qualified column accesses are permitted (i.e. which include a - // module). - global bool -} - -func newHirParser(srcmap *sexp.SourceMap[sexp.SExp]) *hirParser { - p := sexp.NewTranslator[Expr](srcmap) - // Initialise empty environment - env := EmptyEnvironment() - // Register top-level module (aka the prelude) - prelude := env.RegisterModule("") - // Construct parser - parser := &hirParser{p, prelude, env, false} - // Configure translator - p.AddSymbolRule(constantParserRule) - p.AddSymbolRule(columnAccessParserRule(parser)) - p.AddBinaryRule("shift", shiftParserRule(parser)) - p.AddRecursiveRule("+", addParserRule) - p.AddRecursiveRule("-", subParserRule) - p.AddRecursiveRule("*", mulParserRule) - p.AddRecursiveRule("~", normParserRule) - p.AddRecursiveRule("^", powParserRule) - p.AddRecursiveRule("if", ifParserRule) - p.AddRecursiveRule("ifnot", ifNotParserRule) - p.AddRecursiveRule("begin", beginParserRule) - // - return parser -} - -func (p *hirParser) parseDeclaration(s sexp.SExp) error { - if e, ok := s.(*sexp.List); ok { - if e.MatchSymbols(2, "module") { - return p.parseModuleDeclaration(e) - } else if e.MatchSymbols(1, "defcolumns") { - return p.parseColumnDeclarations(e) - } else if e.Len() == 4 && e.MatchSymbols(2, "defconstraint") { - return p.parseConstraintDeclaration(e.Elements) - } else if e.Len() == 3 && e.MatchSymbols(2, "assert") { - return p.parseAssertionDeclaration(e.Elements) - } else if e.Len() == 3 && e.MatchSymbols(1, "defpermutation") { - return p.parsePermutationDeclaration(e) - } else if e.Len() == 4 && e.MatchSymbols(1, "deflookup") { - return p.parseLookupDeclaration(e) - } else if e.Len() == 3 && e.MatchSymbols(1, "definterleaved") { - return p.parseInterleavingDeclaration(e) - } else if e.Len() == 3 && e.MatchSymbols(1, "definrange") { - return p.parseRangeDeclaration(e) - } - } - // Error - return p.translator.SyntaxError(s, "unexpected or malformed declaration") -} - -// Parse a column declaration -func (p *hirParser) parseModuleDeclaration(l *sexp.List) error { - // Sanity check declaration - if len(l.Elements) > 2 { - return p.translator.SyntaxError(l, "malformed module declaration") - } - // Extract column name - moduleName := l.Elements[1].AsSymbol().Value - // Sanity check doesn't already exist - if p.env.HasModule(moduleName) { - return p.translator.SyntaxError(l, "duplicate module declaration") - } - // Register module - mid := p.env.RegisterModule(moduleName) - // Set current module - p.module = mid - // - return nil -} - -// Parse a column declaration -func (p *hirParser) parseColumnDeclarations(l *sexp.List) error { - // Sanity check declaration - if len(l.Elements) == 1 { - return p.translator.SyntaxError(l, "malformed column declaration") - } - // Process column declarations one by one. - for i := 1; i < len(l.Elements); i++ { - // Extract column name - if err := p.parseColumnDeclaration(l.Elements[i]); err != nil { - return err - } - } - - return nil -} - -func (p *hirParser) parseColumnDeclaration(e sexp.SExp) error { - var columnName string - // Default to field type - var columnType sc.Type = &sc.FieldType{} - // Check whether extended declaration or not. - if l := e.AsList(); l != nil { - // Check at least the name provided. - if len(l.Elements) == 0 { - return p.translator.SyntaxError(l, "empty column declaration") - } - // Column name is always first - columnName = l.Elements[0].String(false) - // Parse type (if applicable) - if len(l.Elements) == 2 { - var err error - if columnType, err = p.parseType(l.Elements[1]); err != nil { - return err - } - } else if len(l.Elements) > 2 { - // For now. - return p.translator.SyntaxError(l, "unknown column declaration attributes") - } - } else { - columnName = e.String(false) - } - // Sanity check doesn't already exist - if p.env.HasColumn(p.module, columnName) { - return p.translator.SyntaxError(e, "duplicate column declaration") - } - // Register column - cid := p.env.AddDataColumn(p.module, columnName, columnType) - // Apply type constraint (if applicable) - if columnType.AsUint() != nil { - bound := columnType.AsUint().Bound() - p.env.schema.AddRangeConstraint(columnName, p.module, &ColumnAccess{cid, 0}, bound) - } - // - return nil -} - -// Parse a sorted permutation declaration -func (p *hirParser) parsePermutationDeclaration(l *sexp.List) error { - // Target columns are (sorted) permutations of source columns. - sexpTargets := l.Elements[1].AsList() - // Source columns. - sexpSources := l.Elements[2].AsList() - // Sanity check - if sexpTargets == nil { - return p.translator.SyntaxError(l.Elements[1], "malformed target columns") - } else if sexpSources == nil { - return p.translator.SyntaxError(l.Elements[2], "malformed source columns") - } - // Convert into appropriate form. - sources := make([]uint, sexpSources.Len()) - signs := make([]bool, sexpSources.Len()) - // - if sexpTargets.Len() != sexpSources.Len() { - return p.translator.SyntaxError(l, "sorted permutation requires matching number of source and target columns") - } - // initialise context - ctx := trace.VoidContext() - // - for i := 0; i < sexpSources.Len(); i++ { - sourceIndex, sourceSign, err := p.parsePermutationSource(sexpSources.Get(i)) - if err != nil { - return err - } - // Check source context - sourceCol := p.env.schema.Columns().Nth(sourceIndex) - ctx = ctx.Join(sourceCol.Context()) - // Sanity check we have a sensible type here. - if ctx.IsConflicted() { - return p.translator.SyntaxError(sexpSources.Get(i), "conflicting evaluation context") - } else if ctx.IsVoid() { - return p.translator.SyntaxError(sexpSources.Get(i), "empty evaluation context") - } - // Copy over column name - signs[i] = sourceSign - sources[i] = sourceIndex - } - // Parse targets - targets := make([]sc.Column, sexpTargets.Len()) - // Parse targets - for i := 0; i < sexpTargets.Len(); i++ { - targetName, err := p.parsePermutationTarget(sexpTargets.Get(i)) - // - if err != nil { - return err - } - // Lookup corresponding source - source := p.env.schema.Columns().Nth(sources[i]) - // Done - targets[i] = sc.NewColumn(ctx, targetName, source.Type()) - } - // - p.env.AddAssignment(assignment.NewSortedPermutation(ctx, targets, signs, sources)) - // - return nil -} - -func (p *hirParser) parsePermutationSource(source sexp.SExp) (uint, bool, error) { - var ( - name string - sign bool - err error - ) - - if source.AsList() != nil { - l := source.AsList() - // Check whether sort direction provided - if l.Len() != 2 || l.Get(0).AsSymbol() == nil || l.Get(1).AsSymbol() == nil { - return 0, false, p.translator.SyntaxError(source, "malformed column") - } - // Parser sorting direction - if sign, err = p.parseSortDirection(l.Get(0).AsSymbol()); err != nil { - return 0, false, err - } - // Extract column name - name = l.Get(1).AsSymbol().Value - } else { - name = source.AsSymbol().Value - sign = true // default - } - // Determine index for source column - index, ok := p.env.LookupColumn(p.module, name) - if !ok { - // Column doesn't exist! - return 0, false, p.translator.SyntaxError(source, "unknown column") - } - // Done - return index, sign, nil -} - -func (p *hirParser) parsePermutationTarget(target sexp.SExp) (string, error) { - if target.AsSymbol() == nil { - return "", p.translator.SyntaxError(target, "malformed target column") - } - // - targetName := target.AsSymbol().Value - // Sanity check that target column *doesn't* exist. - if p.env.HasColumn(p.module, targetName) { - // No, it doesn't. - return "", p.translator.SyntaxError(target, "duplicate column") - } - // Done - return targetName, nil -} - -func (p *hirParser) parseSortDirection(l *sexp.Symbol) (bool, error) { - switch l.Value { - case "+", "↓": - return true, nil - case "-", "↑": - return false, nil - } - // Unknown sort - return false, p.translator.SyntaxError(l, "malformed sort direction") -} - -// Parse a lookup declaration -func (p *hirParser) parseLookupDeclaration(l *sexp.List) error { - handle := l.Elements[1].AsSymbol().Value - // Target columns are (sorted) permutations of source columns. - sexpTargets := l.Elements[2].AsList() - // Source columns. - sexpSources := l.Elements[3].AsList() - // Sanity check number of target colunms matches number of source columns. - if sexpTargets.Len() != sexpSources.Len() { - return p.translator.SyntaxError(l, "lookup constraint requires matching number of source and target columns") - } - // Sanity check expressions have unitary form. - for i := 0; i < sexpTargets.Len(); i++ { - // Sanity check source and target expressions do not contain expression - // forms which are not permitted within a unitary expression. - if err := p.checkUnitExpr(sexpTargets.Get(i)); err != nil { - return err - } - - if err := p.checkUnitExpr(sexpSources.Get(i)); err != nil { - return err - } - } - // Proceed with translation - targets := make([]UnitExpr, sexpTargets.Len()) - sources := make([]UnitExpr, sexpSources.Len()) - // Lookup expressions are permitted to make fully qualified accesses. This - // is because inter-module lookups are supported. - p.global = true - // Parse source / target expressions - for i := 0; i < len(targets); i++ { - target, err1 := p.translator.Translate(sexpTargets.Get(i)) - source, err2 := p.translator.Translate(sexpSources.Get(i)) - - if err1 != nil { - return err1 - } else if err2 != nil { - return err2 - } - // Done - targets[i] = UnitExpr{target} - sources[i] = UnitExpr{source} - } - // Sanity check enclosing source and target modules - sourceCtx := sc.JoinContexts(sources, p.env.schema) - targetCtx := sc.JoinContexts(targets, p.env.schema) - // Propagate errors - if sourceCtx.IsConflicted() { - return p.translator.SyntaxError(sexpSources, "conflicting evaluation context") - } else if targetCtx.IsConflicted() { - return p.translator.SyntaxError(sexpTargets, "conflicting evaluation context") - } else if sourceCtx.IsVoid() { - return p.translator.SyntaxError(sexpSources, "empty evaluation context") - } else if targetCtx.IsVoid() { - return p.translator.SyntaxError(sexpTargets, "empty evaluation context") - } - // Finally add constraint - p.env.schema.AddLookupConstraint(handle, sourceCtx, targetCtx, sources, targets) - // Done - return nil -} - -// Parse am interleaving declaration -func (p *hirParser) parseInterleavingDeclaration(l *sexp.List) error { - // Target columns are (sorted) permutations of source columns. - sexpTarget := l.Elements[1].AsSymbol() - // Source columns. - sexpSources := l.Elements[2].AsList() - // Sanity checks. - if sexpTarget == nil { - return p.translator.SyntaxError(l, "column name expected") - } else if sexpSources == nil { - return p.translator.SyntaxError(l, "source column list expected") - } - // Construct and check source columns - sources := make([]uint, sexpSources.Len()) - ctx := trace.VoidContext() - - for i := 0; i < sexpSources.Len(); i++ { - ith := sexpSources.Get(i) - col := ith.AsSymbol() - // Sanity check a symbol was found - if col == nil { - return p.translator.SyntaxError(ith, "column name expected") - } - // Attempt to lookup the column - cid, ok := p.env.LookupColumn(p.module, col.Value) - // Check it exists - if !ok { - return p.translator.SyntaxError(ith, "unknown column") - } - // Check multiplier calculation - sourceCol := p.env.schema.Columns().Nth(cid) - ctx = ctx.Join(sourceCol.Context()) - // Sanity check we have a sensible context here. - if ctx.IsConflicted() { - return p.translator.SyntaxError(sexpSources.Get(i), "conflicting evaluation context") - } else if ctx.IsVoid() { - return p.translator.SyntaxError(sexpSources.Get(i), "empty evaluation context") - } - // Assign - sources[i] = cid - } - // Add assignment - p.env.AddAssignment(assignment.NewInterleaving(ctx, sexpTarget.Value, sources, &sc.FieldType{})) - // Done - return nil -} - -// Parse a range constraint -func (p *hirParser) parseRangeDeclaration(l *sexp.List) error { - var bound fr.Element - // Check bound - if l.Get(2).AsSymbol() == nil { - return p.translator.SyntaxError(l.Get(2), "malformed bound") - } - // Parse bound - if _, err := bound.SetString(l.Get(2).AsSymbol().Value); err != nil { - return p.translator.SyntaxError(l.Get(2), "malformed bound") - } - // Parse expression - expr, err := p.translator.Translate(l.Get(1)) - if err != nil { - return err - } - // Determine evaluation context of expression. - ctx := expr.Context(p.env.schema) - // Sanity check we have a sensible context here. - if ctx.IsConflicted() { - return p.translator.SyntaxError(l.Get(1), "conflicting evaluation context") - } else if ctx.IsVoid() { - return p.translator.SyntaxError(l.Get(1), "empty evaluation context") - } - // - handle := l.Get(1).String(true) - p.env.schema.AddRangeConstraint(handle, ctx, expr, bound) - // - return nil -} - -// Parse a property assertion -func (p *hirParser) parseAssertionDeclaration(elements []sexp.SExp) error { - handle := elements[1].AsSymbol().Value - // Property assertions do not have global scope, hence qualified column - // accesses are not permitted. - p.global = false - // Translate - expr, err := p.translator.Translate(elements[2]) - if err != nil { - return err - } - // Determine evaluation context of assertion. - ctx := expr.Context(p.env.schema) - // Add assertion. - p.env.schema.AddPropertyAssertion(handle, ctx, expr) - - return nil -} - -// Parse a vanishing declaration -func (p *hirParser) parseConstraintDeclaration(elements []sexp.SExp) error { - // - handle := elements[1].AsSymbol().Value - // Vanishing constraints do not have global scope, hence qualified column - // accesses are not permitted. - p.global = false - domain, guard, err := p.parseConstraintAttributes(elements[2]) - // Check for error - if err != nil { - return err - } - // Translate expression - expr, err := p.translator.Translate(elements[3]) - if err != nil { - return err - } else if guard != nil { - // if guard != 0 then expr - expr = &IfZero{guard, nil, expr} - } - // Determine evaluation context of expression. - ctx := expr.Context(p.env.schema) - // Sanity check we have a sensible context here. - if ctx.IsConflicted() { - return p.translator.SyntaxError(elements[3], "conflicting evaluation context") - } else if ctx.IsVoid() { - return p.translator.SyntaxError(elements[3], "empty evaluation context") - } - - p.env.schema.AddVanishingConstraint(handle, ctx, domain, expr) - - return nil -} - -func (p *hirParser) parseConstraintAttributes(attributes sexp.SExp) (domain *int, guard Expr, err error) { - // Check attribute list is a list - if attributes.AsList() == nil { - return nil, nil, p.translator.SyntaxError(attributes, "expected attribute list") - } - // Deconstruct as list - attrs := attributes.AsList() - // Process each attribute in turn - for i := 0; i < attrs.Len(); i++ { - ith := attrs.Get(i) - // Check start of attribute - if ith.AsSymbol() == nil { - return nil, nil, p.translator.SyntaxError(ith, "malformed attribute") - } - // Check what we've got - switch ith.AsSymbol().Value { - case ":domain": - i++ - if domain, err = p.parseDomainAttribute(attrs.Get(i)); err != nil { - return nil, nil, err - } - case ":guard": - i++ - if guard, err = p.translator.Translate(attrs.Get(i)); err != nil { - return nil, nil, err - } - default: - return nil, nil, p.translator.SyntaxError(ith, "unknown attribute") - } - } - // Done - return domain, guard, nil -} - -func (p *hirParser) parseDomainAttribute(attribute sexp.SExp) (domain *int, err error) { - if attribute.AsSet() == nil { - return nil, p.translator.SyntaxError(attribute, "malformed domain set") - } - // Sanity check - set := attribute.AsSet() - // Check all domain elements well-formed. - for i := 0; i < set.Len(); i++ { - ith := set.Get(i) - if ith.AsSymbol() == nil { - return nil, p.translator.SyntaxError(ith, "malformed domain") - } - } - // Currently, only support domains of size 1. - if set.Len() == 1 { - first, err := strconv.Atoi(set.Get(0).AsSymbol().Value) - // Check for parse error - if err != nil { - return nil, p.translator.SyntaxError(set.Get(0), "malformed domain element") - } - // Done - return &first, nil - } - // Fail - return nil, p.translator.SyntaxError(attribute, "multiple values not supported") -} - -func (p *hirParser) parseType(term sexp.SExp) (sc.Type, error) { - symbol := term.AsSymbol() - if symbol == nil { - return nil, p.translator.SyntaxError(term, "malformed column") - } - // Access string of symbol - str := symbol.Value - if strings.HasPrefix(str, ":u") { - n, err := strconv.Atoi(str[2:]) - if err != nil { - return nil, err - } - // Done - return sc.NewUintType(uint(n)), nil - } - // Error - return nil, p.translator.SyntaxError(symbol, "unknown type") -} - -// Check that a given expression conforms to the requirements of a unitary -// expression. That is, it cannot contain an "if", "ifnot" or "begin" -// expression form. -func (p *hirParser) checkUnitExpr(term sexp.SExp) error { - l := term.AsList() - - if l != nil && l.Len() > 0 { - if head := l.Get(0).AsSymbol(); head != nil { - switch head.Value { - case "if": - fallthrough - case "ifnot": - fallthrough - case "begin": - return p.translator.SyntaxError(term, "not permitted in lookup") - } - } - // Check arguments - for i := 0; i < l.Len(); i++ { - if err := p.checkUnitExpr(l.Get(i)); err != nil { - return err - } - } - } - - return nil -} - -func beginParserRule(args []Expr) (Expr, error) { - return &List{args}, nil -} - -func constantParserRule(symbol string) (Expr, bool, error) { - if symbol[0] >= '0' && symbol[0] < '9' { - var num fr.Element - // Attempt to parse - _, err := num.SetString(symbol) - // Check for errors - if err != nil { - return nil, true, err - } - // Done - return &Constant{Val: num}, true, nil - } - // Not applicable - return nil, false, nil -} - -func columnAccessParserRule(parser *hirParser) func(col string) (Expr, bool, error) { - // Returns a closure over the parser. - return func(col string) (Expr, bool, error) { - var ok bool - // Sanity check what we have - if !unicode.IsLetter(rune(col[0])) { - return nil, false, nil - } - // Handle qualified accesses (where permitted) - context := parser.module - colname := col - // Attempt to split column name into module / column pair. - split := strings.Split(col, ".") - if parser.global && len(split) == 2 { - // Lookup module - if context, ok = parser.env.LookupModule(split[0]); !ok { - return nil, true, errors.New("unknown module") - } - - colname = split[1] - } else if len(split) > 2 { - return nil, true, errors.New("malformed column access") - } else if len(split) == 2 { - return nil, true, errors.New("qualified column access not permitted here") - } - // Now lookup column in the appropriate module. - var cid uint - // Look up column in the environment using local scope. - cid, ok = parser.env.LookupColumn(context, colname) - // Check column was found - if !ok { - return nil, true, errors.New("unknown column") - } - // Done - return &ColumnAccess{cid, 0}, true, nil - } -} - -func addParserRule(args []Expr) (Expr, error) { - return &Add{args}, nil -} - -func subParserRule(args []Expr) (Expr, error) { - return &Sub{args}, nil -} - -func mulParserRule(args []Expr) (Expr, error) { - return &Mul{args}, nil -} - -func ifParserRule(args []Expr) (Expr, error) { - if len(args) == 2 { - return &IfZero{args[0], args[1], nil}, nil - } else if len(args) == 3 { - return &IfZero{args[0], args[1], args[2]}, nil - } - - return nil, errors.New("incorrect number of arguments") -} - -func ifNotParserRule(args []Expr) (Expr, error) { - if len(args) == 2 { - return &IfZero{args[0], nil, args[1]}, nil - } - - return nil, errors.New("incorrect number of arguments") -} - -func shiftParserRule(parser *hirParser) func(string, string) (Expr, error) { - // Returns a closure over the parser. - return func(col string, amt string) (Expr, error) { - n, err := strconv.Atoi(amt) - - if err != nil { - return nil, err - } - // Look up column in the environment - i, ok := parser.env.LookupColumn(parser.module, col) - // Check column was found - if !ok { - return nil, fmt.Errorf("unknown column %s", col) - } - // Done - return &ColumnAccess{ - Column: i, - Shift: n, - }, nil - } -} - -func powParserRule(args []Expr) (Expr, error) { - var k big.Int - - if len(args) != 2 { - return nil, errors.New("incorrect number of arguments") - } - - c, ok := args[1].(*Constant) - if !ok { - return nil, errors.New("expected constant power") - } else if !c.Val.IsUint64() { - return nil, errors.New("constant power too large") - } - // Convert power to uint64 - c.Val.BigInt(&k) - // Done - return &Exp{Arg: args[0], Pow: k.Uint64()}, nil -} - -func normParserRule(args []Expr) (Expr, error) { - if len(args) != 1 { - return nil, errors.New("incorrect number of arguments") - } - - return &Normalise{Arg: args[0]}, nil -} diff --git a/pkg/hir/schema.go b/pkg/hir/schema.go index 0d8d5a3..3373288 100644 --- a/pkg/hir/schema.go +++ b/pkg/hir/schema.go @@ -51,6 +51,9 @@ type Schema struct { assertions []PropertyAssertion // Cache list of columns declared in inputs and assignments. column_cache []sc.Column + // Macros determines the set of macros which can be called within + // expressions, etc. + macros []*MacroDefinition } // EmptySchema is used to construct a fresh schema onto which new columns and @@ -143,6 +146,15 @@ func (p *Schema) AddPropertyAssertion(handle string, context trace.Context, prop p.assertions = append(p.assertions, sc.NewPropertyAssertion[ZeroArrayTest](handle, context, ZeroArrayTest{property})) } +// AddMacroDefinition adds a definition for a macro (either pure or impure). +func (p *Schema) AddMacroDefinition(module uint, name string, params []string, body Expr, pure bool) uint { + index := p.Macros().Count() + macro := &MacroDefinition{module, name, params, body, pure} + p.macros = append(p.macros, macro) + + return index +} + // ============================================================================ // Schema Interface // ============================================================================ @@ -192,6 +204,11 @@ func (p *Schema) Declarations() util.Iterator[sc.Declaration] { return inputs.Append(ps) } +// Macros returns an array over the macro definitions available in this schema. +func (p *Schema) Macros() util.Iterator[*MacroDefinition] { + return util.NewArrayIterator(p.macros) +} + // Modules returns an iterator over the declared set of modules within this // schema. func (p *Schema) Modules() util.Iterator[sc.Module] { diff --git a/pkg/sexp/error.go b/pkg/sexp/error.go deleted file mode 100644 index da641f5..0000000 --- a/pkg/sexp/error.go +++ /dev/null @@ -1,34 +0,0 @@ -package sexp - -import ( - "fmt" -) - -// SyntaxError is a structured error which retains the index into the original -// string where an error occurred, along with an error message. -type SyntaxError struct { - // Byte index into string being parsed where error arose. - span Span - // Error message being reported - msg string -} - -// NewSyntaxError simply constructs a new syntax error. -func NewSyntaxError(span Span, msg string) *SyntaxError { - return &SyntaxError{span, msg} -} - -// Span returns the span of the original text on which this error is reported. -func (p *SyntaxError) Span() Span { - return p.span -} - -// Message returns the message to be reported. -func (p *SyntaxError) Message() string { - return p.msg -} - -// Error implements the error interface. -func (p *SyntaxError) Error() string { - return fmt.Sprintf("%d:%d:%s", p.span.Start(), p.span.End(), p.Message()) -} diff --git a/pkg/sexp/parser.go b/pkg/sexp/parser.go index 27535f6..ef8a23d 100644 --- a/pkg/sexp/parser.go +++ b/pkg/sexp/parser.go @@ -4,41 +4,27 @@ import ( "unicode" ) -// Parse a given string into an S-expression, or return an error if the string -// is malformed. -func Parse(s string) (SExp, error) { - p := NewParser(s) - // Parse the input - sExp, err := p.Parse() - // Sanity check everything was parsed - if err == nil && p.index != len(p.text) { - return nil, p.error("unexpected remainder") - } - - return sExp, err -} - // Parser represents a parser in the process of parsing a given string into one // or more S-expressions. type Parser struct { - // Text being parsed + // Source file being parsed + srcfile *SourceFile + // Cache (for simplicity) text []rune // Determine current position within text index int - // Mapping from construct S-Expressions to their spans in the original text. + // Mapping from constructed S-Expressions to their spans in the original text. srcmap *SourceMap[SExp] } // NewParser constructs a new instance of Parser -func NewParser(text string) *Parser { - // Convert string into array of runes. This is necessary to properly handle - // unicode. - runes := []rune(text) +func NewParser(srcfile *SourceFile) *Parser { // Construct initial parser. return &Parser{ - text: runes, - index: 0, - srcmap: NewSourceMap[SExp](runes), + srcfile: srcfile, + text: srcfile.Contents(), + index: 0, + srcmap: NewSourceMap[SExp](srcfile.Contents()), } } @@ -49,27 +35,13 @@ func (p *Parser) SourceMap() *SourceMap[SExp] { return p.srcmap } -// ParseAll parses the input string into zero or more S-expressions, whilst -// returning an error if the string is malformed. -func (p *Parser) ParseAll() ([]SExp, error) { - terms := make([]SExp, 0) - // Parse the input - for { - term, err := p.Parse() - // Sanity check everything was parsed - if err != nil { - return terms, err - } else if term == nil { - // EOF reached - return terms, nil - } - - terms = append(terms, term) - } +// Text returns the underlying text for this parser. +func (p *Parser) Text() []rune { + return p.text } // Parse a given string into an S-Expression, or produce an error. -func (p *Parser) Parse() (SExp, error) { +func (p *Parser) Parse() (SExp, *SyntaxError) { var term SExp // Skip over any whitespace. This is import to get the correct starting // point for this term. @@ -187,7 +159,7 @@ func (p *Parser) parseSymbol() []rune { return token } -func (p *Parser) parseSequence(terminator rune) ([]SExp, error) { +func (p *Parser) parseSequence(terminator rune) ([]SExp, *SyntaxError) { var elements []SExp for c := p.Lookahead(0); c == nil || *c != terminator; c = p.Lookahead(0) { @@ -211,5 +183,5 @@ func (p *Parser) parseSequence(terminator rune) ([]SExp, error) { // Construct a parser error at the current position in the input stream. func (p *Parser) error(msg string) *SyntaxError { span := NewSpan(p.index, p.index+1) - return &SyntaxError{span, msg} + return p.srcfile.SyntaxError(span, msg) } diff --git a/pkg/sexp/sexp_test.go b/pkg/sexp/sexp_test.go index ea2159a..cf705bd 100644 --- a/pkg/sexp/sexp_test.go +++ b/pkg/sexp/sexp_test.go @@ -141,7 +141,8 @@ func TestSexp_Err4(t *testing.T) { // ============================================================================ func CheckOk(t *testing.T, sexp1 SExp, input string) { - sexp2, err := Parse(input) + src := NewSourceFile("test", []byte(input)) + sexp2, _, err := src.Parse() // if err != nil { t.Error(err) @@ -151,7 +152,9 @@ func CheckOk(t *testing.T, sexp1 SExp, input string) { } func CheckErr(t *testing.T, input string) { - _, err := Parse(input) + src := NewSourceFile("test", []byte(input)) + _, _, err := src.Parse() + // if err == nil { t.Errorf("input should not have parsed!") diff --git a/pkg/sexp/source_file.go b/pkg/sexp/source_file.go new file mode 100644 index 0000000..901df37 --- /dev/null +++ b/pkg/sexp/source_file.go @@ -0,0 +1,103 @@ +package sexp + +import ( + "fmt" +) + +// SourceFile represents a given source file (typically stored on disk). +type SourceFile struct { + // File name for this source file. + filename string + // Contents of this file. + contents []rune +} + +// NewSourceFile constructs a new source file from a given byte array. +func NewSourceFile(filename string, bytes []byte) *SourceFile { + // Convert bytes into runes for easier parsing + contents := []rune(string(bytes)) + return &SourceFile{filename, contents} +} + +// Filename returns the filename associated with this source file. +func (s *SourceFile) Filename() string { + return s.filename +} + +// Contents returns the contents of this source file. +func (s *SourceFile) Contents() []rune { + return s.contents +} + +// Parse a given string into an S-expression, or return an error if the string +// is malformed. A source map is also returned for debugging purposes. +func (s *SourceFile) Parse() (SExp, *SourceMap[SExp], *SyntaxError) { + p := NewParser(s) + // Parse the input + sExp, err := p.Parse() + // Sanity check everything was parsed + if err == nil && p.index != len(p.text) { + return nil, nil, p.error("unexpected remainder") + } + // Done + return sExp, p.SourceMap(), err +} + +// ParseAll converts a given string into zero or more S-expressions, or returns +// an error if the string is malformed. A source map is also returned for +// debugging purposes. The key distinction from Parse is that this function +// continues parsing after the first S-expression is encountered. +func (s *SourceFile) ParseAll() ([]SExp, *SourceMap[SExp], *SyntaxError) { + p := NewParser(s) + // + terms := make([]SExp, 0) + // Parse the input + for { + term, err := p.Parse() + // Sanity check everything was parsed + if err != nil { + return terms, p.srcmap, err + } else if term == nil { + // EOF reached + return terms, p.srcmap, nil + } + + terms = append(terms, term) + } +} + +// SyntaxError constructs a syntax error over a given span of this file with a +// given message. +func (s *SourceFile) SyntaxError(span Span, msg string) *SyntaxError { + return &SyntaxError{s, span, msg} +} + +// SyntaxError is a structured error which retains the index into the original +// string where an error occurred, along with an error message. +type SyntaxError struct { + srcfile *SourceFile + // Byte index into string being parsed where error arose. + span Span + // Error message being reported + msg string +} + +// SourceFile returns the underlying source file that this syntax error covers. +func (p *SyntaxError) SourceFile() *SourceFile { + return p.srcfile +} + +// Span returns the span of the original text on which this error is reported. +func (p *SyntaxError) Span() Span { + return p.span +} + +// Message returns the message to be reported. +func (p *SyntaxError) Message() string { + return p.msg +} + +// Error implements the error interface. +func (p *SyntaxError) Error() string { + return fmt.Sprintf("%d:%d:%s", p.span.Start(), p.span.End(), p.Message()) +} diff --git a/pkg/sexp/source_map.go b/pkg/sexp/source_map.go index d3f611c..fb70a35 100644 --- a/pkg/sexp/source_map.go +++ b/pkg/sexp/source_map.go @@ -75,6 +75,25 @@ func (p *Line) Length() int { return p.span.Length() } +// SourceMaps provides a mechanism for mapping terms from an AST to multiple +// source files. +type SourceMaps[T comparable] struct { + // Arrray of known source maps. + maps []SourceMap[T] +} + +// NewSourceMaps constructs an (initially empty) set of source maps. The +// intention is that this is populated as each file is parsed. +func NewSourceMaps[T comparable]() *SourceMaps[T] { + return &SourceMaps[T]{[]SourceMap[T]{}} +} + +// Join a given source map into this set of source maps. The effect of this is +// that nodes recorded in the given source map can be accessed from this set. +func (p *SourceMaps[T]) Join(srcmap *SourceMap[T]) { + p.maps = append(p.maps, *srcmap) +} + // SourceMap maps terms from an AST to slices of their originating string. This // is important for error handling when we wish to highlight exactly where, in // the original source file, a given error has arisen. @@ -94,6 +113,11 @@ func NewSourceMap[T comparable](text []rune) *SourceMap[T] { return &SourceMap[T]{mapping, text} } +// Text returns underlying text of this source map. +func (p *SourceMap[T]) Text() []rune { + return p.text +} + // Put registers a new AST item with a given span. Note, if the item exists // already, then it will panic. func (p *SourceMap[T]) Put(item T, span Span) { diff --git a/pkg/sexp/translator.go b/pkg/sexp/translator.go index faf62c2..a3bb486 100644 --- a/pkg/sexp/translator.go +++ b/pkg/sexp/translator.go @@ -1,8 +1,6 @@ package sexp -import ( - "fmt" -) +import "fmt" // SymbolRule is a symbol generator is responsible for converting a terminating // expression (i.e. a symbol) into an expression type T. For @@ -13,7 +11,7 @@ type SymbolRule[T comparable] func(string) (T, bool, error) // sequence of zero or more arguments into an expression type T. // Observe that the arguments are already translated into the correct // form. -type ListRule[T comparable] func(*List) (T, error) +type ListRule[T comparable] func(*List) (T, *SyntaxError) // BinaryRule is a binary translator is a wrapper for translating lists which must // have exactly two symbol arguments. The wrapper takes care of @@ -23,7 +21,7 @@ type BinaryRule[T comparable] func(string, string) (T, error) // RecursiveRule is a recursive translator is a wrapper for translating lists whose // elements can be built by recursively reusing the enclosing // translator. -type RecursiveRule[T comparable] func([]T) (T, error) +type RecursiveRule[T comparable] func(string, []T) (T, error) // =================================================================== // Parser @@ -32,7 +30,12 @@ type RecursiveRule[T comparable] func([]T) (T, error) // Translator is a generic mechanism for translating S-Expressions into a structured // form. type Translator[T comparable] struct { - lists map[string]ListRule[T] + srcfile *SourceFile + // Rules for parsing lists + lists map[string]ListRule[T] + // Fallback rule for generic user-defined lists. + list_default ListRule[T] + // Rules for parsing symbols symbols []SymbolRule[T] // Maps S-Expressions to their spans in the original source file. This is // used to build the new source map. @@ -43,12 +46,14 @@ type Translator[T comparable] struct { } // NewTranslator constructs a new Translator instance. -func NewTranslator[T comparable](srcmap *SourceMap[SExp]) *Translator[T] { +func NewTranslator[T comparable](srcfile *SourceFile, srcmap *SourceMap[SExp]) *Translator[T] { return &Translator[T]{ - lists: make(map[string]ListRule[T]), - symbols: make([]SymbolRule[T], 0), - old_srcmap: srcmap, - new_srcmap: NewSourceMap[T](srcmap.text), + srcfile: srcfile, + lists: make(map[string]ListRule[T]), + list_default: nil, + symbols: make([]SymbolRule[T], 0), + old_srcmap: srcmap, + new_srcmap: NewSourceMap[T](srcmap.text), } } @@ -56,23 +61,9 @@ func NewTranslator[T comparable](srcmap *SourceMap[SExp]) *Translator[T] { // Public // =================================================================== -// ParseAndTranslate a given string into a given structured representation T -// using an appropriately configured. -func (p *Translator[T]) ParseAndTranslate(s string) (T, error) { - // Parse string into S-expression form - e, err := Parse(s) - if err != nil { - var empty T - return empty, err - } - - // Process S-expression into AIR expression. - return translateSExp(p, e) -} - // Translate a given string into a given structured representation T // using an appropriately configured. -func (p *Translator[T]) Translate(sexp SExp) (T, error) { +func (p *Translator[T]) Translate(sexp SExp) (T, *SyntaxError) { // Process S-expression into target expression return translateSExp(p, sexp) } @@ -80,14 +71,31 @@ func (p *Translator[T]) Translate(sexp SExp) (T, error) { // AddRecursiveRule adds a new list translator to this expression translator. func (p *Translator[T]) AddRecursiveRule(name string, t RecursiveRule[T]) { // Construct a recursive list translator as a wrapper around a generic list translator. - p.lists[name] = func(l *List) (T, error) { - var ( - empty T - err error - ) + p.lists[name] = p.createRecursiveRule(t) +} + +// AddDefaultRecursiveRule adds a default recursive rule to be applied when no +// other recursive rules apply. +func (p *Translator[T]) AddDefaultRecursiveRule(t RecursiveRule[T]) { + // Construct a recursive list translator as a wrapper around a generic list translator. + p.list_default = p.createRecursiveRule(t) +} + +func (p *Translator[T]) createRecursiveRule(t RecursiveRule[T]) ListRule[T] { + // Construct a recursive list translator as a wrapper around a generic list translator. + return func(l *List) (T, *SyntaxError) { + var empty T + // Extract the "head" of the list. + if len(l.Elements) == 0 || l.Elements[0].AsSymbol() == nil { + return empty, p.SyntaxError(l, "invalid list") + } + // Extract expression name + head := (l.Elements[0].(*Symbol)).Value // Translate arguments args := make([]T, len(l.Elements)-1) + // for i, s := range l.Elements[1:] { + var err *SyntaxError args[i], err = translateSExp(p, s) // Handle error if err != nil { @@ -95,7 +103,7 @@ func (p *Translator[T]) AddRecursiveRule(name string, t RecursiveRule[T]) { } } // Apply constructor - term, err := t(args) + term, err := t(head, args) // Check for error if err == nil { return term, nil @@ -109,7 +117,7 @@ func (p *Translator[T]) AddRecursiveRule(name string, t RecursiveRule[T]) { func (p *Translator[T]) AddBinaryRule(name string, t BinaryRule[T]) { var empty T // - p.lists[name] = func(l *List) (T, error) { + p.lists[name] = func(l *List) (T, *SyntaxError) { if len(l.Elements) != 3 { // Should be unreachable. return empty, p.SyntaxError(l, "Incorrect number of arguments") @@ -141,11 +149,11 @@ func (p *Translator[T]) AddSymbolRule(t SymbolRule[T]) { } // SyntaxError constructs a suitable syntax error for a given S-Expression. -func (p *Translator[T]) SyntaxError(s SExp, msg string) error { +func (p *Translator[T]) SyntaxError(s SExp, msg string) *SyntaxError { // Get span of enclosing list span := p.old_srcmap.Get(s) - // This should be unreachable. - return NewSyntaxError(span, msg) + // Construct syntax error + return p.srcfile.SyntaxError(span, msg) } // =================================================================== @@ -155,7 +163,7 @@ func (p *Translator[T]) SyntaxError(s SExp, msg string) error { // Translate an S-Expression into an IR expression. Observe that // this can still fail in the event that the given S-Expression does // not describe a well-formed IR expression. -func translateSExp[T comparable](p *Translator[T], s SExp) (T, error) { +func translateSExp[T comparable](p *Translator[T], s SExp) (T, *SyntaxError) { var empty T switch e := s.(type) { @@ -180,7 +188,7 @@ func translateSExp[T comparable](p *Translator[T], s SExp) (T, error) { // expression of some kind. This type of expression is determined by // the first element of the list. The remaining elements are treated // as arguments which are first recursively translated. -func translateSExpList[T comparable](p *Translator[T], l *List) (T, error) { +func translateSExpList[T comparable](p *Translator[T], l *List) (T, *SyntaxError) { var empty T // Sanity check this list makes sense if len(l.Elements) == 0 || l.Elements[0].AsSymbol() == nil { @@ -193,6 +201,8 @@ func translateSExpList[T comparable](p *Translator[T], l *List) (T, error) { // Check whether we found one. if t != nil { return (t)(l) + } else if p.list_default != nil { + return (p.list_default)(l) } // Default fall back return empty, p.SyntaxError(l, "unknown list encountered") diff --git a/pkg/test/ir_test.go b/pkg/test/ir_test.go index c6c2ae3..c9b4db8 100644 --- a/pkg/test/ir_test.go +++ b/pkg/test/ir_test.go @@ -9,8 +9,10 @@ import ( "strings" "testing" + "github.com/consensys/go-corset/pkg/corset" "github.com/consensys/go-corset/pkg/hir" sc "github.com/consensys/go-corset/pkg/schema" + "github.com/consensys/go-corset/pkg/sexp" "github.com/consensys/go-corset/pkg/trace" "github.com/consensys/go-corset/pkg/trace/json" ) @@ -470,6 +472,18 @@ func Test_Interleave_04(t *testing.T) { Check(t, "interleave_04") } +// =================================================================== +// Functions +// =================================================================== + +func Test_PureFun_01(t *testing.T) { + Check(t, "purefun_01") +} + +func Test_PureFun_02(t *testing.T) { + Check(t, "purefun_02") +} + // =================================================================== // Complex Tests // =================================================================== @@ -530,19 +544,22 @@ const MAX_PADDING uint = 7 // expect to be accepted are accepted, and all traces that we expect // to be rejected are rejected. func Check(t *testing.T, test string) { + filename := fmt.Sprintf("%s.lisp", test) // Enable testing each trace in parallel t.Parallel() // Read constraints file - bytes, err := os.ReadFile(fmt.Sprintf("%s/%s.lisp", TestDir, test)) + bytes, err := os.ReadFile(fmt.Sprintf("%s/%s", TestDir, filename)) // Check test file read ok if err != nil { t.Fatal(err) } + // Package up as source file + srcfile := sexp.NewSourceFile(filename, bytes) // Parse terms into an HIR schema - schema, err := hir.ParseSchemaString(string(bytes)) + schema, errs := corset.CompileSourceFile(srcfile) // Check terms parsed ok - if err != nil { - t.Fatalf("Error parsing %s.lisp: %s\n", test, err) + if len(errs) > 0 { + t.Fatalf("Error parsing %s: %s\n", filename, errs) } // Check valid traces are accepted accepts_file := fmt.Sprintf("%s.%s", test, "accepts") diff --git a/pkg/util/maps.go b/pkg/util/maps.go new file mode 100644 index 0000000..9cb08d7 --- /dev/null +++ b/pkg/util/maps.go @@ -0,0 +1,11 @@ +package util + +// ShallowCloneMap makes a shallow clone of a given map. +func ShallowCloneMap[K comparable, V any](orig map[K]V) map[K]V { + new := make(map[K]V, len(orig)) + for key, value := range orig { + new[key] = value + } + + return new +} diff --git a/testdata/purefun_01.accepts b/testdata/purefun_01.accepts new file mode 100644 index 0000000..945d190 --- /dev/null +++ b/testdata/purefun_01.accepts @@ -0,0 +1,6 @@ +{ "A": [0] } +{ "A": [0,0] } +{ "A": [0,0,0] } +{ "A": [0,0,0,0] } +{ "A": [0,0,0,0,0] } +{ "A": [0,0,0,0,0,0] } diff --git a/testdata/purefun_01.lisp b/testdata/purefun_01.lisp new file mode 100644 index 0000000..7aafa59 --- /dev/null +++ b/testdata/purefun_01.lisp @@ -0,0 +1,3 @@ +(defcolumns A) +(defpurefun (id x) x) +(defconstraint test () (id A)) diff --git a/testdata/purefun_01.rejects b/testdata/purefun_01.rejects new file mode 100644 index 0000000..a5f8491 --- /dev/null +++ b/testdata/purefun_01.rejects @@ -0,0 +1,5 @@ +{ "A": [1] } +{ "A": [2] } +{ "A": [1,1] } +{ "A": [1,1] } +{ "A": [2,1] } diff --git a/testdata/purefun_02.accepts b/testdata/purefun_02.accepts new file mode 100644 index 0000000..5f89401 --- /dev/null +++ b/testdata/purefun_02.accepts @@ -0,0 +1,15 @@ +{ "A": [], "B": [] } +{ "A": [0], "B": [0] } +{ "A": [1], "B": [1] } +{ "A": [2], "B": [2] } +{ "A": [3], "B": [3] } +{ "A": [4], "B": [4] } +;; +{ "A": [0,0], "B": [0,0] } +{ "A": [1,0], "B": [1,0] } +{ "A": [0,1], "B": [0,1] } +{ "A": [1,1], "B": [1,1] } +;; +{ "A": [125,0], "B": [125,0] } +{ "A": [0,125], "B": [0,125] } +{ "A": [125,125], "B": [125,125] } diff --git a/testdata/purefun_02.lisp b/testdata/purefun_02.lisp new file mode 100644 index 0000000..5b5f120 --- /dev/null +++ b/testdata/purefun_02.lisp @@ -0,0 +1,3 @@ +(defcolumns A B) +(defpurefun (eq x y) (- y x)) +(defconstraint test () (eq A B)) diff --git a/testdata/purefun_02.rejects b/testdata/purefun_02.rejects new file mode 100644 index 0000000..7ec90d0 --- /dev/null +++ b/testdata/purefun_02.rejects @@ -0,0 +1,35 @@ +{ "A": [0], "B": [1] } +{ "A": [1], "B": [0] } +{ "A": [0], "B": [1] } +{ "A": [0], "B": [2] } +{ "A": [0], "B": [3] } +{ "A": [1], "B": [0] } +{ "A": [2], "B": [0] } +{ "A": [3], "B": [0] } +;; +{ "A": [0,0], "B": [0,1] } +{ "A": [1,0], "B": [0,0] } +{ "A": [0,0], "B": [1,0] } +{ "A": [0,1], "B": [0,0] } +{ "A": [0,0], "B": [1,1] } +{ "A": [1,1], "B": [0,0] } +{ "A": [1,0], "B": [0,1] } +{ "A": [0,1], "B": [1,0] } +;; +{ "A": [0,0], "B": [0,125] } +{ "A": [125,0], "B": [0,0] } +{ "A": [0,0], "B": [125,0] } +{ "A": [0,125], "B": [0,0] } +{ "A": [0,0], "B": [125,125] } +{ "A": [125,125], "B": [0,0] } +{ "A": [125,0], "B": [0,125] } +{ "A": [0,125], "B": [125,0] } +;; +{ "A": [65,65], "B": [65,65573234] } +{ "A": [65573234,65], "B": [65,65] } +{ "A": [65,65], "B": [65573234,65] } +{ "A": [65,65573234], "B": [65,65] } +{ "A": [65,65], "B": [65573234,65573234] } +{ "A": [65573234,65573234], "B": [65,65] } +{ "A": [65573234,65], "B": [65,65573234] } +{ "A": [65,65573234], "B": [65573234,65] }