From 83af113d19c08bb0fcd444ab4540c27238ae5a23 Mon Sep 17 00:00:00 2001 From: DavePearce Date: Mon, 2 Dec 2024 16:51:16 +1300 Subject: [PATCH] Rework scope resolution This reworks the way in which scopes are resolved to avoid using the Environment until the very last moment in the compilation process. This just means we don't have to get too concerned about specific column identifiers until it really matters. It also means we can retain more detailed meta-information, allowing better analysis. --- pkg/air/expr.go | 2 +- pkg/binfile/computation.go | 2 +- pkg/corset/ast.go | 89 +++++--------- pkg/corset/binding.go | 75 +++++++---- pkg/corset/compiler.go | 6 +- pkg/corset/environment.go | 246 ++++++++++--------------------------- pkg/corset/parser.go | 37 +++--- pkg/corset/resolver.go | 199 ++++++++++++++---------------- pkg/corset/scope.go | 154 +++++++++++++++-------- pkg/corset/translator.go | 77 ++++++------ pkg/hir/expr.go | 2 +- pkg/mir/expr.go | 2 +- pkg/schema/schemas.go | 4 +- pkg/trace/context.go | 45 ++++--- 14 files changed, 438 insertions(+), 502 deletions(-) diff --git a/pkg/air/expr.go b/pkg/air/expr.go index cbbb691d..d35bfda9 100644 --- a/pkg/air/expr.go +++ b/pkg/air/expr.go @@ -213,7 +213,7 @@ func NewConst64(val uint64) Expr { // Context determines the evaluation context (i.e. enclosing module) for this // expression. func (p *Constant) Context(schema sc.Schema) trace.Context { - return trace.VoidContext() + return trace.VoidContext[uint]() } // RequiredColumns returns the set of columns on which this term depends. diff --git a/pkg/binfile/computation.go b/pkg/binfile/computation.go index 6a468c88..e3f44590 100644 --- a/pkg/binfile/computation.go +++ b/pkg/binfile/computation.go @@ -127,7 +127,7 @@ func sourceColumnsFromHandles(handles []string, columns []column, // Convert source refs into column indexes sources := make([]uint, len(sourceIDs)) // - ctx := trace.VoidContext() + ctx := trace.VoidContext[uint]() // for i, source_id := range sourceIDs { // Determine schema column index for ith source column. diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go index 2027f2f0..a0f7ceb3 100644 --- a/pkg/corset/ast.go +++ b/pkg/corset/ast.go @@ -177,19 +177,6 @@ type DefInterleaved struct { Sources []*DefName } -// CanFinalise checks whether or not this interleaving is ready to be finalised. -// Specifically, it checks whether or not the source columns of this -// interleaving are themselves finalised. -func (p *DefInterleaved) CanFinalise(module uint, env *Environment) bool { - for _, col := range p.Sources { - if !env.IsColumnFinalised(module, col.Name) { - return false - } - } - // - return true -} - // IsDeclaration needed to signal declaration. func (p *DefInterleaved) IsDeclaration() {} @@ -238,46 +225,19 @@ func (p *DefLookup) Lisp() sexp.SExp { // source columns can be specified as increasing or decreasing. type DefPermutation struct { Targets []*DefColumn - Sources []*DefPermutedColumn + Sources []*DefName + Signs []bool } // IsDeclaration needed to signal declaration. func (p *DefPermutation) IsDeclaration() {} -// CanFinalise checks whether or not this permutation is ready to be finalised. -// Specifically, it checks whether or not the source columns of this permutation -// are themselves finalised. -func (p *DefPermutation) CanFinalise(module uint, env *Environment) bool { - for _, col := range p.Sources { - if !env.IsColumnFinalised(module, col.Name) { - return false - } - } - // - return true -} - // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. func (p *DefPermutation) Lisp() sexp.SExp { panic("got here") } -// DefPermutedColumn provides information about a column being permuted by a -// sorted permutation. -type DefPermutedColumn struct { - // Name of the column to be permuted - Name string - // Sign of the column - Sign bool -} - -// Lisp converts this node into its lisp representation. This is primarily used -// for debugging purposes. -func (p *DefPermutedColumn) Lisp() sexp.SExp { - panic("got here") -} - // DefProperty represents an assertion to be used only for debugging / testing / // verification. Unlike vanishing constraints, property assertions do not // represent something that the prover can enforce. Rather, they represent @@ -370,7 +330,7 @@ type Expr interface { Node // Multiplicity defines the number of values which will be returned when // evaluating this expression. Due to the nature of expressions in Corset, - // they can (perhaps) surprisingly return multiple values. For example, + // they can (perhaps surprisingly) return multiple values. For example, // lists return one value for each element in the list. Note, every // expression must return at least one value. Multiplicity() uint @@ -378,13 +338,16 @@ type Expr interface { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). - Context() tr.Context + Context() Context // Substitute all variables (such as for function parameters) arising in // this expression. Substitute(args []Expr) Expr } +// Context represents the evaluation context for a given expression. +type Context = tr.RawContext[string] + // ============================================================================ // Addition // ============================================================================ @@ -401,7 +364,7 @@ func (e *Add) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Add) Context() tr.Context { +func (e *Add) Context() Context { return ContextOfExpressions(e.Args) } @@ -433,8 +396,8 @@ func (e *Constant) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Constant) Context() tr.Context { - return tr.VoidContext() +func (e *Constant) Context() Context { + return tr.VoidContext[string]() } // Lisp converts this schema element into a simple S-Expression, for example @@ -468,7 +431,7 @@ func (e *Exp) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Exp) Context() tr.Context { +func (e *Exp) Context() Context { return ContextOfExpressions([]Expr{e.Arg}) } @@ -508,7 +471,7 @@ func (e *IfZero) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *IfZero) Context() tr.Context { +func (e *IfZero) Context() Context { return ContextOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) } @@ -543,7 +506,7 @@ func (e *List) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *List) Context() tr.Context { +func (e *List) Context() Context { return ContextOfExpressions(e.Args) } @@ -575,7 +538,7 @@ func (e *Mul) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Mul) Context() tr.Context { +func (e *Mul) Context() Context { return ContextOfExpressions(e.Args) } @@ -608,7 +571,7 @@ func (e *Normalise) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Normalise) Context() tr.Context { +func (e *Normalise) Context() Context { return ContextOfExpressions([]Expr{e.Arg}) } @@ -640,7 +603,7 @@ func (e *Sub) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Sub) Context() tr.Context { +func (e *Sub) Context() Context { return ContextOfExpressions(e.Args) } @@ -671,7 +634,7 @@ type Invoke struct { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *Invoke) Context() tr.Context { +func (e *Invoke) Context() Context { if e.Binding == nil { panic("unresolved expressions encountered whilst resolving context") } @@ -720,12 +683,16 @@ func (e *VariableAccess) Multiplicity() uint { // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). -func (e *VariableAccess) Context() tr.Context { - if e.Binding == nil { - panic("unresolved expressions encountered whilst resolving context") +func (e *VariableAccess) Context() Context { + binding, ok := e.Binding.(*ColumnBinding) + // + if ok { + return binding.Context() + } else if binding == nil { + panic("unresolved column access") } - // Extract saved context - return e.Binding.Context() + // + panic("invalid column access") } // Lisp converts this schema element into a simple S-Expression, for example @@ -758,8 +725,8 @@ func (e *VariableAccess) Substitute(args []Expr) Expr { // they are all constants) then the void context is returned. Likewise, if // there are expressions with different contexts then the conflicted context // will be returned. Otherwise, the one consistent context will be returned. -func ContextOfExpressions(exprs []Expr) tr.Context { - context := tr.VoidContext() +func ContextOfExpressions(exprs []Expr) Context { + context := tr.VoidContext[string]() // for _, e := range exprs { context = context.Join(e.Context()) diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go index fc427e5c..e17c26fd 100644 --- a/pkg/corset/binding.go +++ b/pkg/corset/binding.go @@ -1,32 +1,71 @@ package corset import ( + "math" + + sc "github.com/consensys/go-corset/pkg/schema" tr "github.com/consensys/go-corset/pkg/trace" ) +// BindingId is an identifier is used to distinguish different forms of binding, +// as some forms are known from their use. Specifically, at the current time, +// only functions are distinguished from other categories (e.g. columns, +// parameters, etc). +type BindingId struct { + // Name of the binding + name string + // Indicates whether function binding or other. + fn bool +} + // Binding represents an association between a name, as found in a source file, // and concrete item (e.g. a column, function, etc). type Binding interface { // Returns the context associated with this binding. - Context() tr.Context + IsBinding() } // ColumnBinding represents something bound to a given column. type ColumnBinding struct { - // For a column access, this identifies the enclosing context. - context tr.Context - // Identifies the variable or column index (as appropriate). - index uint + // Column's allocated identifier + cid uint + // Column's enclosing module + module string + // Determines whether this is a computed column, or not. + computed bool + // Column's length multiplier + multiplier uint + // Column's datatype + datatype sc.Type +} + +// NewColumnBinding constructs a new column binding in a given module. +func NewColumnBinding(module string, computed bool, multiplier uint, datatype sc.Type) *ColumnBinding { + return &ColumnBinding{math.MaxUint, module, computed, multiplier, datatype} +} + +// IsBinding ensures this is an instance of Binding. +func (p *ColumnBinding) IsBinding() {} + +// Context returns the of this column. That is, the module in which this colunm +// was declared and also the length multiplier of that module it requires. +func (p *ColumnBinding) Context() Context { + return tr.NewContext(p.module, p.multiplier) } -// Context returns the enclosing context for this column access. -func (p *ColumnBinding) Context() tr.Context { - return p.context +// AllocateId allocates the column identifier for this column +func (p *ColumnBinding) AllocateId(cid uint) { + p.cid = cid } -// ColumnID returns the column identifier that this column access refers to. -func (p *ColumnBinding) ColumnID() uint { - return p.index +// ColumnId returns the allocated identifier for this column. NOTE: this will +// panic if this column has not yet been allocated an identifier. +func (p *ColumnBinding) ColumnId() uint { + if p.cid == math.MaxUint { + panic("column id not yet allocated") + } + // + return p.cid } // ParameterBinding represents something bound to a given column. @@ -35,11 +74,8 @@ type ParameterBinding struct { index uint } -// Context for a parameter is always void, as it does not correspond to a column -// in given module. -func (p *ParameterBinding) Context() tr.Context { - return tr.VoidContext() -} +// IsBinding ensures this is an instance of Binding. +func (p *ParameterBinding) IsBinding() {} // FunctionBinding represents the binding of a function application to its // physical definition. @@ -50,11 +86,8 @@ type FunctionBinding struct { body Expr } -// Context for a parameter is always void, as it does not correspond to a column -// in given module. -func (p *FunctionBinding) Context() tr.Context { - return tr.VoidContext() -} +// IsBinding ensures this is an instance of Binding. +func (p *FunctionBinding) IsBinding() {} // Apply a given set of arguments to this function binding. func (p *FunctionBinding) Apply(args []Expr) Expr { diff --git a/pkg/corset/compiler.go b/pkg/corset/compiler.go index 487f88ee..0906fd1e 100644 --- a/pkg/corset/compiler.go +++ b/pkg/corset/compiler.go @@ -64,15 +64,17 @@ func NewCompiler(circuit Circuit, srcmaps *sexp.SourceMaps[Node]) *Compiler { // etc. func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) { // Resolve variables (via nested scopes) - env, errs := ResolveCircuit(p.srcmap, &p.circuit) + scope, errs := ResolveCircuit(p.srcmap, &p.circuit) // Check whether any errors were encountered. If so, terminate since we // cannot proceed with translation. if len(errs) != 0 { return nil, errs } + // Convert global scope into an environment by allocating all columns. + environment := scope.ToEnvironment() // Check constraint contexts (e.g. for constraints, lookups, etc) // Type check constraints fmt.Println("Translating Circuit...") // Finally, translate everything and add it to the schema. - return TranslateCircuit(env, p.srcmap, &p.circuit) + return TranslateCircuit(environment, p.srcmap, &p.circuit) } diff --git a/pkg/corset/environment.go b/pkg/corset/environment.go index 7906649c..a801e869 100644 --- a/pkg/corset/environment.go +++ b/pkg/corset/environment.go @@ -1,202 +1,80 @@ package corset import ( - "fmt" - - "github.com/consensys/go-corset/pkg/schema" - sc "github.com/consensys/go-corset/pkg/schema" - "github.com/consensys/go-corset/pkg/trace" tr "github.com/consensys/go-corset/pkg/trace" ) -// =================================================================== -// Environment -// =================================================================== - -// Identifies a specific column within the environment. -type colRef struct { - module uint - column string -} - -// ColumnInfo packages up information about a declared column (either input or -// assignment). -type ColumnInfo struct { - // Column index - cid uint - // Length multiplier - multiplier uint - // Datatype - datatype schema.Type -} - -// IsFinalised checks whether this column has been finalised already. -func (p ColumnInfo) IsFinalised() bool { - return p.multiplier != 0 -} - -// Environment maps module and column names to their (respective) module and -// column indices. The environment separates input columns from assignment -// columns because they are disjoint in the schema being constructed (i.e. input -// columns always have a lower index than assignments). -type Environment struct { - // Maps module names to their module indices. - modules map[string]uint - // Maps input columns to their column indices. - columns map[colRef]ColumnInfo -} - -// EmptyEnvironment constructs an empty environment. -func EmptyEnvironment() *Environment { - modules := make(map[string]uint) - columns := make(map[colRef]ColumnInfo) - // - return &Environment{modules, columns} -} - -// NewModuleScope creates a new evaluation scope. -func (p *Environment) NewModuleScope(module string) *ModuleScope { - mid := p.Module(module) - return &ModuleScope{mid, p, make(map[string]FunctionBinding)} -} - -// RegisterModule registers a new module within this environment. Observe that -// this will panic if the module already exists. Furthermore, the module -// identifier is always determined as the next available identifier. -func (p *Environment) RegisterModule(module string) trace.Context { - if p.HasModule(module) { - panic(fmt.Sprintf("module %s already exists", module)) +// Environment provides an interface into the global scope which can be used for +// simply resolving column identifiers. +type Environment interface { + // Module returns the module identifier for a given module, or panics if no + // such module exists. + Module(name string) *ModuleScope + // Column returns the column identifier for a given column in a given + // module, or panics if no such column exists. + Column(module string, name string) *ColumnBinding + // Convert a context from the high-level form into the lower level form + // suitable for HIR. + ToContext(from Context) tr.Context + // Construct a trace context from a given module and multiplier. + ContextFrom(module string, multiplier uint) tr.Context +} + +// GlobalEnvironment is a wrapper around a global scope. The point, really, is +// to signal the change between a global scope whose columns have yet to be +// allocated, from an environment whose columns are allocated. +type GlobalEnvironment struct { + scope *GlobalScope +} + +// NewGlobalEnvironment constructs a new global environment from a global scope +// by allocating appropriate identifiers to all columns. +func NewGlobalEnvironment(scope *GlobalScope) GlobalEnvironment { + columnId := uint(0) + // Allocate input columns first. + for _, m := range scope.modules { + for _, b := range m.bindings { + if binding, ok := b.(*ColumnBinding); ok && !binding.computed { + binding.AllocateId(columnId) + // Increase the column id + columnId++ + } + } } - // Update schema - mid := uint(len(p.modules)) - // Update cache - p.modules[module] = mid - // Done - return trace.NewContext(mid, 1) -} - -// RegisterColumn registers a new column within a given module. Observe that -// this will panic if the column already exists. Furthermore, the column -// identifier is always determined as the next available identifier. Hence, care -// must be taken when declaring columns to ensure they are allocated in the -// right order. -func (p *Environment) RegisterColumn(context trace.Context, column string, datatype schema.Type) uint { - if p.HasColumn(context.Module(), column) { - panic(fmt.Sprintf("column %d:%s already exists", context.Module(), column)) - } else if datatype == nil { - panic(fmt.Sprintf("column %d:%s cannot have nil type", context.Module(), column)) - } else if context.LengthMultiplier() == 0 { - panic(fmt.Sprintf("column %d:%s cannot have 0 length multiplier", context.Module(), column)) + // Allocate assignments second. + for _, m := range scope.modules { + for _, b := range m.bindings { + if binding, ok := b.(*ColumnBinding); ok && binding.computed { + binding.AllocateId(columnId) + // Increase the column id + columnId++ + } + } } - // Update cache - cid := uint(len(p.columns)) - cref := colRef{context.Module(), column} - p.columns[cref] = ColumnInfo{cid, context.LengthMultiplier(), datatype} // Done - return cid + return GlobalEnvironment{scope} } -// PreRegisterColumn makes an initial recording of the column and allocates a -// column identifier. A pre-registered column is a column who registration has -// not yet been finalised. More specifically the column is not considered -// finalised (i.e. ready for use) until FinaliseColumn is called. -func (p *Environment) PreRegisterColumn(module uint, column string) uint { - if p.HasColumn(module, column) { - panic(fmt.Sprintf("column %d:%s already exists", module, column)) - } - // Update cache - cid := uint(len(p.columns)) - cref := colRef{module, column} - p.columns[cref] = ColumnInfo{cid, 0, nil} - // Done - return cid -} - -// IsColumnFinalised determines whether a given column has been finalised yet, -// or not. Observe this will panic if the column has not at least been -// pre-registered. -func (p *Environment) IsColumnFinalised(module uint, column string) bool { - if !p.HasColumn(module, column) { - panic(fmt.Sprintf("column %d:%s does not exist", module, column)) - } - // - cref := colRef{module, column} - // Check information is finalised. - return p.columns[cref].IsFinalised() -} - -// FinaliseColumn finalises details of a columnm, specifically its length -// multiplier and type. After this has been called, IsColumnFinalised should -// return true for the column in question. Obserce this will panic if the -// column has not been preregistered, or if it is already finalised. -func (p *Environment) FinaliseColumn(context tr.Context, column string, datatype sc.Type) { - // Sanity check we are not finalising a column which has already been finalised. - if p.IsColumnFinalised(context.Module(), column) { - panic(fmt.Sprintf("Attempt to refinalise column %s", column)) - } - // - cref := colRef{context.Module(), column} - // Extract existing (incomplete) info - info := p.columns[cref] - // Update incomplete info - p.columns[cref] = ColumnInfo{info.cid, context.LengthMultiplier(), datatype} -} - -// LookupModule determines the module index for a given named module, or return -// false if no such module exists. -func (p *Environment) LookupModule(module string) (uint, bool) { - mid, ok := p.modules[module] - return mid, ok +// Module returns the identifier of the module with the given name. +func (p GlobalEnvironment) Module(name string) *ModuleScope { + return p.scope.Module(name) } -// LookupColumn determines the column index for a given named column in a given -// module, or return false if no such column exists. Observe this will return -// information even for columns which exist by are not yet finalised. -func (p *Environment) LookupColumn(module uint, column string) (ColumnInfo, bool) { - cref := colRef{module, column} - cinfo, ok := p.columns[cref] - - return cinfo, ok -} - -// Module determines the module index for a given module. This assumes the -// module exists, and will panic otherwise. -func (p *Environment) Module(module string) uint { - ctx, ok := p.LookupModule(module) - // Sanity check we found something - if !ok { - panic(fmt.Sprintf("unknown module %s", module)) - } - // Discard column index - return ctx -} - -// Column determines the column index for a given column declared in a given -// module. This assumes the column / module exist, and will panic otherwise. -// Furthermore, this assumes that the column is finalised and, otherwise, will -// panic. -func (p *Environment) Column(module uint, column string) ColumnInfo { - info, ok := p.LookupColumn(module, column) - // Sanity check we found something - if !ok { - panic(fmt.Sprintf("unknown column %s", column)) - } else if !info.IsFinalised() { - panic(fmt.Sprintf("column %s not yet finalised", column)) - } - // Done - return info +// Column returns the column identifier for a given column in a given +// module, or panics if no such column exists. +func (p GlobalEnvironment) Column(module string, name string) *ColumnBinding { + // Lookup the given binding, expecting that it is a column binding. If not, + // then this will fail. + return p.Module(module).Bind(nil, name, false).(*ColumnBinding) } -// HasModule checks whether a given module exists, or not. -func (p *Environment) HasModule(module string) bool { - _, ok := p.LookupModule(module) - // Discard column index - return ok +// ContextFrom constructs a trace context for a given module and length +// multiplier. +func (p GlobalEnvironment) ContextFrom(module string, multiplier uint) tr.Context { + return tr.NewContext(p.Module(module).mid, multiplier) } -// HasColumn checks whether a given module has a given column, or not. -func (p *Environment) HasColumn(module uint, column string) bool { - _, ok := p.LookupColumn(module, column) - // Discard column index - return ok +// ToContext constructs a trace context from a given corset context. +func (p GlobalEnvironment) ToContext(from Context) tr.Context { + return p.ContextFrom(from.Module(), from.LengthMultiplier()) } diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index 71e42fb6..af201d41 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -432,7 +432,8 @@ func (p *Parser) parseDefPermutation(elements []sexp.SExp) (*DefPermutation, *Sy } // targets := make([]*DefColumn, sexpTargets.Len()) - sources := make([]*DefPermutedColumn, sexpSources.Len()) + sources := make([]*DefName, sexpSources.Len()) + signs := make([]bool, sexpSources.Len()) // for i := 0; i < len(targets); i++ { // Parse target column @@ -440,43 +441,45 @@ func (p *Parser) parseDefPermutation(elements []sexp.SExp) (*DefPermutation, *Sy return nil, err } // Parse source column - if sources[i], err = p.parsePermutedColumnDeclaration(i == 0, sexpSources.Get(i)); err != nil { + if sources[i], signs[i], err = p.parsePermutedColumnDeclaration(i == 0, sexpSources.Get(i)); err != nil { return nil, err } } // - return &DefPermutation{targets, sources}, nil + return &DefPermutation{targets, sources, signs}, nil } -func (p *Parser) parsePermutedColumnDeclaration(signRequired bool, e sexp.SExp) (*DefPermutedColumn, *SyntaxError) { - var err *SyntaxError - // - defcolumn := &DefPermutedColumn{"", false} +func (p *Parser) parsePermutedColumnDeclaration(signRequired bool, e sexp.SExp) (*DefName, bool, *SyntaxError) { + var ( + err *SyntaxError + name DefName + sign bool + ) // Check whether extended declaration or not. if l := e.AsList(); l != nil { // Check at least the name provided. if len(l.Elements) == 0 { - return defcolumn, p.translator.SyntaxError(l, "empty permutation column") + return nil, false, p.translator.SyntaxError(l, "empty permutation column") } else if len(l.Elements) != 2 { - return defcolumn, p.translator.SyntaxError(l, "malformed permutation column") + return nil, false, p.translator.SyntaxError(l, "malformed permutation column") } else if l.Get(0).AsSymbol() == nil || l.Get(1).AsSymbol() == nil { - return defcolumn, p.translator.SyntaxError(l, "empty permutation column") + return nil, false, p.translator.SyntaxError(l, "empty permutation column") } // Parse sign - if defcolumn.Sign, err = p.parsePermutedColumnSign(l.Get(0).AsSymbol()); err != nil { - return nil, err + if sign, err = p.parsePermutedColumnSign(l.Get(0).AsSymbol()); err != nil { + return nil, false, err } // Parse column name - defcolumn.Name = l.Get(1).AsSymbol().Value + name.Name = l.Get(1).AsSymbol().Value } else if signRequired { - return nil, p.translator.SyntaxError(e, "missing sort direction") + return nil, false, p.translator.SyntaxError(e, "missing sort direction") } else { - defcolumn.Name = e.String(false) + name.Name = e.String(false) } // Update source mapping - p.mapSourceNode(e, defcolumn) + p.mapSourceNode(e, &name) // - return defcolumn, nil + return &name, sign, nil } func (p *Parser) parsePermutedColumnSign(sign *sexp.Symbol) (bool, *SyntaxError) { diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index e06ccc29..74180b05 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -5,7 +5,6 @@ import ( "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/sexp" - tr "github.com/consensys/go-corset/pkg/trace" ) // ResolveCircuit resolves all symbols declared and used within a circuit, @@ -14,63 +13,54 @@ import ( // a symbol (e.g. a column) is referred to which doesn't exist. Likewise, if // two modules or columns with identical names are declared in the same scope, // etc. -func ResolveCircuit(srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*Environment, []SyntaxError) { - r := resolver{EmptyEnvironment(), srcmap} - // Allocate declared modules - r.resolveModules(circuit) +func ResolveCircuit(srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*GlobalScope, []SyntaxError) { + // Construct top-level scope + scope := NewGlobalScope() + // Register the root module (which should always exist) + scope.DeclareModule("") + // Register other modules + for _, m := range circuit.Modules { + scope.DeclareModule(m.Name) + } + // Construct resolver + r := resolver{srcmap} // Allocate declared input columns - errs := r.resolveColumns(circuit) + errs := r.resolveColumns(scope, circuit) // Check expressions - errs = append(errs, r.resolveConstraints(circuit)...) + errs = append(errs, r.resolveConstraints(scope, circuit)...) // Done - return r.env, errs + return scope, errs } // Resolver packages up information necessary for resolving a circuit and // checking that everything makes sense. type resolver struct { - // Environment determines module and column indices, as needed for - // translating the various constructs found in a circuit. - env *Environment // Source maps nodes in the circuit back to the spans in their original // source files. This is needed when reporting syntax errors to generate // highlights of the relevant source line(s) in question. srcmap *sexp.SourceMaps[Node] } -// Process all module declarations, and allocating them into the environment. -// If any duplicates are found, one or more errors will be reported. Note: it -// is important that this traverses the modules in an identical order to the -// translator. This is to ensure that the relevant module identifiers line up. -func (r *resolver) resolveModules(circuit *Circuit) { - // Register the root module (which should always exist) - r.env.RegisterModule("") - // - for _, m := range circuit.Modules { - r.env.RegisterModule(m.Name) - } -} - // Process all input column or column assignment declarations. -func (r *resolver) resolveColumns(circuit *Circuit) []SyntaxError { +func (r *resolver) resolveColumns(scope *GlobalScope, circuit *Circuit) []SyntaxError { // Allocate input columns first. These must all be done before any // assignments are allocated, since the hir.Schema separates these out. - ierrs := r.resolveInputColumns(circuit) + ierrs := r.resolveInputColumns(scope, circuit) // Now we can resolve any assignments. - aerrs := r.resolveAssignments(circuit) + aerrs := r.resolveAssignments(scope, circuit) // return append(ierrs, aerrs...) } // Process all input column declarations. -func (r *resolver) resolveInputColumns(circuit *Circuit) []SyntaxError { +func (r *resolver) resolveInputColumns(scope *GlobalScope, circuit *Circuit) []SyntaxError { // Input columns must be allocated before assignemts, since the hir.Schema // separates these out. - errs := r.resolveInputColumnsInModule("", circuit.Declarations) + errs := r.resolveInputColumnsInModule(scope.Module(""), circuit.Declarations) // for _, m := range circuit.Modules { // Process all declarations in the module - merrs := r.resolveInputColumnsInModule(m.Name, m.Declarations) + merrs := r.resolveInputColumnsInModule(scope.Module(m.Name), m.Declarations) // Package up all errors errs = append(errs, merrs...) } @@ -79,21 +69,22 @@ func (r *resolver) resolveInputColumns(circuit *Circuit) []SyntaxError { } // Resolve all input columns in a given module. -func (r *resolver) resolveInputColumnsInModule(module string, decls []Declaration) []SyntaxError { +func (r *resolver) resolveInputColumnsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { errors := make([]SyntaxError, 0) - mid := r.env.Module(module) // for _, d := range decls { if dcols, ok := d.(*DefColumns); ok { // Found one. for _, col := range dcols.Columns { // Check whether column already exists - if _, ok := r.env.LookupColumn(mid, col.Name); ok { - err := r.srcmap.SyntaxError(col, fmt.Sprintf("column %s already declared in module %s", col.Name, module)) + if scope.Bind(nil, col.Name, false) != nil { + msg := fmt.Sprintf("symbol %s already declared in %s", col.Name, scope.EnclosingModule()) + err := r.srcmap.SyntaxError(col, msg) errors = append(errors, *err) } else { - context := tr.NewContext(mid, col.LengthMultiplier) - r.env.RegisterColumn(context, col.Name, col.DataType) + // Declare new column + scope.Declare(col.Name, false, NewColumnBinding(scope.EnclosingModule(), + false, col.LengthMultiplier, col.DataType)) } } } @@ -105,14 +96,14 @@ func (r *resolver) resolveInputColumnsInModule(module string, decls []Declaratio // Process all assignment column declarations. These are more complex than for // input columns, since there can be dependencies between them. Thus, we cannot // simply resolve them in one linear scan. -func (r *resolver) resolveAssignments(circuit *Circuit) []SyntaxError { +func (r *resolver) resolveAssignments(scope *GlobalScope, circuit *Circuit) []SyntaxError { // Input columns must be allocated before assignemts, since the hir.Schema // separates these out. - errs := r.resolveAssignmentsInModule("", circuit.Declarations) + errs := r.resolveAssignmentsInModule(scope.Module(""), circuit.Declarations) // for _, m := range circuit.Modules { // Process all declarations in the module - merrs := r.resolveAssignmentsInModule(m.Name, m.Declarations) + merrs := r.resolveAssignmentsInModule(scope.Module(m.Name), m.Declarations) // Package up all errors errs = append(errs, merrs...) } @@ -124,16 +115,16 @@ func (r *resolver) resolveAssignments(circuit *Circuit) []SyntaxError { // assignments can depend on the declaration of other columns. Hence, we have // to process all columns before we can sure that they are all declared // correctly. -func (r *resolver) resolveAssignmentsInModule(module string, decls []Declaration) []SyntaxError { - if errors := r.initialiseAssignmentsInModule(module, decls); len(errors) > 0 { +func (r *resolver) resolveAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { + if errors := r.initialiseAssignmentsInModule(scope, decls); len(errors) > 0 { return errors } // Check assignments - if errors := r.checkAssignmentsInModule(module, decls); len(errors) > 0 { + if errors := r.checkAssignmentsInModule(scope, decls); len(errors) > 0 { return errors } // Iterate until all columns finalised - return r.finaliseAssignmentsInModule(module, decls) + return r.finaliseAssignmentsInModule(scope, decls) } // Initialise the column allocation from the available declarations, whilst @@ -141,27 +132,27 @@ func (r *resolver) resolveAssignmentsInModule(module string, decls []Declaration // the initial assignment is incomplete because information about dependent // columns may not be available. So, the goal of the subsequent phase is to // flesh out this missing information. -func (r *resolver) initialiseAssignmentsInModule(module string, decls []Declaration) []SyntaxError { +func (r *resolver) initialiseAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { + module := scope.EnclosingModule() errors := make([]SyntaxError, 0) - mid := r.env.Module(module) // for _, d := range decls { if col, ok := d.(*DefInterleaved); ok { - if _, ok := r.env.LookupColumn(mid, col.Target); ok { - err := r.srcmap.SyntaxError(col, fmt.Sprintf("column %s already declared in module %s", col.Target, module)) + if binding := scope.Bind(nil, col.Target, false); binding != nil { + err := r.srcmap.SyntaxError(col, fmt.Sprintf("symbol %s already declared in %s", col.Target, module)) errors = append(errors, *err) } else { // Register incomplete (assignment) column. - r.env.PreRegisterColumn(mid, col.Target) + scope.Declare(col.Target, false, NewColumnBinding(module, true, 0, nil)) } } else if col, ok := d.(*DefPermutation); ok { for _, c := range col.Targets { - if _, ok := r.env.LookupColumn(mid, c.Name); ok { - err := r.srcmap.SyntaxError(col, fmt.Sprintf("column %s already declared in module %s", c.Name, module)) + if binding := scope.Bind(nil, c.Name, false); binding != nil { + err := r.srcmap.SyntaxError(col, fmt.Sprintf("symbol %s already declared in %s", c.Name, module)) errors = append(errors, *err) } else { // Register incomplete (assignment) column. - r.env.PreRegisterColumn(mid, c.Name) + scope.Declare(c.Name, false, NewColumnBinding(scope.EnclosingModule(), true, 0, nil)) } } } @@ -170,20 +161,19 @@ func (r *resolver) initialiseAssignmentsInModule(module string, decls []Declarat return errors } -func (r *resolver) checkAssignmentsInModule(module string, decls []Declaration) []SyntaxError { +func (r *resolver) checkAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { errors := make([]SyntaxError, 0) - mid := r.env.Module(module) // for _, d := range decls { if col, ok := d.(*DefInterleaved); ok { for _, c := range col.Sources { - if !r.env.HasColumn(mid, c.Name) { + if scope.Bind(nil, c.Name, false) == nil { errors = append(errors, *r.srcmap.SyntaxError(c, "unknown source column")) } } } else if col, ok := d.(*DefPermutation); ok { for _, c := range col.Sources { - if !r.env.HasColumn(mid, c.Name) { + if scope.Bind(nil, c.Name, false) == nil { errors = append(errors, *r.srcmap.SyntaxError(c, "unknown source column")) } } @@ -194,8 +184,7 @@ func (r *resolver) checkAssignmentsInModule(module string, decls []Declaration) } // Iterate the column allocation to a fix point by iteratively fleshing out column information. -func (r *resolver) finaliseAssignmentsInModule(module string, decls []Declaration) []SyntaxError { - mid := r.env.Module(module) +func (r *resolver) finaliseAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { // Changed indicates whether or not a new assignment was finalised during a // given iteration. This is important to know since, if the assignment is // not complete and we didn't finalise any more assignments --- then, we've @@ -219,9 +208,9 @@ func (r *resolver) finaliseAssignmentsInModule(module string, decls []Declaratio for _, d := range decls { if col, ok := d.(*DefInterleaved); ok { // Check whether dependencies are resolved or not. - if col.CanFinalise(mid, r.env) { + if r.columnsAreFinalised(scope, col.Sources) { // Finalise assignment and handle any errors - errs := r.finaliseInterleavedAssignment(mid, col) + errs := r.finaliseInterleavedAssignment(scope, col) errors = append(errors, errs...) // Record that a new assignment is available. changed = changed || len(errs) == 0 @@ -231,9 +220,9 @@ func (r *resolver) finaliseAssignmentsInModule(module string, decls []Declaratio } } else if col, ok := d.(*DefPermutation); ok { // Check whether dependencies are resolved or not. - if col.CanFinalise(mid, r.env) { + if r.columnsAreFinalised(scope, col.Sources) { // Finalise assignment and handle any errors - errs := r.finalisePermutationAssignment(mid, col) + errs := r.finalisePermutationAssignment(scope, col) errors = append(errors, errs...) // Record that a new assignment is available. changed = changed || len(errs) == 0 @@ -259,12 +248,29 @@ func (r *resolver) finaliseAssignmentsInModule(module string, decls []Declaratio return nil } +// Check that a given set of source columns have been finalised. This is +// important, since we cannot finalise an assignment until all of its +// dependencies have themselves been finalised. +func (r *resolver) columnsAreFinalised(scope *ModuleScope, columns []*DefName) bool { + for _, col := range columns { + // Look up information + info := scope.Bind(nil, col.Name, false).(*ColumnBinding) + // Check whether its finalised + if info.multiplier == 0 { + // Nope, not yet. + return false + } + } + // + return true +} + // Finalise an interleaving assignment. Since the assignment would already been // initialised, all we need to do is determine the appropriate type and length // multiplier for the interleaved column. This can still result in an error, // for example, if the multipliers between interleaved columns are incompatible, // etc. -func (r *resolver) finaliseInterleavedAssignment(module uint, decl *DefInterleaved) []SyntaxError { +func (r *resolver) finaliseInterleavedAssignment(scope *ModuleScope, decl *DefInterleaved) []SyntaxError { var ( // Length multiplier being determined length_multiplier uint @@ -276,7 +282,8 @@ func (r *resolver) finaliseInterleavedAssignment(module uint, decl *DefInterleav // Determine type and length multiplier for i, source := range decl.Sources { // Lookup info of column being interleaved. - info := r.env.Column(module, source.Name) + info := scope.Bind(nil, source.Name, false).(*ColumnBinding) + // if i == 0 { length_multiplier = info.multiplier datatype = info.datatype @@ -292,10 +299,11 @@ func (r *resolver) finaliseInterleavedAssignment(module uint, decl *DefInterleav if len(errors) == 0 { // Determine actual length multiplier length_multiplier *= uint(len(decl.Sources)) - // Construct context for this column - context := tr.NewContext(module, length_multiplier) - // Finalise column registration - r.env.FinaliseColumn(context, decl.Target, datatype) + // Lookup existing declaration + info := scope.Bind(nil, decl.Target, false).(*ColumnBinding) + // Update with completed information + info.multiplier = length_multiplier + info.datatype = datatype } // Done return errors @@ -303,7 +311,7 @@ func (r *resolver) finaliseInterleavedAssignment(module uint, decl *DefInterleav // Finalise a permutation assignment. Since the assignment would already been // initialised, this is actually quite easy to do. -func (r *resolver) finalisePermutationAssignment(module uint, decl *DefPermutation) []SyntaxError { +func (r *resolver) finalisePermutationAssignment(scope *ModuleScope, decl *DefPermutation) []SyntaxError { var ( multiplier uint = 0 errors []SyntaxError @@ -311,19 +319,22 @@ func (r *resolver) finalisePermutationAssignment(module uint, decl *DefPermutati // Finalise each column in turn for i := 0; i < len(decl.Sources); i++ { ith := decl.Sources[i] - src := r.env.Column(module, ith.Name) + // Lookup source of column being permuted + source := scope.Bind(nil, ith.Name, false).(*ColumnBinding) // Sanity check length multiplier - if i == 0 && src.datatype.AsUint() == nil { + if i == 0 && source.datatype.AsUint() == nil { errors = append(errors, *r.srcmap.SyntaxError(ith, "fixed-width type required")) } else if i == 0 { - multiplier = src.multiplier - } else if multiplier != src.multiplier { + multiplier = source.multiplier + } else if multiplier != source.multiplier { // Problem errors = append(errors, *r.srcmap.SyntaxError(ith, "incompatible length multiplier")) } - // All good, finalise column - context := tr.NewContext(module, src.multiplier) - r.env.FinaliseColumn(context, decl.Targets[i].Name, src.datatype) + // All good, finalise target column + target := scope.Bind(nil, decl.Targets[i].Name, false).(*ColumnBinding) + // Update with completed information + target.multiplier = source.multiplier + target.datatype = source.datatype } // Done return errors @@ -336,14 +347,12 @@ func (r *resolver) finalisePermutationAssignment(module uint, decl *DefPermutati // pass is to: firstly, check that every variable refers to something which was // declared; secondly, to determine what each variable represents (i.e. column // access, a constant, etc). -func (r *resolver) resolveConstraints(circuit *Circuit) []SyntaxError { - root := r.buildModuleScope("", circuit.Declarations) - errs := r.resolveConstraintsInModule(root, circuit.Declarations) +func (r *resolver) resolveConstraints(scope *GlobalScope, circuit *Circuit) []SyntaxError { + errs := r.resolveConstraintsInModule(scope.Module(""), circuit.Declarations) // for _, m := range circuit.Modules { - module := r.buildModuleScope(m.Name, circuit.Declarations) // Process all declarations in the module - merrs := r.resolveConstraintsInModule(module, m.Declarations) + merrs := r.resolveConstraintsInModule(scope.Module(m.Name), m.Declarations) // Package up all errors errs = append(errs, merrs...) } @@ -351,22 +360,6 @@ func (r *resolver) resolveConstraints(circuit *Circuit) []SyntaxError { return errs } -func (r *resolver) buildModuleScope(name string, decls []Declaration) Scope { - var ( - scope *ModuleScope = r.env.NewModuleScope(name) - ) - // - for _, d := range decls { - // Look for defcolumns decalarations only - if c, ok := d.(*DefFun); ok { - // TODO: sanity check if function already declared. - scope.DeclareFunction(c.Name.Name, uint(len(c.Parameters)), c.Body) - } - } - // - return scope -} - // Helper for resolve constraints which considers those constraints declared in // a particular module. func (r *resolver) resolveConstraintsInModule(enclosing Scope, decls []Declaration) []SyntaxError { @@ -543,22 +536,18 @@ func (r *resolver) resolveInvokeInModule(scope LocalScope, expr *Invoke) []Synta // permitted in a global context. func (r *resolver) resolveVariableInModule(scope LocalScope, expr *VariableAccess) []SyntaxError { - // Will identify module of variable - //var module string = scope.EnclosingModule() - var mid *uint // Check whether this is a qualified access, or not. if !scope.IsGlobal() && expr.Module != nil { return r.srcmap.SyntaxErrors(expr, "qualified access not permitted here") } else if expr.Module != nil && !scope.HasModule(*expr.Module) { return r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown module %s", *expr.Module)) - } else if expr.Module != nil { - tmp := scope.Module(*expr.Module) - mid = &tmp } - // Attempt resolve as a column access in enclosing module - if expr.Binding = scope.Bind(mid, expr.Name, false); expr.Binding != nil { + // Attempt resolve this variable access, noting that it definitely does not + // refer to a function. + if expr.Binding = scope.Bind(expr.Module, expr.Name, false); expr.Binding != nil { // Update context - if !scope.FixContext(expr.Binding.Context()) { + binding, ok := expr.Binding.(*ColumnBinding) + if ok && !scope.FixContext(binding.Context()) { return r.srcmap.SyntaxErrors(expr, "conflicting context") } // Done diff --git a/pkg/corset/scope.go b/pkg/corset/scope.go index 79344907..984f6ea4 100644 --- a/pkg/corset/scope.go +++ b/pkg/corset/scope.go @@ -13,17 +13,83 @@ import ( type Scope interface { // Get the name of the enclosing module. This is generally useful for // reporting errors. - EnclosingModule() uint + EnclosingModule() string // HasModule checks whether a given module exists, or not. HasModule(string) bool - // Lookup the identifier for a given module. This assumes that the module - // exists, and will panic otherwise. - Module(string) uint // Lookup a given variable being referenced with an optional module // specifier. This variable could correspond to a column, a function, a // parameter, or a local variable. Furthermore, the returned binding will // be nil if this variable does not exist. - Bind(*uint, string, bool) Binding + Bind(*string, string, bool) Binding +} + +// ============================================================================= +// Global Scope +// ============================================================================= + +// GlobalScope represents the top-level scope in a Corset file, and is used to +// glue the scopes for modules together. For example, it enables one module to +// lookup columns in another. +type GlobalScope struct { + // Top-level mapping of modules to their scopes. + ids map[string]uint + // List of modules in declaration order + modules []ModuleScope +} + +// NewGlobalScope constructs an empty global scope. +func NewGlobalScope() *GlobalScope { + return &GlobalScope{make(map[string]uint), make([]ModuleScope, 0)} +} + +// DeclareModule declares an initialises a new module within this global scope. +// If a module by the same name already exists, then this will panic. +func (p *GlobalScope) DeclareModule(module string) { + // Sanity check module doesn't already exist + if _, ok := p.ids[module]; ok { + panic(fmt.Sprintf("duplicate module %s declared", module)) + } + // Register module + mid := uint(len(p.ids)) + scope := ModuleScope{module, mid, make(map[BindingId]uint), make([]Binding, 0), p} + p.modules = append(p.modules, scope) + p.ids[module] = mid +} + +// EnclosingModule returns the name of the enclosing module. For a global +// scope, this has no meaning. +func (p *GlobalScope) EnclosingModule() string { + panic("unreachable") +} + +// HasModule checks whether a given module exists, or not. +func (p *GlobalScope) HasModule(module string) bool { + // Attempt to lookup the module + _, ok := p.ids[module] + // Return what we found + return ok +} + +// Bind looks up a given variable being referenced within a given module. For a +// root context, this is either a column, an alias or a function declaration. +func (p *GlobalScope) Bind(module *string, name string, fn bool) Binding { + if module == nil { + panic("cannot bind unqualified symbol in the global scope") + } + // + return p.Module(*module).Bind(nil, name, fn) +} + +// Module returns the identifier of the module with the given name. +func (p *GlobalScope) Module(name string) *ModuleScope { + mid := p.ids[name] + return &p.modules[mid] +} + +// ToEnvironment converts this global scope into a concrete environment by +// allocating all columns within this scope. +func (p *GlobalScope) ToEnvironment() Environment { + return NewGlobalEnvironment(p) } // ============================================================================= @@ -32,59 +98,57 @@ type Scope interface { // ModuleScope represents the scope characterised by a module. type ModuleScope struct { - // Module ID - module uint - // Provides access to global environment - environment *Environment - // Maps function names to their contents. - functions map[string]FunctionBinding + // Module name + module string + // Module identifier + mid uint + // Mapping from binding identifiers to indices within the bindings array. + ids map[BindingId]uint + // The set of bindings in the order of declaration. + bindings []Binding + // Enclosing global scope + enclosing Scope } // EnclosingModule returns the name of the enclosing module. This is generally // useful for reporting errors. -func (p *ModuleScope) EnclosingModule() uint { +func (p *ModuleScope) EnclosingModule() string { return p.module } // HasModule checks whether a given module exists, or not. func (p *ModuleScope) HasModule(module string) bool { - return p.environment.HasModule(module) -} - -// Module determines the module index for a given module. This assumes the -// module exists, and will panic otherwise. -func (p *ModuleScope) Module(module string) uint { - return p.environment.Module(module) + return p.enclosing.HasModule(module) } // Bind looks up a given variable being referenced within a given module. For a // root context, this is either a column, an alias or a function declaration. -func (p *ModuleScope) Bind(module *uint, name string, fn bool) Binding { - var mid uint +func (p *ModuleScope) Bind(module *string, name string, fn bool) Binding { // Determine module for this lookup. if module != nil { - mid = *module - } else { - mid = p.module + // non-local lookup + return p.enclosing.Bind(module, name, fn) } - // Lookup function - if binding, ok := p.functions[name]; ok && module == nil { - return &binding - } else if info, ok := p.environment.LookupColumn(mid, name); ok && !fn { - ctx := tr.NewContext(mid, info.multiplier) - return &ColumnBinding{ctx, info.cid} + // construct binding identifier + if bid, ok := p.ids[BindingId{name, fn}]; ok { + return p.bindings[bid] } - // error + // failed return nil } -// DeclareFunction declares a given function within this module scope. -func (p *ModuleScope) DeclareFunction(name string, arity uint, body Expr) { - if _, ok := p.functions[name]; ok { - panic(fmt.Sprintf("attempt to redeclared function \"%s\"/%d", name, arity)) +// Declare declares a given binding within this module scope. +func (p *ModuleScope) Declare(name string, fn bool, binding Binding) { + // construct binding identifier + bid := BindingId{name, fn} + // Sanity check not already declared + if _, ok := p.ids[bid]; ok { + panic(fmt.Sprintf("attempt to redeclare binding for \"%s\"", name)) } - // - p.functions[name] = FunctionBinding{arity, body} + // Done + id := uint(len(p.bindings)) + p.bindings = append(p.bindings, binding) + p.ids[bid] = id } // ============================================================================= @@ -100,7 +164,7 @@ type LocalScope struct { // Represents the enclosing scope enclosing Scope // Context for this scope - context *tr.Context + context *Context // Maps inputs parameters to the declaration index. locals map[string]uint } @@ -110,7 +174,7 @@ type LocalScope struct { // also be "global" in the sense that accessing symbols from other modules is // permitted. func NewLocalScope(enclosing Scope, global bool) LocalScope { - context := tr.VoidContext() + context := tr.VoidContext[string]() locals := make(map[string]uint) // return LocalScope{global, enclosing, &context, locals} @@ -135,13 +199,13 @@ func (p LocalScope) IsGlobal() bool { // EnclosingModule returns the name of the enclosing module. This is generally // useful for reporting errors. -func (p LocalScope) EnclosingModule() uint { +func (p LocalScope) EnclosingModule() string { return p.enclosing.EnclosingModule() } // FixContext fixes the context for this scope. Since every scope requires // exactly one context, this fails if we fix it to incompatible contexts. -func (p LocalScope) FixContext(context tr.Context) bool { +func (p LocalScope) FixContext(context Context) bool { // Join contexts together *p.context = p.context.Join(context) // Check they were compatible @@ -153,15 +217,9 @@ func (p LocalScope) HasModule(module string) bool { return p.enclosing.HasModule(module) } -// Module determines the module index for a given module. This assumes the -// module exists, and will panic otherwise. -func (p LocalScope) Module(module string) uint { - return p.enclosing.Module(module) -} - // Bind looks up a given variable or function being referenced either within the // enclosing scope (module==nil) or within a specified module. -func (p LocalScope) Bind(module *uint, name string, fn bool) Binding { +func (p LocalScope) Bind(module *string, name string, fn bool) Binding { // Check whether this is a local variable access. if id, ok := p.locals[name]; ok && !fn && module == nil { // Yes, this is a local variable access. diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index e0493c5b..194c6b0f 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -16,7 +16,7 @@ import ( // easily. Thus, whilst syntax errors can be returned here, this should never // happen. The mechanism is supported, however, to simplify development of new // features, etc. -func TranslateCircuit(env *Environment, srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*hir.Schema, []SyntaxError) { +func TranslateCircuit(env Environment, srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*hir.Schema, []SyntaxError) { t := translator{env, srcmap, hir.EmptySchema()} // Allocate all modules into schema t.translateModules(circuit) @@ -35,9 +35,9 @@ func TranslateCircuit(env *Environment, srcmap *sexp.SourceMaps[Node], circuit * // Translator packages up information necessary for translating a circuit into // the schema form required for the HIR level. type translator struct { - // Environment determines module and column indices, as needed for - // translating the various constructs found in a circuit. - env *Environment + // Environment is needed for determining the identifiers for modules and + // columns. + env Environment // Source maps nodes in the circuit back to the spans in their original // source files. This is needed when reporting syntax errors to generate // highlights of the relevant source line(s) in question. @@ -52,7 +52,7 @@ func (t *translator) translateModules(circuit *Circuit) { // Add nested modules for _, m := range circuit.Modules { mid := t.schema.AddModule(m.Name) - aid := t.env.Module(m.Name) + aid := t.env.Module(m.Name).mid // Sanity check everything lines up. if aid != mid { panic(fmt.Sprintf("Invalid module identifier: %d vs %d", mid, aid)) @@ -75,12 +75,10 @@ func (t *translator) translateInputColumns(circuit *Circuit) []SyntaxError { // Translate all input column declarations occurring in a given module within the circuit. func (t *translator) translateInputColumnsInModule(module string, decls []Declaration) []SyntaxError { var errors []SyntaxError - // Construct context for enclosing module - context := t.env.Module(module) // for _, d := range decls { if dcols, ok := d.(*DefColumns); ok { - errs := t.translateDefColumns(dcols, context) + errs := t.translateDefColumns(dcols, module) errors = append(errors, errs...) } } @@ -89,11 +87,11 @@ func (t *translator) translateInputColumnsInModule(module string, decls []Declar } // Translate a "defcolumns" declaration. -func (t *translator) translateDefColumns(decl *DefColumns, module uint) []SyntaxError { +func (t *translator) translateDefColumns(decl *DefColumns, module string) []SyntaxError { var errors []SyntaxError // Add each column to schema for _, c := range decl.Columns { - context := tr.NewContext(module, c.LengthMultiplier) + context := t.env.ContextFrom(module, c.LengthMultiplier) cid := t.schema.AddDataColumn(context, c.Name, c.DataType) // Prove type (if requested) if c.MustProve { @@ -101,7 +99,7 @@ func (t *translator) translateDefColumns(decl *DefColumns, module uint) []Syntax t.schema.AddRangeConstraint(c.Name, context, &hir.ColumnAccess{Column: cid, Shift: 0}, bound) } // Sanity check column identifier - if info := t.env.Column(module, c.Name); info.cid != cid { + if info := t.env.Column(module, c.Name); info.ColumnId() != cid { errors = append(errors, *t.srcmap.SyntaxError(c, "invalid column identifier")) } } @@ -125,11 +123,9 @@ func (t *translator) translateAssignmentsAndConstraints(circuit *Circuit) []Synt // the circuit. func (t *translator) translateAssignmentsAndConstraintsInModule(module string, decls []Declaration) []SyntaxError { var errors []SyntaxError - // Construct context for enclosing module - context := t.env.Module(module) // for _, d := range decls { - errs := t.translateDeclaration(d, context) + errs := t.translateDeclaration(d, module) errors = append(errors, errs...) } // Done @@ -138,7 +134,7 @@ func (t *translator) translateAssignmentsAndConstraintsInModule(module string, d // Translate an assignment or constraint declarartion which occurs within a // given module. -func (t *translator) translateDeclaration(decl Declaration, module uint) []SyntaxError { +func (t *translator) translateDeclaration(decl Declaration, module string) []SyntaxError { var errors []SyntaxError // if _, ok := decl.(*DefColumns); ok { @@ -167,7 +163,7 @@ func (t *translator) translateDeclaration(decl Declaration, module uint) []Synta } // Translate a "defconstraint" declaration. -func (t *translator) translateDefConstraint(decl *DefConstraint, module uint) []SyntaxError { +func (t *translator) translateDefConstraint(decl *DefConstraint, module string) []SyntaxError { // Translate constraint body constraint, errors := t.translateExpressionInModule(decl.Constraint, module) // Translate (optional) guard @@ -182,7 +178,7 @@ func (t *translator) translateDefConstraint(decl *DefConstraint, module uint) [] if len(errors) == 0 { context := constraint.Context(t.schema) // - if context.Module() != module { + if context.Module() != t.env.Module(module).mid { return t.srcmap.SyntaxErrors(decl, "invalid context inferred") } // Add translated constraint @@ -193,7 +189,9 @@ func (t *translator) translateDefConstraint(decl *DefConstraint, module uint) [] } // Translate a "deflookup" declaration. -func (t *translator) translateDefLookup(decl *DefLookup, module uint) []SyntaxError { +// +//nolint:staticcheck +func (t *translator) translateDefLookup(decl *DefLookup, module string) []SyntaxError { // Translate source expressions sources, src_errs := t.translateUnitExpressionsInModule(decl.Sources, module) targets, tgt_errs := t.translateUnitExpressionsInModule(decl.Targets, module) @@ -201,8 +199,8 @@ func (t *translator) translateDefLookup(decl *DefLookup, module uint) []SyntaxEr errors := append(src_errs, tgt_errs...) // if len(errors) == 0 { - src_context := ContextOfExpressions(decl.Sources) - target_context := ContextOfExpressions(decl.Targets) + src_context := t.env.ToContext(ContextOfExpressions(decl.Sources)) + target_context := t.env.ToContext(ContextOfExpressions(decl.Targets)) // Add translated constraint t.schema.AddLookupConstraint(decl.Handle, src_context, target_context, sources, targets) } @@ -211,12 +209,12 @@ func (t *translator) translateDefLookup(decl *DefLookup, module uint) []SyntaxEr } // Translate a "definrange" declaration. -func (t *translator) translateDefInRange(decl *DefInRange, module uint) []SyntaxError { +func (t *translator) translateDefInRange(decl *DefInRange, module string) []SyntaxError { // Translate constraint body expr, errors := t.translateExpressionInModule(decl.Expr, module) // if len(errors) == 0 { - context := tr.NewContext(module, 1) + context := t.env.ContextFrom(module, 1) // Add translated constraint t.schema.AddRangeConstraint("", context, expr, decl.Bound) } @@ -225,7 +223,7 @@ func (t *translator) translateDefInRange(decl *DefInRange, module uint) []Syntax } // Translate a "definterleaved" declaration. -func (t *translator) translateDefInterleaved(decl *DefInterleaved, module uint) []SyntaxError { +func (t *translator) translateDefInterleaved(decl *DefInterleaved, module string) []SyntaxError { var errors []SyntaxError // sources := make([]uint, len(decl.Sources)) @@ -233,14 +231,14 @@ func (t *translator) translateDefInterleaved(decl *DefInterleaved, module uint) info := t.env.Column(module, decl.Target) // Determine source column identifiers for i, source := range decl.Sources { - sources[i] = t.env.Column(module, source.Name).cid + sources[i] = t.env.Column(module, source.Name).ColumnId() } // Construct context for this assignment - context := tr.NewContext(module, info.multiplier) + context := t.env.ContextFrom(module, info.multiplier) // Register assignment cid := t.schema.AddAssignment(assignment.NewInterleaving(context, decl.Target, sources, info.datatype)) // Sanity check column identifiers align. - if cid != info.cid { + if cid != info.ColumnId() { errors = append(errors, *t.srcmap.SyntaxError(decl, "invalid column identifier")) } // Done @@ -248,7 +246,7 @@ func (t *translator) translateDefInterleaved(decl *DefInterleaved, module uint) } // Translate a "defpermutation" declaration. -func (t *translator) translateDefPermutation(decl *DefPermutation, module uint) []SyntaxError { +func (t *translator) translateDefPermutation(decl *DefPermutation, module string) []SyntaxError { var ( errors []SyntaxError context tr.Context @@ -261,13 +259,13 @@ func (t *translator) translateDefPermutation(decl *DefPermutation, module uint) // for i := 0; i < len(decl.Sources); i++ { target := t.env.Column(module, decl.Targets[i].Name) - context = tr.NewContext(module, target.multiplier) + context = t.env.ContextFrom(module, target.multiplier) targets[i] = sc.NewColumn(context, decl.Targets[i].Name, target.datatype) - sources[i] = t.env.Column(module, decl.Sources[i].Name).cid - signs[i] = decl.Sources[i].Sign + sources[i] = t.env.Column(module, decl.Sources[i].Name).ColumnId() + signs[i] = decl.Signs[i] // Record first CID if i == 0 { - firstCid = target.cid + firstCid = target.ColumnId() } } // Add the assignment and check the first identifier. @@ -281,12 +279,12 @@ func (t *translator) translateDefPermutation(decl *DefPermutation, module uint) } // Translate a "defproperty" declaration. -func (t *translator) translateDefProperty(decl *DefProperty, module uint) []SyntaxError { +func (t *translator) translateDefProperty(decl *DefProperty, module string) []SyntaxError { // Translate constraint body assertion, errors := t.translateExpressionInModule(decl.Assertion, module) // if len(errors) == 0 { - context := tr.NewContext(module, 1) + context := t.env.ContextFrom(module, 1) // Add translated constraint t.schema.AddPropertyAssertion(decl.Handle, context, assertion) } @@ -297,7 +295,7 @@ func (t *translator) translateDefProperty(decl *DefProperty, module uint) []Synt // Translate an optional expression in a given context. That is an expression // which maybe nil (i.e. doesn't exist). In such case, nil is returned (i.e. // without any errors). -func (t *translator) translateOptionalExpressionInModule(expr Expr, module uint) (hir.Expr, []SyntaxError) { +func (t *translator) translateOptionalExpressionInModule(expr Expr, module string) (hir.Expr, []SyntaxError) { if expr != nil { return t.translateExpressionInModule(expr, module) } @@ -308,7 +306,7 @@ func (t *translator) translateOptionalExpressionInModule(expr Expr, module uint) // Translate an optional expression in a given context. That is an expression // which maybe nil (i.e. doesn't exist). In such case, nil is returned (i.e. // without any errors). -func (t *translator) translateUnitExpressionsInModule(exprs []Expr, module uint) ([]hir.UnitExpr, []SyntaxError) { +func (t *translator) translateUnitExpressionsInModule(exprs []Expr, module string) ([]hir.UnitExpr, []SyntaxError) { errors := []SyntaxError{} hirExprs := make([]hir.UnitExpr, len(exprs)) // Iterate each expression in turn @@ -325,7 +323,7 @@ func (t *translator) translateUnitExpressionsInModule(exprs []Expr, module uint) } // Translate a sequence of zero or more expressions enclosed in a given module. -func (t *translator) translateExpressionsInModule(exprs []Expr, module uint) ([]hir.Expr, []SyntaxError) { +func (t *translator) translateExpressionsInModule(exprs []Expr, module string) ([]hir.Expr, []SyntaxError) { errors := []SyntaxError{} hirExprs := make([]hir.Expr, len(exprs)) // Iterate each expression in turn @@ -343,7 +341,7 @@ func (t *translator) translateExpressionsInModule(exprs []Expr, module uint) ([] // Translate an expression situated in a given context. The context is // necessary to resolve unqualified names (e.g. for column access, function // invocations, etc). -func (t *translator) translateExpressionInModule(expr Expr, module uint) (hir.Expr, []SyntaxError) { +func (t *translator) translateExpressionInModule(expr Expr, module string) (hir.Expr, []SyntaxError) { if e, ok := expr.(*Constant); ok { return &hir.Constant{Val: e.Val}, nil } else if v, ok := expr.(*Add); ok { @@ -379,7 +377,10 @@ func (t *translator) translateExpressionInModule(expr Expr, module uint) (hir.Ex return &hir.Sub{Args: args}, errs } else if e, ok := expr.(*VariableAccess); ok { if binding, ok := e.Binding.(*ColumnBinding); ok { - return &hir.ColumnAccess{Column: binding.ColumnID(), Shift: e.Shift}, nil + // Lookup column binding + cinfo := t.env.Column(binding.module, e.Name) + // Done + return &hir.ColumnAccess{Column: cinfo.ColumnId(), Shift: e.Shift}, nil } // error return nil, t.srcmap.SyntaxErrors(expr, "unbound variable") diff --git a/pkg/hir/expr.go b/pkg/hir/expr.go index 279ac932..dd6e61a6 100644 --- a/pkg/hir/expr.go +++ b/pkg/hir/expr.go @@ -265,7 +265,7 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. func (p *Constant) Context(schema sc.Schema) trace.Context { - return trace.VoidContext() + return trace.VoidContext[uint]() } // RequiredColumns returns the set of columns on which this term depends. diff --git a/pkg/mir/expr.go b/pkg/mir/expr.go index 2aa8e4da..896ccf2a 100644 --- a/pkg/mir/expr.go +++ b/pkg/mir/expr.go @@ -170,7 +170,7 @@ func (p *Constant) Bounds() util.Bounds { return util.EMPTY_BOUND } // Context determines the evaluation context (i.e. enclosing module) for this // expression. func (p *Constant) Context(schema sc.Schema) trace.Context { - return trace.VoidContext() + return trace.VoidContext[uint]() } // RequiredColumns returns the set of columns on which this term depends. diff --git a/pkg/schema/schemas.go b/pkg/schema/schemas.go index 82d76bad..d627f1f0 100644 --- a/pkg/schema/schemas.go +++ b/pkg/schema/schemas.go @@ -42,7 +42,7 @@ func QualifiedName(schema Schema, column uint) string { // context is returned. Otherwise, the common context to all expressions is // returned. func JoinContexts[E Contextual](args []E, schema Schema) tr.Context { - ctx := tr.VoidContext() + ctx := tr.VoidContext[uint]() // for _, e := range args { ctx = ctx.Join(e.Context(schema)) @@ -58,7 +58,7 @@ func JoinContexts[E Contextual](args []E, schema Schema) tr.Context { // conflicting context is returned. Otherwise, the common context to all // columns is returned. func ContextOfColumns(cols []uint, schema Schema) tr.Context { - ctx := tr.VoidContext() + ctx := tr.VoidContext[uint]() // for i := 0; i < len(cols); i++ { col := schema.Columns().Nth(cols[i]) diff --git a/pkg/trace/context.go b/pkg/trace/context.go index 8fc90513..611cfcc1 100644 --- a/pkg/trace/context.go +++ b/pkg/trace/context.go @@ -5,7 +5,10 @@ import ( "math" ) -// Context represents the evaluation context in which an expression can be +// Context is an instance of RawContext where the module identifier is a uint. +type Context = RawContext[uint] + +// RawContext represents the evaluation context in which an expression can be // evaluated. Firstly, every expression must have a single enclosing module // (i.e. all columns accessed by the expression are in that module); secondly, // the length multiplier for all columns accessed by the expression must be the @@ -22,10 +25,10 @@ import ( // multipliers must be powers of 2. Likewise, non-normal expressions (i.e those // with a multipler > 1) can only be used in a fairly limited number of // situtions (e.g. lookups). -type Context struct { +type RawContext[T comparable] struct { // Identifies the module in which this evaluation context exists. The empty // module is given by the maximum index (math.MaxUint). - module uint + module T // Identifies the length multiplier required to complete this context. In // essence, length multiplies divide up a given module into several disjoint // "subregions", such than every expression exists only in one of them. @@ -35,28 +38,30 @@ type Context struct { // VoidContext returns the void (or empty) context. This is the bottom type in // the lattice, and is the context contained in all other contexts. It is // needed, for example, as the context for constant expressions. -func VoidContext() Context { - return Context{math.MaxUint, 0} +func VoidContext[T comparable]() RawContext[T] { + var empty T + return RawContext[T]{empty, 0} } // ConflictingContext represents the case where multiple different contexts have // been joined together. For example, when determining the context of an // expression, the conflicting context is returned when no single context can be // deteremed. This value is generally considered to indicate an error. -func ConflictingContext() Context { - return Context{math.MaxUint - 1, 0} +func ConflictingContext[T comparable]() RawContext[T] { + var empty T + return RawContext[T]{empty, math.MaxUint - 1} } // NewContext returns a context representing the given module with the given // length multiplier. -func NewContext(module uint, multiplier uint) Context { - return Context{module, multiplier} +func NewContext[T comparable](module T, multiplier uint) RawContext[T] { + return RawContext[T]{module, multiplier} } // Module returns the module for this context. Note, however, that this is // nonsensical in the case of either the void or the conflicted context. In // this cases, this method will panic. -func (p Context) Module() uint { +func (p RawContext[T]) Module() T { if !p.IsVoid() && !p.IsConflicted() { return p.module } else if p.IsVoid() { @@ -69,7 +74,7 @@ func (p Context) Module() uint { // LengthMultiplier returns the length multiplier for this context. Note, // however, that this is nonsensical in the case of either the void or the // conflicted context. In this cases, this method will panic. -func (p Context) LengthMultiplier() uint { +func (p RawContext[T]) LengthMultiplier() uint { if !p.IsVoid() && !p.IsConflicted() { return p.multiplier } else if p.IsVoid() { @@ -81,38 +86,38 @@ func (p Context) LengthMultiplier() uint { // IsVoid checks whether this context is the void context (or not). This is the // bottom element in the lattice. -func (p Context) IsVoid() bool { - return p.module == math.MaxUint +func (p RawContext[T]) IsVoid() bool { + return p.multiplier == 0 } // IsConflicted checks whether this context represents the conflicted context. // This is the top element in the lattice, and is used to represent the case // where e.g. an expression has multiple conflicting contexts. -func (p Context) IsConflicted() bool { - return p.module == math.MaxUint-1 +func (p RawContext[T]) IsConflicted() bool { + return p.multiplier == math.MaxUint } // Multiply updates the length multiplier by multiplying it by a given factor, // producing the updated context. -func (p Context) Multiply(factor uint) Context { +func (p RawContext[T]) Multiply(factor uint) RawContext[T] { return NewContext(p.module, p.multiplier*factor) } // Join returns the least upper bound of the two contexts, or false if this does // not exist (i.e. the two context's are in conflict). -func (p Context) Join(other Context) Context { +func (p RawContext[T]) Join(other RawContext[T]) RawContext[T] { if p.IsVoid() { return other } else if other.IsVoid() { return p } else if p != other || p.IsConflicted() || other.IsConflicted() { // Conflicting contexts - return ConflictingContext() + return ConflictingContext[T]() } // Matching contexts return p } -func (p Context) String() string { - return fmt.Sprintf("%d*%d", p.module, p.multiplier) +func (p RawContext[T]) String() string { + return fmt.Sprintf("%v*%d", p.module, p.multiplier) }