diff --git a/pkg/corset/ast.go b/pkg/corset/ast.go index a0f7ceb..c1970a8 100644 --- a/pkg/corset/ast.go +++ b/pkg/corset/ast.go @@ -5,6 +5,7 @@ import ( sc "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/sexp" tr "github.com/consensys/go-corset/pkg/trace" + "github.com/consensys/go-corset/pkg/util" ) // Circuit represents the root of the Abstract Syntax Tree. This is also @@ -43,12 +44,53 @@ type ColumnAssignment struct { Type sc.Type } +// Symbol represents a variable or function access within a declaration. +// Initially, such the proper interpretation of such accesses is unclear and it +// is only later when we can distinguish them (e.g. whether its a column access, +// a constant access, etc). +type Symbol interface { + Node + // Determines whether this symbol is qualfied or not (i.e. has an explicitly + // module specifier). + IsQualified() bool + // Indicates whether or not this is a function. + IsFunction() bool + // Checks whether this symbol has been resolved already, or not. + IsResolved() bool + // Optional module qualification + Module() string + // Name of the symbol + Name() string + // Get binding associated with this interface. This will panic if this + // symbol is not yet resolved. + Binding() Binding + // Resolve this symbol by associating it with the binding associated with + // the definition of the symbol to which this refers. + Resolve(Binding) +} + +// SymbolDefinition represents a declaration (or part thereof) which defines a +// particular symbol. For example, "defcolumns" will define one or more symbols +// representing columns, etc. +type SymbolDefinition interface { + Node + // Name of symbol being defined + Name() string + // Indicates whether or not this is a function definition. + IsFunction() bool + // Allocated binding for the symbol which may or may not be finalised. + Binding() Binding +} + // Declaration represents a top-level declaration in a Corset source file (e.g. // defconstraint, defcolumns, etc). type Declaration interface { Node - // Simple marker to indicate this is really a declaration. - IsDeclaration() + // Returns the set of symbols being defined this declaration. Observe that + // these may not yet have been finalised. + Definitions() util.Iterator[SymbolDefinition] + // Return set of columns on which this declaration depends. + Dependencies() util.Iterator[Symbol] } // Assignment is a declaration which introduces one (or more) computed columns. @@ -66,13 +108,85 @@ type Assignment interface { Resolve(*Environment) ([]ColumnAssignment, []SyntaxError) } +// ColumnName represents a name within some syntactic item. Essentially this wraps a +// string and provides a mechanism for it to be associated with source line +// information. +type ColumnName struct { + name string + binding Binding +} + +// IsQualified determines whether this symbol is qualfied or not (i.e. has an +// explicit module specifier). Column names are never qualified. +func (e *ColumnName) IsQualified() bool { + return false +} + +// IsFunction indicates whether or not this symbol refers to a function (which +// of course it never does). +func (e *ColumnName) IsFunction() bool { + return false +} + +// IsResolved checks whether this symbol has been resolved already, or not. +func (e *ColumnName) IsResolved() bool { + return e.binding != nil +} + +// Module returns the optional module qualification. This always panics because +// column name's are never qualified. +func (e *ColumnName) Module() string { + panic("undefined") +} + +// Name returns the (unqualified) name of the column to which this symbol +// refers. +func (e *ColumnName) Name() string { + return e.name +} + +// Binding gets binding associated with this interface. This will panic if this +// symbol is not yet resolved. +func (e *ColumnName) Binding() Binding { + if e.binding == nil { + panic("name not yet resolved") + } + // + return e.binding +} + +// Resolve this symbol by associating it with the binding associated with +// the definition of the symbol to which this refers. +func (e *ColumnName) Resolve(binding Binding) { + if e.binding != nil { + panic("name already resolved") + } + // + e.binding = binding +} + +// Lisp converts this node into its lisp representation. This is primarily used +// for debugging purposes. +func (e *ColumnName) Lisp() sexp.SExp { + return sexp.NewSymbol(e.name) +} + // DefColumns captures a set of one or more columns being declared. type DefColumns struct { Columns []*DefColumn } -// IsDeclaration needed to signal declaration. -func (p *DefColumns) IsDeclaration() {} +// Dependencies needed to signal declaration. +func (p *DefColumns) Dependencies() util.Iterator[Symbol] { + return util.NewArrayIterator[Symbol](nil) +} + +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefColumns) Definitions() util.Iterator[SymbolDefinition] { + iter := util.NewArrayIterator(p.Columns) + return util.NewCastIterator[*DefColumn, SymbolDefinition](iter) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -84,20 +198,62 @@ func (p *DefColumns) Lisp() sexp.SExp { // column, such its name and type. type DefColumn struct { // Column name - Name string - // The datatype which all values in this column should inhabit. - DataType sc.Type - // Determines whether or not values in this column should be proven to be - // within the given type (i.e. using a range constraint). - MustProve bool - // Determines the length of this column as a multiple of the enclosing - // module. - LengthMultiplier uint + name string + // Binding of this column (which may or may not be finalised). + binding ColumnBinding +} + +// IsFunction is never true for a column definition. +func (e *DefColumn) IsFunction() bool { + return false +} + +// Binding returns the allocated binding for this symbol (which may or may not +// be finalised). +func (e *DefColumn) Binding() Binding { + return &e.binding +} + +// Name of symbol being defined +func (e *DefColumn) Name() string { + return e.name +} + +// DataType returns the type of this column. If this column have not yet been +// finalised, then this will panic. +func (e *DefColumn) DataType() sc.Type { + if !e.binding.IsFinalised() { + panic("unfinalised column") + } + // + return e.binding.dataType +} + +// LengthMultiplier returns the length multiplier of this column (where the +// height of this column is determined as the product of the enclosing module's +// height and this length multiplier). If this column have not yet been +// finalised, then this will panic. +func (e *DefColumn) LengthMultiplier() uint { + if !e.binding.IsFinalised() { + panic("unfinalised column") + } + // + return e.binding.multiplier +} + +// MustProve determines whether or not the type of this column must be +// established by the prover (e.g. a range constraint or similar). +func (e *DefColumn) MustProve() bool { + if !e.binding.IsFinalised() { + panic("unfinalised column") + } + // + return e.binding.mustProve } // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. -func (p *DefColumn) Lisp() sexp.SExp { +func (e *DefColumn) Lisp() sexp.SExp { panic("got here") } @@ -130,8 +286,24 @@ type DefConstraint struct { Constraint Expr } -// IsDeclaration needed to signal declaration. -func (p *DefConstraint) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefConstraint) Definitions() util.Iterator[SymbolDefinition] { + return util.NewArrayIterator[SymbolDefinition](nil) +} + +// Dependencies needed to signal declaration. +func (p *DefConstraint) Dependencies() util.Iterator[Symbol] { + var guard_deps []Symbol + // Extract guard's dependencies (if applicable) + if p.Guard != nil { + guard_deps = p.Guard.Dependencies() + } + // Extract bodies dependencies + body_deps := p.Constraint.Dependencies() + // Done + return util.NewArrayIterator[Symbol](append(guard_deps, body_deps...)) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -153,8 +325,16 @@ type DefInRange struct { Bound fr.Element } -// IsDeclaration needed to signal declaration. -func (p *DefInRange) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefInRange) Definitions() util.Iterator[SymbolDefinition] { + return util.NewArrayIterator[SymbolDefinition](nil) +} + +// Dependencies needed to signal declaration. +func (p *DefInRange) Dependencies() util.Iterator[Symbol] { + return util.NewArrayIterator[Symbol](p.Expr.Dependencies()) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -172,13 +352,22 @@ func (p *DefInRange) Lisp() sexp.SExp { // is required to hold an element from any source column). type DefInterleaved struct { // The target column being defined - Target string + Target *DefColumn // The source columns used to define the interleaved target column. - Sources []*DefName + Sources []Symbol } -// IsDeclaration needed to signal declaration. -func (p *DefInterleaved) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefInterleaved) Definitions() util.Iterator[SymbolDefinition] { + iter := util.NewUnitIterator(p.Target) + return util.NewCastIterator[*DefColumn, SymbolDefinition](iter) +} + +// Dependencies needed to signal declaration. +func (p *DefInterleaved) Dependencies() util.Iterator[Symbol] { + return util.NewArrayIterator(p.Sources) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -210,8 +399,19 @@ type DefLookup struct { Targets []Expr } -// IsDeclaration needed to signal declaration. -func (p *DefLookup) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefLookup) Definitions() util.Iterator[SymbolDefinition] { + return util.NewArrayIterator[SymbolDefinition](nil) +} + +// Dependencies needed to signal declaration. +func (p *DefLookup) Dependencies() util.Iterator[Symbol] { + sourceDeps := DependenciesOfExpressions(p.Sources) + targetDeps := DependenciesOfExpressions(p.Targets) + // Combine deps + return util.NewArrayIterator(append(sourceDeps, targetDeps...)) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -225,12 +425,21 @@ func (p *DefLookup) Lisp() sexp.SExp { // source columns can be specified as increasing or decreasing. type DefPermutation struct { Targets []*DefColumn - Sources []*DefName + Sources []Symbol Signs []bool } -// IsDeclaration needed to signal declaration. -func (p *DefPermutation) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefPermutation) Definitions() util.Iterator[SymbolDefinition] { + iter := util.NewArrayIterator(p.Targets) + return util.NewCastIterator[*DefColumn, SymbolDefinition](iter) +} + +// Dependencies needed to signal declaration. +func (p *DefPermutation) Dependencies() util.Iterator[Symbol] { + return util.NewArrayIterator(p.Sources) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -254,8 +463,16 @@ type DefProperty struct { Assertion Expr } -// IsDeclaration needed to signal declaration. -func (p *DefProperty) IsDeclaration() {} +// Definitions returns the set of symbols defined by this declaration. Observe that +// these may not yet have been finalised. +func (p *DefProperty) Definitions() util.Iterator[SymbolDefinition] { + return util.NewArrayIterator[SymbolDefinition](nil) +} + +// Dependencies needed to signal declaration. +func (p *DefProperty) Dependencies() util.Iterator[Symbol] { + return util.NewArrayIterator(p.Assertion.Dependencies()) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -271,19 +488,68 @@ func (p *DefProperty) Lisp() sexp.SExp { // parameters). In contrast, an impure function can access those columns // defined within its enclosing context. type DefFun struct { - Name *DefName - // Flag whether or not is pure function - Pure bool - // Return type - Return sc.Type + name string // Parameters - Parameters []*DefParameter - // Body - Body Expr + parameters []*DefParameter + // + binding FunctionBinding } -// IsDeclaration needed to signal declaration. -func (p *DefFun) IsDeclaration() {} +// IsFunction is always true for a function definition! +func (p *DefFun) IsFunction() bool { + return true +} + +// IsPure indicates whether or not this is a pure function. That is, a function +// which is not permitted to access any columns from the enclosing environment +// (either directly itself, or indirectly via functions it calls). +func (p *DefFun) IsPure() bool { + return p.binding.pure +} + +// Parameters returns information about the parameters defined by this +// declaration. +func (p *DefFun) Parameters() []*DefParameter { + return p.parameters +} + +// Body Access information about the parameters defined by this declaration. +func (p *DefFun) Body() Expr { + return p.binding.body +} + +// Binding returns the allocated binding for this symbol (which may or may not +// be finalised). +func (p *DefFun) Binding() Binding { + return &p.binding +} + +// Name of symbol being defined +func (p *DefFun) Name() string { + return p.name +} + +// Definitions returns the set of symbols defined by this declaration. Observe +// that these may not yet have been finalised. +func (p *DefFun) Definitions() util.Iterator[SymbolDefinition] { + iter := util.NewUnitIterator(p) + return util.NewCastIterator[*DefFun, SymbolDefinition](iter) +} + +// Dependencies needed to signal declaration. +func (p *DefFun) Dependencies() util.Iterator[Symbol] { + deps := p.binding.body.Dependencies() + ndeps := make([]Symbol, 0) + // Filter out all parameters declared in this function, since these are not + // external dependencies. + for _, d := range deps { + if d.IsQualified() || d.IsFunction() || !p.hasParameter(d.Name()) { + ndeps = append(ndeps, d) + } + } + // Done + return util.NewArrayIterator(ndeps) +} // Lisp converts this node into its lisp representation. This is primarily used // for debugging purposes. @@ -291,6 +557,18 @@ func (p *DefFun) Lisp() sexp.SExp { panic("got here") } +// hasParameter checks whether this function has a parameter with the given +// name, or not. +func (p *DefFun) hasParameter(name string) bool { + for _, v := range p.parameters { + if v.Name == name { + return true + } + } + // + return false +} + // DefParameter packages together those piece relevant to declaring an individual // parameter, such its name and type. type DefParameter struct { @@ -306,19 +584,6 @@ func (p *DefParameter) Lisp() sexp.SExp { panic("got here") } -// DefName is simply a wrapper around a string which can be associated with -// source information for producing syntax errors. -type DefName struct { - // Name of the column to be permuted - Name string -} - -// Lisp converts this node into its lisp representation. This is primarily used -// for debugging purposes. -func (p *DefName) Lisp() sexp.SExp { - panic("got here") -} - // Expr represents an arbitrary expression over the columns of a given context // (or the parameters of an enclosing function). Such expressions are pitched // at a higher-level than those of the underlying constraint system. For @@ -343,6 +608,9 @@ type Expr interface { // Substitute all variables (such as for function parameters) arising in // this expression. Substitute(args []Expr) Expr + + // Return set of columns on which this declaration depends. + Dependencies() []Symbol } // Context represents the evaluation context for a given expression. @@ -380,6 +648,11 @@ func (e *Add) Substitute(args []Expr) Expr { return &Add{SubstituteExpressions(e.Args, args)} } +// Dependencies needed to signal declaration. +func (e *Add) Dependencies() []Symbol { + return DependenciesOfExpressions(e.Args) +} + // ============================================================================ // Constants // ============================================================================ @@ -412,6 +685,11 @@ func (e *Constant) Substitute(args []Expr) Expr { return e } +// Dependencies needed to signal declaration. +func (e *Constant) Dependencies() []Symbol { + return nil +} + // ============================================================================ // Exponentiation // ============================================================================ @@ -447,6 +725,11 @@ func (e *Exp) Substitute(args []Expr) Expr { return &Exp{e.Arg.Substitute(args), e.Pow} } +// Dependencies needed to signal declaration. +func (e *Exp) Dependencies() []Symbol { + return e.Arg.Dependencies() +} + // ============================================================================ // IfZero // ============================================================================ @@ -490,6 +773,11 @@ func (e *IfZero) Substitute(args []Expr) Expr { } } +// Dependencies needed to signal declaration. +func (e *IfZero) Dependencies() []Symbol { + return DependenciesOfExpressions([]Expr{e.Condition, e.TrueBranch, e.FalseBranch}) +} + // ============================================================================ // List // ============================================================================ @@ -522,6 +810,11 @@ func (e *List) Substitute(args []Expr) Expr { return &List{SubstituteExpressions(e.Args, args)} } +// Dependencies needed to signal declaration. +func (e *List) Dependencies() []Symbol { + return DependenciesOfExpressions(e.Args) +} + // ============================================================================ // Multiplication // ============================================================================ @@ -554,6 +847,11 @@ func (e *Mul) Substitute(args []Expr) Expr { return &Mul{SubstituteExpressions(e.Args, args)} } +// Dependencies needed to signal declaration. +func (e *Mul) Dependencies() []Symbol { + return DependenciesOfExpressions(e.Args) +} + // ============================================================================ // Normalise // ============================================================================ @@ -587,6 +885,11 @@ func (e *Normalise) Substitute(args []Expr) Expr { return &Normalise{e.Arg.Substitute(args)} } +// Dependencies needed to signal declaration. +func (e *Normalise) Dependencies() []Symbol { + return e.Arg.Dependencies() +} + // ============================================================================ // Subtraction // ============================================================================ @@ -619,27 +922,91 @@ func (e *Sub) Substitute(args []Expr) Expr { return &Sub{SubstituteExpressions(e.Args, args)} } +// Dependencies needed to signal declaration. +func (e *Sub) Dependencies() []Symbol { + return DependenciesOfExpressions(e.Args) +} + // ============================================================================ -// VariableAccess +// Function Invocation // ============================================================================ // Invoke represents an attempt to invoke a given function. type Invoke struct { - Module *string - Name string - Args []Expr - Binding *FunctionBinding + module *string + name string + args []Expr + binding *FunctionBinding +} + +// IsQualified determines whether this symbol is qualfied or not (i.e. has an +// explicitly module specifier). +func (e *Invoke) IsQualified() bool { + return e.module != nil +} + +// IsFunction indicates whether or not this symbol refers to a function (which +// of course it always does). +func (e *Invoke) IsFunction() bool { + return true +} + +// IsResolved checks whether this symbol has been resolved already, or not. +func (e *Invoke) IsResolved() bool { + return e.binding != nil +} + +// Resolve this symbol by associating it with the binding associated with +// the definition of the symbol to which this refers. +func (e *Invoke) Resolve(binding Binding) { + if fb, ok := binding.(*FunctionBinding); ok { + e.binding = fb + return + } + // Problem + panic("cannot resolve function invocation with anything other than a function binding") +} + +// Module returns the optional module qualification. This will panic if this +// invocation is unqualified. +func (e *Invoke) Module() string { + if e.module == nil { + panic("invocation has no module qualifier") + } + + return *e.module +} + +// Name of the function being invoked. +func (e *Invoke) Name() string { + return e.name +} + +// Args returns the arguments provided by this invocation to the function being +// invoked. +func (e *Invoke) Args() []Expr { + return e.args +} + +// Binding gets binding associated with this interface. This will panic if this +// symbol is not yet resolved. +func (e *Invoke) Binding() Binding { + if e.binding == nil { + panic("invocation not yet resolved") + } + + return e.binding } // Context returns the context for this expression. Observe that the // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). func (e *Invoke) Context() Context { - if e.Binding == nil { + if e.binding == nil { panic("unresolved expressions encountered whilst resolving context") } // TODO: impure functions can have their own context. - return ContextOfExpressions(e.Args) + return ContextOfExpressions(e.args) } // Multiplicity determines the number of values that evaluating this expression @@ -658,7 +1025,15 @@ func (e *Invoke) Lisp() sexp.SExp { // Substitute all variables (such as for function parameters) arising in // this expression. func (e *Invoke) Substitute(args []Expr) Expr { - return &Invoke{e.Module, e.Name, SubstituteExpressions(e.Args, args), e.Binding} + return &Invoke{e.module, e.name, SubstituteExpressions(e.args, args), e.binding} +} + +// Dependencies needed to signal declaration. +func (e *Invoke) Dependencies() []Symbol { + deps := DependenciesOfExpressions(e.args) + // Include this expression as a symbol (which must be bound to the function + // being invoked) + return append(deps, e) } // ============================================================================ @@ -668,10 +1043,65 @@ func (e *Invoke) Substitute(args []Expr) Expr { // VariableAccess represents reading the value of a given local variable (such // as a function parameter). type VariableAccess struct { - Module *string - Name string - Shift int - Binding Binding + module *string + name string + shift int + binding Binding +} + +// IsQualified determines whether this symbol is qualfied or not (i.e. has an +// explicitly module specifier). +func (e *VariableAccess) IsQualified() bool { + return e.module != nil +} + +// IsFunction determines whether this symbol refers to a function (which, of +// course, variable accesses never do). +func (e *VariableAccess) IsFunction() bool { + return false +} + +// IsResolved checks whether this symbol has been resolved already, or not. +func (e *VariableAccess) IsResolved() bool { + return e.binding != nil +} + +// Resolve this symbol by associating it with the binding associated with +// the definition of the symbol to which this refers. +func (e *VariableAccess) Resolve(binding Binding) { + if binding == nil { + panic("empty binding") + } else if e.binding != nil { + panic("already resolved") + } + + e.binding = binding +} + +// Module returns the optional module qualification. This will panic if this +// invocation is unqualified. +func (e *VariableAccess) Module() string { + return *e.module +} + +// Name returns the (unqualified) name of this symbol +func (e *VariableAccess) Name() string { + return e.name +} + +// Binding gets binding associated with this interface. This will panic if this +// symbol is not yet resolved. +func (e *VariableAccess) Binding() Binding { + if e.binding == nil { + panic("variable access is unresolved") + } + // + return e.binding +} + +// Shift returns the row shift (if any) associated with this variable access. +func (e *VariableAccess) Shift() int { + return e.shift } // Multiplicity determines the number of values that evaluating this expression @@ -684,12 +1114,10 @@ func (e *VariableAccess) Multiplicity() uint { // expression must have been resolved for this to be defined (i.e. it may // panic if it has not been resolved yet). func (e *VariableAccess) Context() Context { - binding, ok := e.Binding.(*ColumnBinding) + binding, ok := e.binding.(*ColumnBinding) // if ok { return binding.Context() - } else if binding == nil { - panic("unresolved column access") } // panic("invalid column access") @@ -704,9 +1132,9 @@ func (e *VariableAccess) Lisp() sexp.SExp { // Substitute all variables (such as for function parameters) arising in // this expression. func (e *VariableAccess) Substitute(args []Expr) Expr { - if b, ok := e.Binding.(*ParameterBinding); ok { + if b, ok := e.binding.(*ParameterBinding); ok { // This is a variable to be substituted. - if e.Shift != 0 { + if e.shift != 0 { panic("support variable shifts") } // @@ -716,6 +1144,11 @@ func (e *VariableAccess) Substitute(args []Expr) Expr { return e } +// Dependencies needed to signal declaration. +func (e *VariableAccess) Dependencies() []Symbol { + return []Symbol{e} +} + // ============================================================================ // Helpers // ============================================================================ @@ -757,6 +1190,20 @@ func SubstituteOptionalExpression(expr Expr, vars []Expr) Expr { return expr } +// DependenciesOfExpressions determines the dependencies for a given set of zero +// or more expressions. +func DependenciesOfExpressions(exprs []Expr) []Symbol { + var deps []Symbol + // + for _, e := range exprs { + if e != nil { + deps = append(deps, e.Dependencies()...) + } + } + // + return deps +} + func determineMultiplicity(exprs []Expr) uint { width := uint(1) // diff --git a/pkg/corset/binding.go b/pkg/corset/binding.go index e17c26f..71c2d67 100644 --- a/pkg/corset/binding.go +++ b/pkg/corset/binding.go @@ -21,8 +21,8 @@ type BindingId struct { // Binding represents an association between a name, as found in a source file, // and concrete item (e.g. a column, function, etc). type Binding interface { - // Returns the context associated with this binding. - IsBinding() + // Determine whether this binding is finalised or not. + IsFinalised() bool } // ColumnBinding represents something bound to a given column. @@ -33,19 +33,23 @@ type ColumnBinding struct { module string // Determines whether this is a computed column, or not. computed bool + // Determines whether this column must be proven (or not). + mustProve bool // Column's length multiplier multiplier uint // Column's datatype - datatype sc.Type + dataType sc.Type } // NewColumnBinding constructs a new column binding in a given module. -func NewColumnBinding(module string, computed bool, multiplier uint, datatype sc.Type) *ColumnBinding { - return &ColumnBinding{math.MaxUint, module, computed, multiplier, datatype} +func NewColumnBinding(module string, computed bool, mustProve bool, multiplier uint, datatype sc.Type) *ColumnBinding { + return &ColumnBinding{math.MaxUint, module, computed, mustProve, multiplier, datatype} } -// IsBinding ensures this is an instance of Binding. -func (p *ColumnBinding) IsBinding() {} +// IsFinalised checks whether this binding has been finalised yet or not. +func (p *ColumnBinding) IsFinalised() bool { + return p.multiplier != 0 +} // Context returns the of this column. That is, the module in which this colunm // was declared and also the length multiplier of that module it requires. @@ -74,20 +78,38 @@ type ParameterBinding struct { index uint } -// IsBinding ensures this is an instance of Binding. -func (p *ParameterBinding) IsBinding() {} +// IsFinalised checks whether this binding has been finalised yet or not. +func (p *ParameterBinding) IsFinalised() bool { + panic("") +} // FunctionBinding represents the binding of a function application to its // physical definition. type FunctionBinding struct { - // arity determines the number of arguments this function takes. - arity uint + // Flag whether or not is pure function + pure bool + // Types of parameters + paramTypes []sc.Type + // Type of return + returnType sc.Type // body of the function in question. body Expr } -// IsBinding ensures this is an instance of Binding. -func (p *FunctionBinding) IsBinding() {} +// NewFunctionBinding constructs a new function binding. +func NewFunctionBinding(pure bool, paramTypes []sc.Type, returnType sc.Type, body Expr) FunctionBinding { + return FunctionBinding{pure, paramTypes, returnType, body} +} + +// IsFinalised checks whether this binding has been finalised yet or not. +func (p *FunctionBinding) IsFinalised() bool { + return p.returnType != nil +} + +// Arity returns the number of parameters that this function accepts. +func (p *FunctionBinding) Arity() uint { + return uint(len(p.paramTypes)) +} // Apply a given set of arguments to this function binding. func (p *FunctionBinding) Apply(args []Expr) Expr { diff --git a/pkg/corset/compiler.go b/pkg/corset/compiler.go index 0906fd1..2b048ec 100644 --- a/pkg/corset/compiler.go +++ b/pkg/corset/compiler.go @@ -1,8 +1,6 @@ package corset import ( - "fmt" - "github.com/consensys/go-corset/pkg/hir" "github.com/consensys/go-corset/pkg/sexp" ) @@ -72,9 +70,6 @@ func (p *Compiler) Compile() (*hir.Schema, []SyntaxError) { } // Convert global scope into an environment by allocating all columns. environment := scope.ToEnvironment() - // Check constraint contexts (e.g. for constraints, lookups, etc) - // Type check constraints - fmt.Println("Translating Circuit...") // Finally, translate everything and add it to the schema. return TranslateCircuit(environment, p.srcmap, &p.circuit) } diff --git a/pkg/corset/environment.go b/pkg/corset/environment.go index a801e86..e4e37e7 100644 --- a/pkg/corset/environment.go +++ b/pkg/corset/environment.go @@ -65,7 +65,7 @@ func (p GlobalEnvironment) Module(name string) *ModuleScope { func (p GlobalEnvironment) Column(module string, name string) *ColumnBinding { // Lookup the given binding, expecting that it is a column binding. If not, // then this will fail. - return p.Module(module).Bind(nil, name, false).(*ColumnBinding) + return p.Module(module).Column(name) } // ContextFrom constructs a trace context for a given module and length diff --git a/pkg/corset/parser.go b/pkg/corset/parser.go index af201d4..cca8901 100644 --- a/pkg/corset/parser.go +++ b/pkg/corset/parser.go @@ -97,7 +97,7 @@ func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], p := NewParser(srcfile, srcmap) // Parse whatever is declared at the beginning of the file before the first // module declaration. These declarations form part of the "prelude". - if circuit.Declarations, terms, errors = p.parseModuleContents(terms); len(errors) > 0 { + if circuit.Declarations, terms, errors = p.parseModuleContents("", terms); len(errors) > 0 { return circuit, nil, errors } // Continue parsing string until nothing remains. @@ -111,7 +111,7 @@ func ParseSourceFile(srcfile *sexp.SourceFile) (Circuit, *sexp.SourceMap[Node], return circuit, nil, errors } // Parse module contents - if decls, terms, errors = p.parseModuleContents(terms[1:]); len(errors) > 0 { + if decls, terms, errors = p.parseModuleContents(name, terms[1:]); len(errors) > 0 { return circuit, nil, errors } else if len(decls) != 0 { circuit.Modules = append(circuit.Modules, Module{name, decls}) @@ -177,7 +177,7 @@ func (p *Parser) mapSourceNode(from sexp.SExp, to Node) { } // Extract all declarations associated with a given module and package them up. -func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.SExp, []SyntaxError) { +func (p *Parser) parseModuleContents(module string, terms []sexp.SExp) ([]Declaration, []sexp.SExp, []SyntaxError) { var errors []SyntaxError // decls := make([]Declaration, 0) @@ -190,7 +190,7 @@ func (p *Parser) parseModuleContents(terms []sexp.SExp) ([]Declaration, []sexp.S errors = append(errors, *err) } else if e.MatchSymbols(2, "module") { return decls, terms[i:], nil - } else if decl, errs := p.parseDeclaration(e); errs != nil { + } else if decl, errs := p.parseDeclaration(module, e); errs != nil { errors = append(errors, errs...) } else { // Continue accumulating declarations for this module. @@ -225,7 +225,7 @@ func (p *Parser) parseModuleStart(s sexp.SExp) (string, []SyntaxError) { return name, nil } -func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, []SyntaxError) { +func (p *Parser) parseDeclaration(module string, s *sexp.List) (Declaration, []SyntaxError) { var ( decl Declaration errors []SyntaxError @@ -233,7 +233,7 @@ func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, []SyntaxError) { ) // if s.MatchSymbols(1, "defcolumns") { - decl, errors = p.parseDefColumns(s) + decl, errors = p.parseDefColumns(module, s) } else if s.Len() == 4 && s.MatchSymbols(2, "defconstraint") { decl, errors = p.parseDefConstraint(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "defpurefun") { @@ -241,11 +241,11 @@ func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, []SyntaxError) { } else if s.Len() == 3 && s.MatchSymbols(1, "definrange") { decl, err = p.parseDefInRange(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(1, "definterleaved") { - decl, err = p.parseDefInterleaved(s.Elements) + decl, err = p.parseDefInterleaved(module, s.Elements) } else if s.Len() == 4 && s.MatchSymbols(1, "deflookup") { decl, err = p.parseDefLookup(s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defpermutation") { - decl, err = p.parseDefPermutation(s.Elements) + decl, err = p.parseDefPermutation(module, s.Elements) } else if s.Len() == 3 && s.MatchSymbols(2, "defproperty") { decl, err = p.parseDefProperty(s.Elements) } else { @@ -264,7 +264,7 @@ func (p *Parser) parseDeclaration(s *sexp.List) (Declaration, []SyntaxError) { } // Parse a column declaration -func (p *Parser) parseDefColumns(l *sexp.List) (*DefColumns, []SyntaxError) { +func (p *Parser) parseDefColumns(module string, l *sexp.List) (*DefColumns, []SyntaxError) { columns := make([]*DefColumn, l.Len()-1) // Sanity check declaration if len(l.Elements) == 1 { @@ -275,7 +275,7 @@ func (p *Parser) parseDefColumns(l *sexp.List) (*DefColumns, []SyntaxError) { var errors []SyntaxError // Process column declarations one by one. for i := 1; i < len(l.Elements); i++ { - decl, err := p.parseColumnDeclaration(l.Elements[i]) + decl, err := p.parseColumnDeclaration(module, l.Elements[i]) // Extract column name if err != nil { errors = append(errors, *err) @@ -291,35 +291,37 @@ func (p *Parser) parseDefColumns(l *sexp.List) (*DefColumns, []SyntaxError) { return &DefColumns{columns}, nil } -func (p *Parser) parseColumnDeclaration(e sexp.SExp) (*DefColumn, *SyntaxError) { - defcolumn := &DefColumn{"", nil, false, 1} - // Default to field type - defcolumn.DataType = &sc.FieldType{} +func (p *Parser) parseColumnDeclaration(module string, e sexp.SExp) (*DefColumn, *SyntaxError) { + var name string + // + binding := NewColumnBinding(module, false, false, 1, &sc.FieldType{}) // Check whether extended declaration or not. if l := e.AsList(); l != nil { // Check at least the name provided. if len(l.Elements) == 0 { - return defcolumn, p.translator.SyntaxError(l, "empty column declaration") + return nil, p.translator.SyntaxError(l, "empty column declaration") } // Column name is always first - defcolumn.Name = l.Elements[0].String(false) + name = l.Elements[0].String(false) // Parse type (if applicable) if len(l.Elements) == 2 { var err *SyntaxError - if defcolumn.DataType, defcolumn.MustProve, err = p.parseType(l.Elements[1]); err != nil { - return defcolumn, err + if binding.dataType, binding.mustProve, err = p.parseType(l.Elements[1]); err != nil { + return nil, err } } else if len(l.Elements) > 2 { // For now. - return defcolumn, p.translator.SyntaxError(l, "unknown column declaration attributes") + return nil, p.translator.SyntaxError(l, "unknown column declaration attributes") } } else { - defcolumn.Name = e.String(false) + name = e.String(false) } + // + def := &DefColumn{name, *binding} // Update source mapping - p.mapSourceNode(e, defcolumn) + p.mapSourceNode(e, def) // - return defcolumn, nil + return def, nil } // Parse a vanishing declaration @@ -353,7 +355,7 @@ func (p *Parser) parseDefConstraint(elements []sexp.SExp) (*DefConstraint, []Syn } // Parse a interleaved declaration -func (p *Parser) parseDefInterleaved(elements []sexp.SExp) (*DefInterleaved, *SyntaxError) { +func (p *Parser) parseDefInterleaved(module string, elements []sexp.SExp) (*DefInterleaved, *SyntaxError) { // Initial sanity checks if elements[1].AsSymbol() == nil { return nil, p.translator.SyntaxError(elements[1], "malformed target column") @@ -361,9 +363,8 @@ func (p *Parser) parseDefInterleaved(elements []sexp.SExp) (*DefInterleaved, *Sy return nil, p.translator.SyntaxError(elements[2], "malformed source columns") } // Extract target and source columns - target := elements[1].AsSymbol().Value sexpSources := elements[2].AsList() - sources := make([]*DefName, sexpSources.Len()) + sources := make([]Symbol, sexpSources.Len()) // for i := 0; i != sexpSources.Len(); i++ { ith := sexpSources.Get(i) @@ -371,8 +372,14 @@ func (p *Parser) parseDefInterleaved(elements []sexp.SExp) (*DefInterleaved, *Sy return nil, p.translator.SyntaxError(ith, "malformed source column") } // Extract column name - sources[i] = &DefName{ith.AsSymbol().Value} + sources[i] = &ColumnName{ith.AsSymbol().Value, nil} + p.mapSourceNode(ith, sources[i]) } + // + binding := NewColumnBinding(module, false, false, 1, &sc.FieldType{}) + target := &DefColumn{elements[1].AsSymbol().Value, *binding} + // Updating mapping for target definition + p.mapSourceNode(elements[1], target) // Done return &DefInterleaved{target, sources}, nil } @@ -415,7 +422,7 @@ func (p *Parser) parseDefLookup(elements []sexp.SExp) (*DefLookup, *SyntaxError) } // Parse a permutation declaration -func (p *Parser) parseDefPermutation(elements []sexp.SExp) (*DefPermutation, *SyntaxError) { +func (p *Parser) parseDefPermutation(module string, elements []sexp.SExp) (*DefPermutation, *SyntaxError) { var err *SyntaxError // sexpTargets := elements[1].AsList() @@ -432,12 +439,12 @@ func (p *Parser) parseDefPermutation(elements []sexp.SExp) (*DefPermutation, *Sy } // targets := make([]*DefColumn, sexpTargets.Len()) - sources := make([]*DefName, sexpSources.Len()) + sources := make([]Symbol, sexpSources.Len()) signs := make([]bool, sexpSources.Len()) // for i := 0; i < len(targets); i++ { // Parse target column - if targets[i], err = p.parseColumnDeclaration(sexpTargets.Get(i)); err != nil { + if targets[i], err = p.parseColumnDeclaration(module, sexpTargets.Get(i)); err != nil { return nil, err } // Parse source column @@ -449,10 +456,10 @@ func (p *Parser) parseDefPermutation(elements []sexp.SExp) (*DefPermutation, *Sy return &DefPermutation{targets, sources, signs}, nil } -func (p *Parser) parsePermutedColumnDeclaration(signRequired bool, e sexp.SExp) (*DefName, bool, *SyntaxError) { +func (p *Parser) parsePermutedColumnDeclaration(signRequired bool, e sexp.SExp) (*ColumnName, bool, *SyntaxError) { var ( err *SyntaxError - name DefName + name *ColumnName sign bool ) // Check whether extended declaration or not. @@ -470,16 +477,16 @@ func (p *Parser) parsePermutedColumnDeclaration(signRequired bool, e sexp.SExp) return nil, false, err } // Parse column name - name.Name = l.Get(1).AsSymbol().Value + name = &ColumnName{l.Get(1).AsSymbol().Value, nil} } else if signRequired { return nil, false, p.translator.SyntaxError(e, "missing sort direction") } else { - name.Name = e.String(false) + name = &ColumnName{e.String(false), nil} } // Update source mapping - p.mapSourceNode(e, &name) + p.mapSourceNode(e, name) // - return &name, sign, nil + return name, sign, nil } func (p *Parser) parsePermutedColumnSign(sign *sexp.Symbol) (bool, *SyntaxError) { @@ -513,7 +520,7 @@ func (p *Parser) parseDefProperty(elements []sexp.SExp) (*DefProperty, *SyntaxEr // Parse a permutation declaration func (p *Parser) parseDefPureFun(elements []sexp.SExp) (*DefFun, []SyntaxError) { var ( - name *DefName + name string ret sc.Type params []*DefParameter errors []SyntaxError @@ -535,11 +542,18 @@ func (p *Parser) parseDefPureFun(elements []sexp.SExp) (*DefFun, []SyntaxError) if len(errors) > 0 { return nil, errors } + // Extract parameter types + paramTypes := make([]sc.Type, len(params)) + for i, p := range params { + paramTypes[i] = p.DataType + } + // Construct binding + binding := NewFunctionBinding(true, paramTypes, ret, body) // - return &DefFun{name, true, ret, params, body}, nil + return &DefFun{name, params, binding}, nil } -func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (*DefName, sc.Type, []*DefParameter, []SyntaxError) { +func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (string, sc.Type, []*DefParameter, []SyntaxError) { var ( name *sexp.Symbol = elements[0].AsSymbol() params []*DefParameter = make([]*DefParameter, len(elements)-1) @@ -561,10 +575,10 @@ func (p *Parser) parseFunctionSignature(elements []sexp.SExp) (*DefName, sc.Type } // Check for any errors arising if len(errors) > 0 { - return nil, nil, nil, errors + return "", nil, nil, errors } // - return &DefName{name.Value}, ret, params, nil + return name.Value, ret, params, nil } func (p *Parser) parseFunctionParameter(element sexp.SExp) (*DefParameter, []SyntaxError) { @@ -722,9 +736,9 @@ func varAccessParserRule(col string) (Expr, bool, error) { return &VariableAccess{&split[0], split[1], 0, nil}, true, nil } else if len(split) > 2 { return nil, true, errors.New("malformed column access") + } else { + return &VariableAccess{nil, col, 0, nil}, true, nil } - // Done - return &VariableAccess{nil, col, 0, nil}, true, nil } func addParserRule(_ string, args []Expr) (Expr, error) { @@ -761,9 +775,9 @@ func invokeParserRule(name string, args []Expr) (Expr, error) { return &Invoke{&split[0], split[1], args, nil}, nil } else if len(split) > 2 { return nil, errors.New("malformed function invocation") + } else { + return &Invoke{nil, name, args, nil}, nil } - // Done - return &Invoke{nil, name, args, nil}, nil } func shiftParserRule(col string, amt string) (Expr, error) { diff --git a/pkg/corset/resolver.go b/pkg/corset/resolver.go index 74180b0..240e90f 100644 --- a/pkg/corset/resolver.go +++ b/pkg/corset/resolver.go @@ -5,6 +5,7 @@ import ( "github.com/consensys/go-corset/pkg/schema" "github.com/consensys/go-corset/pkg/sexp" + "github.com/consensys/go-corset/pkg/util" ) // ResolveCircuit resolves all symbols declared and used within a circuit, @@ -25,9 +26,11 @@ func ResolveCircuit(srcmap *sexp.SourceMaps[Node], circuit *Circuit) (*GlobalSco // Construct resolver r := resolver{srcmap} // Allocate declared input columns - errs := r.resolveColumns(scope, circuit) - // Check expressions - errs = append(errs, r.resolveConstraints(scope, circuit)...) + errs := r.resolveDeclarations(scope, circuit) + // + if len(errs) > 0 { + return nil, errs + } // Done return scope, errs } @@ -41,69 +44,17 @@ type resolver struct { srcmap *sexp.SourceMaps[Node] } -// Process all input column or column assignment declarations. -func (r *resolver) resolveColumns(scope *GlobalScope, circuit *Circuit) []SyntaxError { - // Allocate input columns first. These must all be done before any - // assignments are allocated, since the hir.Schema separates these out. - ierrs := r.resolveInputColumns(scope, circuit) - // Now we can resolve any assignments. - aerrs := r.resolveAssignments(scope, circuit) - // - return append(ierrs, aerrs...) -} - -// Process all input column declarations. -func (r *resolver) resolveInputColumns(scope *GlobalScope, circuit *Circuit) []SyntaxError { - // Input columns must be allocated before assignemts, since the hir.Schema - // separates these out. - errs := r.resolveInputColumnsInModule(scope.Module(""), circuit.Declarations) - // - for _, m := range circuit.Modules { - // Process all declarations in the module - merrs := r.resolveInputColumnsInModule(scope.Module(m.Name), m.Declarations) - // Package up all errors - errs = append(errs, merrs...) - } - // - return errs -} - -// Resolve all input columns in a given module. -func (r *resolver) resolveInputColumnsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { - errors := make([]SyntaxError, 0) - // - for _, d := range decls { - if dcols, ok := d.(*DefColumns); ok { - // Found one. - for _, col := range dcols.Columns { - // Check whether column already exists - if scope.Bind(nil, col.Name, false) != nil { - msg := fmt.Sprintf("symbol %s already declared in %s", col.Name, scope.EnclosingModule()) - err := r.srcmap.SyntaxError(col, msg) - errors = append(errors, *err) - } else { - // Declare new column - scope.Declare(col.Name, false, NewColumnBinding(scope.EnclosingModule(), - false, col.LengthMultiplier, col.DataType)) - } - } - } - } - // Done - return errors -} - // Process all assignment column declarations. These are more complex than for // input columns, since there can be dependencies between them. Thus, we cannot // simply resolve them in one linear scan. -func (r *resolver) resolveAssignments(scope *GlobalScope, circuit *Circuit) []SyntaxError { +func (r *resolver) resolveDeclarations(scope *GlobalScope, circuit *Circuit) []SyntaxError { // Input columns must be allocated before assignemts, since the hir.Schema // separates these out. - errs := r.resolveAssignmentsInModule(scope.Module(""), circuit.Declarations) + errs := r.resolveDeclarationsInModule(scope.Module(""), circuit.Declarations) // for _, m := range circuit.Modules { // Process all declarations in the module - merrs := r.resolveAssignmentsInModule(scope.Module(m.Name), m.Declarations) + merrs := r.resolveDeclarationsInModule(scope.Module(m.Name), m.Declarations) // Package up all errors errs = append(errs, merrs...) } @@ -115,45 +66,32 @@ func (r *resolver) resolveAssignments(scope *GlobalScope, circuit *Circuit) []Sy // assignments can depend on the declaration of other columns. Hence, we have // to process all columns before we can sure that they are all declared // correctly. -func (r *resolver) resolveAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { - if errors := r.initialiseAssignmentsInModule(scope, decls); len(errors) > 0 { - return errors - } - // Check assignments - if errors := r.checkAssignmentsInModule(scope, decls); len(errors) > 0 { +func (r *resolver) resolveDeclarationsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { + if errors := r.initialiseDeclarationsInModule(scope, decls); len(errors) > 0 { return errors } // Iterate until all columns finalised - return r.finaliseAssignmentsInModule(scope, decls) + return r.finaliseDeclarationsInModule(scope, decls) } -// Initialise the column allocation from the available declarations, whilst -// identifying any duplicate declarations. Observe that, for some declarations, -// the initial assignment is incomplete because information about dependent -// columns may not be available. So, the goal of the subsequent phase is to -// flesh out this missing information. -func (r *resolver) initialiseAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { +// Initialise all declarations in the given module scope. That means allocating +// all bindings into the scope, whilst also ensuring that we never have two +// bindings for the same symbol, etc. The key is that, at this stage, all +// bindings are potentially "non-finalised". That means they may be missing key +// information which is yet to be determined (e.g. information about types, or +// contexts, etc). +func (r *resolver) initialiseDeclarationsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { module := scope.EnclosingModule() errors := make([]SyntaxError, 0) // for _, d := range decls { - if col, ok := d.(*DefInterleaved); ok { - if binding := scope.Bind(nil, col.Target, false); binding != nil { - err := r.srcmap.SyntaxError(col, fmt.Sprintf("symbol %s already declared in %s", col.Target, module)) + for iter := d.Definitions(); iter.HasNext(); { + def := iter.Next() + // Attempt to declare symbol + if !scope.Declare(def) { + msg := fmt.Sprintf("symbol %s already declared in %s", def.Name(), module) + err := r.srcmap.SyntaxError(def, msg) errors = append(errors, *err) - } else { - // Register incomplete (assignment) column. - scope.Declare(col.Target, false, NewColumnBinding(module, true, 0, nil)) - } - } else if col, ok := d.(*DefPermutation); ok { - for _, c := range col.Targets { - if binding := scope.Bind(nil, c.Name, false); binding != nil { - err := r.srcmap.SyntaxError(col, fmt.Sprintf("symbol %s already declared in %s", c.Name, module)) - errors = append(errors, *err) - } else { - // Register incomplete (assignment) column. - scope.Declare(c.Name, false, NewColumnBinding(scope.EnclosingModule(), true, 0, nil)) - } } } } @@ -161,30 +99,12 @@ func (r *resolver) initialiseAssignmentsInModule(scope *ModuleScope, decls []Dec return errors } -func (r *resolver) checkAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { - errors := make([]SyntaxError, 0) - // - for _, d := range decls { - if col, ok := d.(*DefInterleaved); ok { - for _, c := range col.Sources { - if scope.Bind(nil, c.Name, false) == nil { - errors = append(errors, *r.srcmap.SyntaxError(c, "unknown source column")) - } - } - } else if col, ok := d.(*DefPermutation); ok { - for _, c := range col.Sources { - if scope.Bind(nil, c.Name, false) == nil { - errors = append(errors, *r.srcmap.SyntaxError(c, "unknown source column")) - } - } - } - } - // Done - return errors -} - -// Iterate the column allocation to a fix point by iteratively fleshing out column information. -func (r *resolver) finaliseAssignmentsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { +// Finalise all declarations given in a module. This requires an iterative +// process as we cannot finalise a declaration until all of its dependencies +// have been themselves finalised. For example, a function which depends upon +// an interleaved column. Until the interleaved column is finalised, its type +// won't be available and, hence, we cannot type the function. +func (r *resolver) finaliseDeclarationsInModule(scope *ModuleScope, decls []Declaration) []SyntaxError { // Changed indicates whether or not a new assignment was finalised during a // given iteration. This is important to know since, if the assignment is // not complete and we didn't finalise any more assignments --- then, we've @@ -198,47 +118,45 @@ func (r *resolver) finaliseAssignmentsInModule(scope *ModuleScope, decls []Decla // For an incomplete assignment, this identifies the last declaration that // could not be finalised (i.e. as an example so we have at least one for // error reporting). - var incomplete Node = nil + var ( + incomplete Node = nil + counter uint = 4 + ) // - for changed && !complete { + for changed && !complete && counter > 0 { errors := make([]SyntaxError, 0) changed = false complete = true // for _, d := range decls { - if col, ok := d.(*DefInterleaved); ok { - // Check whether dependencies are resolved or not. - if r.columnsAreFinalised(scope, col.Sources) { - // Finalise assignment and handle any errors - errs := r.finaliseInterleavedAssignment(scope, col) - errors = append(errors, errs...) - // Record that a new assignment is available. - changed = changed || len(errs) == 0 - } else { - complete = false - incomplete = d - } - } else if col, ok := d.(*DefPermutation); ok { - // Check whether dependencies are resolved or not. - if r.columnsAreFinalised(scope, col.Sources) { - // Finalise assignment and handle any errors - errs := r.finalisePermutationAssignment(scope, col) - errors = append(errors, errs...) - // Record that a new assignment is available. - changed = changed || len(errs) == 0 - } else { - complete = false - incomplete = d - } + ready, errs := r.declarationDependenciesAreFinalised(scope, d.Dependencies()) + // See what arosed + if errs != nil { + errors = append(errors, errs...) + } else if ready { + // Finalise declaration and handle errors + errs := r.finaliseDeclaration(scope, d) + errors = append(errors, errs...) + // Record that a new assignment is available. + changed = changed || len(errs) == 0 + } else { + // Declaration not ready yet + complete = false + incomplete = d } } // Sanity check for any errors caught during this iteration. if len(errors) > 0 { return errors } + // Decrement counter + counter-- } // Check whether we actually finished the allocation. - if !complete { + if counter == 0 { + err := r.srcmap.SyntaxError(incomplete, "unable to complete resolution") + return []SyntaxError{*err} + } else if !complete { // No, we didn't. So, something is wrong --- assume it must be a cyclic // definition for now. err := r.srcmap.SyntaxError(incomplete, "cyclic declaration") @@ -249,20 +167,68 @@ func (r *resolver) finaliseAssignmentsInModule(scope *ModuleScope, decls []Decla } // Check that a given set of source columns have been finalised. This is -// important, since we cannot finalise an assignment until all of its +// important, since we cannot finalise a declaration until all of its // dependencies have themselves been finalised. -func (r *resolver) columnsAreFinalised(scope *ModuleScope, columns []*DefName) bool { - for _, col := range columns { - // Look up information - info := scope.Bind(nil, col.Name, false).(*ColumnBinding) - // Check whether its finalised - if info.multiplier == 0 { - // Nope, not yet. - return false +func (r *resolver) declarationDependenciesAreFinalised(scope *ModuleScope, + symbols util.Iterator[Symbol]) (bool, []SyntaxError) { + var ( + errors []SyntaxError + finalised bool = true + ) + // + for iter := symbols; iter.HasNext(); { + symbol := iter.Next() + // Attempt to resolve + if !symbol.IsResolved() && !scope.Bind(symbol) { + errors = append(errors, *r.srcmap.SyntaxError(symbol, "unknown symbol")) + // not finalised yet + finalised = false + } else if !symbol.Binding().IsFinalised() { + // no, not finalised + finalised = false } } // - return true + return finalised, errors +} + +// Finalise a declaration. +func (r *resolver) finaliseDeclaration(scope *ModuleScope, decl Declaration) []SyntaxError { + if d, ok := decl.(*DefConstraint); ok { + return r.finaliseDefConstraintInModule(scope, d) + } else if d, ok := decl.(*DefFun); ok { + return r.finaliseDefFunInModule(scope, d) + } else if d, ok := decl.(*DefInRange); ok { + return r.finaliseDefInRangeInModule(scope, d) + } else if d, ok := decl.(*DefInterleaved); ok { + return r.finaliseDefInterleavedInModule(d) + } else if d, ok := decl.(*DefLookup); ok { + return r.finaliseDefLookupInModule(scope, d) + } else if d, ok := decl.(*DefPermutation); ok { + return r.finaliseDefPermutationInModule(d) + } else if d, ok := decl.(*DefProperty); ok { + return r.finaliseDefPropertyInModule(scope, d) + } + // + return nil +} + +// Finalise a vanishing constraint declaration after all symbols have been +// resolved. This involves: (a) checking the context is valid; (b) checking the +// expressions are well-typed. +func (r *resolver) finaliseDefConstraintInModule(enclosing Scope, decl *DefConstraint) []SyntaxError { + var ( + errors []SyntaxError + scope = NewLocalScope(enclosing, false) + ) + // Resolve guard + if decl.Guard != nil { + errors = r.finaliseExpressionInModule(scope, decl.Guard) + } + // Resolve constraint body + errors = append(errors, r.finaliseExpressionInModule(scope, decl.Constraint)...) + // Done + return errors } // Finalise an interleaving assignment. Since the assignment would already been @@ -270,7 +236,7 @@ func (r *resolver) columnsAreFinalised(scope *ModuleScope, columns []*DefName) b // multiplier for the interleaved column. This can still result in an error, // for example, if the multipliers between interleaved columns are incompatible, // etc. -func (r *resolver) finaliseInterleavedAssignment(scope *ModuleScope, decl *DefInterleaved) []SyntaxError { +func (r *resolver) finaliseDefInterleavedInModule(decl *DefInterleaved) []SyntaxError { var ( // Length multiplier being determined length_multiplier uint @@ -281,37 +247,37 @@ func (r *resolver) finaliseInterleavedAssignment(scope *ModuleScope, decl *DefIn ) // Determine type and length multiplier for i, source := range decl.Sources { - // Lookup info of column being interleaved. - info := scope.Bind(nil, source.Name, false).(*ColumnBinding) + // Lookup binding of column being interleaved. + binding := source.Binding().(*ColumnBinding) // if i == 0 { - length_multiplier = info.multiplier - datatype = info.datatype - } else if info.multiplier != length_multiplier { + length_multiplier = binding.multiplier + datatype = binding.dataType + } else if binding.multiplier != length_multiplier { // Columns to be interleaved must have the same length multiplier. - err := r.srcmap.SyntaxError(decl, fmt.Sprintf("source column %s has incompatible length multiplier", source)) + err := r.srcmap.SyntaxError(decl, fmt.Sprintf("source column %s has incompatible length multiplier", source.Name())) errors = append(errors, *err) } // Combine datatypes. - datatype = schema.Join(datatype, info.datatype) + datatype = schema.Join(datatype, binding.dataType) } // Finalise details only if no errors if len(errors) == 0 { // Determine actual length multiplier length_multiplier *= uint(len(decl.Sources)) // Lookup existing declaration - info := scope.Bind(nil, decl.Target, false).(*ColumnBinding) + binding := decl.Target.Binding().(*ColumnBinding) // Update with completed information - info.multiplier = length_multiplier - info.datatype = datatype + binding.multiplier = length_multiplier + binding.dataType = datatype } // Done return errors } -// Finalise a permutation assignment. Since the assignment would already been -// initialised, this is actually quite easy to do. -func (r *resolver) finalisePermutationAssignment(scope *ModuleScope, decl *DefPermutation) []SyntaxError { +// Finalise a permutation assignment after all symbols have been resolved. This +// requires checking the contexts of all columns is consistent. +func (r *resolver) finaliseDefPermutationInModule(decl *DefPermutation) []SyntaxError { var ( multiplier uint = 0 errors []SyntaxError @@ -320,9 +286,9 @@ func (r *resolver) finalisePermutationAssignment(scope *ModuleScope, decl *DefPe for i := 0; i < len(decl.Sources); i++ { ith := decl.Sources[i] // Lookup source of column being permuted - source := scope.Bind(nil, ith.Name, false).(*ColumnBinding) + source := ith.Binding().(*ColumnBinding) // Sanity check length multiplier - if i == 0 && source.datatype.AsUint() == nil { + if i == 0 && source.dataType.AsUint() == nil { errors = append(errors, *r.srcmap.SyntaxError(ith, "fixed-width type required")) } else if i == 0 { multiplier = source.multiplier @@ -331,149 +297,84 @@ func (r *resolver) finalisePermutationAssignment(scope *ModuleScope, decl *DefPe errors = append(errors, *r.srcmap.SyntaxError(ith, "incompatible length multiplier")) } // All good, finalise target column - target := scope.Bind(nil, decl.Targets[i].Name, false).(*ColumnBinding) + target := decl.Targets[i].Binding().(*ColumnBinding) // Update with completed information target.multiplier = source.multiplier - target.datatype = source.datatype + target.dataType = source.dataType } // Done return errors } -// Examine each constraint and attempt to resolve any variables used within -// them. For example, a vanishing constraint may refer to some variable "X". -// Prior to this function being called, its not clear what "X" refers to --- it -// could refer to a column a constant, or even an alias. The purpose of this -// pass is to: firstly, check that every variable refers to something which was -// declared; secondly, to determine what each variable represents (i.e. column -// access, a constant, etc). -func (r *resolver) resolveConstraints(scope *GlobalScope, circuit *Circuit) []SyntaxError { - errs := r.resolveConstraintsInModule(scope.Module(""), circuit.Declarations) - // - for _, m := range circuit.Modules { - // Process all declarations in the module - merrs := r.resolveConstraintsInModule(scope.Module(m.Name), m.Declarations) - // Package up all errors - errs = append(errs, merrs...) - } - // - return errs -} - -// Helper for resolve constraints which considers those constraints declared in -// a particular module. -func (r *resolver) resolveConstraintsInModule(enclosing Scope, decls []Declaration) []SyntaxError { - var errors []SyntaxError - // - for _, d := range decls { - // Look for defcolumns decalarations only - if _, ok := d.(*DefColumns); ok { - // Safe to ignore. - } else if c, ok := d.(*DefConstraint); ok { - errors = append(errors, r.resolveDefConstraintInModule(enclosing, c)...) - } else if c, ok := d.(*DefInRange); ok { - errors = append(errors, r.resolveDefInRangeInModule(enclosing, c)...) - } else if _, ok := d.(*DefInterleaved); ok { - // Nothing to do here, since this assignment form contains no - // expressions to be resolved. - } else if c, ok := d.(*DefLookup); ok { - errors = append(errors, r.resolveDefLookupInModule(enclosing, c)...) - } else if _, ok := d.(*DefPermutation); ok { - // Nothing to do here, since this assignment form contains no - // expressions to be resolved. - } else if c, ok := d.(*DefFun); ok { - errors = append(errors, r.resolveDefFunInModule(enclosing, c)...) - } else if c, ok := d.(*DefProperty); ok { - errors = append(errors, r.resolveDefPropertyInModule(enclosing, c)...) - } else { - errors = append(errors, *r.srcmap.SyntaxError(d, "unknown declaration")) - } - } - // - return errors -} - -// Resolve those variables appearing in either the guard or the body of this constraint. -func (r *resolver) resolveDefConstraintInModule(enclosing Scope, decl *DefConstraint) []SyntaxError { - var ( - errors []SyntaxError - scope = NewLocalScope(enclosing, false) - ) - // Resolve guard - if decl.Guard != nil { - errors = r.resolveExpressionInModule(scope, decl.Guard) - } - // Resolve constraint body - errors = append(errors, r.resolveExpressionInModule(scope, decl.Constraint)...) - // Done - return errors -} - -// Resolve those variables appearing in the body of this range constraint. -func (r *resolver) resolveDefInRangeInModule(enclosing Scope, decl *DefInRange) []SyntaxError { +// Finalise a range constraint declaration after all symbols have been +// resolved. This involves: (a) checking the context is valid; (b) checking the +// expressions are well-typed. +func (r *resolver) finaliseDefInRangeInModule(enclosing Scope, decl *DefInRange) []SyntaxError { var ( errors []SyntaxError scope = NewLocalScope(enclosing, false) ) // Resolve property body - errors = append(errors, r.resolveExpressionInModule(scope, decl.Expr)...) + errors = append(errors, r.finaliseExpressionInModule(scope, decl.Expr)...) // Done return errors } -// Resolve those variables appearing in the body of this function. -func (r *resolver) resolveDefFunInModule(enclosing Scope, decl *DefFun) []SyntaxError { +// Finalise a function definition after all symbols have been resolved. This +// involves: (a) checking the context is valid for the body; (b) checking the +// body is well-typed; (c) for pure functions checking that no columns are +// accessed; (d) finally, resolving any parameters used within the body of this +// function. +func (r *resolver) finaliseDefFunInModule(enclosing Scope, decl *DefFun) []SyntaxError { var ( errors []SyntaxError scope = NewLocalScope(enclosing, false) ) // Declare parameters in local scope - for _, p := range decl.Parameters { + for _, p := range decl.Parameters() { scope.DeclareLocal(p.Name) } // Resolve property body - errors = append(errors, r.resolveExpressionInModule(scope, decl.Body)...) - // Remove parameters from enclosing environment + errors = append(errors, r.finaliseExpressionInModule(scope, decl.Body())...) // Done return errors } // Resolve those variables appearing in the body of this lookup constraint. -func (r *resolver) resolveDefLookupInModule(enclosing Scope, decl *DefLookup) []SyntaxError { +func (r *resolver) finaliseDefLookupInModule(enclosing Scope, decl *DefLookup) []SyntaxError { var ( errors []SyntaxError sourceScope = NewLocalScope(enclosing, true) targetScope = NewLocalScope(enclosing, true) ) - // Resolve source expressions - errors = append(errors, r.resolveExpressionsInModule(sourceScope, decl.Sources)...) + errors = append(errors, r.finaliseExpressionsInModule(sourceScope, decl.Sources)...) // Resolve target expressions - errors = append(errors, r.resolveExpressionsInModule(targetScope, decl.Targets)...) + errors = append(errors, r.finaliseExpressionsInModule(targetScope, decl.Targets)...) // Done return errors } // Resolve those variables appearing in the body of this property assertion. -func (r *resolver) resolveDefPropertyInModule(enclosing Scope, decl *DefProperty) []SyntaxError { +func (r *resolver) finaliseDefPropertyInModule(enclosing Scope, decl *DefProperty) []SyntaxError { var ( errors []SyntaxError scope = NewLocalScope(enclosing, false) ) // Resolve property body - errors = append(errors, r.resolveExpressionInModule(scope, decl.Assertion)...) + errors = append(errors, r.finaliseExpressionInModule(scope, decl.Assertion)...) // Done return errors } // Resolve a sequence of zero or more expressions within a given module. This // simply resolves each of the arguments in turn, collecting any errors arising. -func (r *resolver) resolveExpressionsInModule(scope LocalScope, args []Expr) []SyntaxError { +func (r *resolver) finaliseExpressionsInModule(scope LocalScope, args []Expr) []SyntaxError { var errors []SyntaxError // Visit each argument for _, arg := range args { if arg != nil { - errs := r.resolveExpressionInModule(scope, arg) + errs := r.finaliseExpressionInModule(scope, arg) errors = append(errors, errs...) } } @@ -486,27 +387,27 @@ func (r *resolver) resolveExpressionsInModule(scope LocalScope, args []Expr) []S // variable accesses. As above, the goal is ensure variable refers to something // that was declared and, more specifically, what kind of access it is (e.g. // column access, constant access, etc). -func (r *resolver) resolveExpressionInModule(scope LocalScope, expr Expr) []SyntaxError { +func (r *resolver) finaliseExpressionInModule(scope LocalScope, expr Expr) []SyntaxError { if _, ok := expr.(*Constant); ok { return nil } else if v, ok := expr.(*Add); ok { - return r.resolveExpressionsInModule(scope, v.Args) + return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Exp); ok { - return r.resolveExpressionInModule(scope, v.Arg) + return r.finaliseExpressionInModule(scope, v.Arg) } else if v, ok := expr.(*IfZero); ok { - return r.resolveExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) + return r.finaliseExpressionsInModule(scope, []Expr{v.Condition, v.TrueBranch, v.FalseBranch}) } else if v, ok := expr.(*Invoke); ok { - return r.resolveInvokeInModule(scope, v) + return r.finaliseInvokeInModule(scope, v) } else if v, ok := expr.(*List); ok { - return r.resolveExpressionsInModule(scope, v.Args) + return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Mul); ok { - return r.resolveExpressionsInModule(scope, v.Args) + return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*Normalise); ok { - return r.resolveExpressionInModule(scope, v.Arg) + return r.finaliseExpressionInModule(scope, v.Arg) } else if v, ok := expr.(*Sub); ok { - return r.resolveExpressionsInModule(scope, v.Args) + return r.finaliseExpressionsInModule(scope, v.Args) } else if v, ok := expr.(*VariableAccess); ok { - return r.resolveVariableInModule(scope, v) + return r.finaliseVariableInModule(scope, v) } else { return r.srcmap.SyntaxErrors(expr, "unknown expression") } @@ -515,44 +416,47 @@ func (r *resolver) resolveExpressionInModule(scope LocalScope, expr Expr) []Synt // Resolve a specific invocation contained within some expression which, in // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. -func (r *resolver) resolveInvokeInModule(scope LocalScope, expr *Invoke) []SyntaxError { +func (r *resolver) finaliseInvokeInModule(scope LocalScope, expr *Invoke) []SyntaxError { // Resolve arguments - if errors := r.resolveExpressionsInModule(scope, expr.Args); errors != nil { + if errors := r.finaliseExpressionsInModule(scope, expr.Args()); errors != nil { return errors } // Lookup the corresponding function definition. - binding := scope.Bind(nil, expr.Name, true) - // Check what we got - if fnBinding, ok := binding.(*FunctionBinding); ok { - expr.Binding = fnBinding - return nil + if !scope.Bind(expr) { + return r.srcmap.SyntaxErrors(expr, "unknown function") } - // - return r.srcmap.SyntaxErrors(expr, "unknown function") + // Success + return nil } // Resolve a specific variable access contained within some expression which, in // turn, is contained within some module. Note, qualified accesses are only // permitted in a global context. -func (r *resolver) resolveVariableInModule(scope LocalScope, +func (r *resolver) finaliseVariableInModule(scope LocalScope, expr *VariableAccess) []SyntaxError { // Check whether this is a qualified access, or not. - if !scope.IsGlobal() && expr.Module != nil { + if !scope.IsGlobal() && expr.IsQualified() { return r.srcmap.SyntaxErrors(expr, "qualified access not permitted here") - } else if expr.Module != nil && !scope.HasModule(*expr.Module) { - return r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown module %s", *expr.Module)) + } else if expr.IsQualified() && !scope.HasModule(expr.Module()) { + return r.srcmap.SyntaxErrors(expr, fmt.Sprintf("unknown module %s", expr.Module())) } - // Attempt resolve this variable access, noting that it definitely does not - // refer to a function. - if expr.Binding = scope.Bind(expr.Module, expr.Name, false); expr.Binding != nil { + // Symbol should be resolved at this point, but we still need to check the + // context. + if expr.IsResolved() { // Update context - binding, ok := expr.Binding.(*ColumnBinding) + binding, ok := expr.Binding().(*ColumnBinding) if ok && !scope.FixContext(binding.Context()) { return r.srcmap.SyntaxErrors(expr, "conflicting context") + } else if !ok { + // Unable to resolve variable + return r.srcmap.SyntaxErrors(expr, "not a column") } // Done return nil + } else if scope.Bind(expr) { + // Must be a local variable or parameter access, so we're all good. + return nil } // Unable to resolve variable - return r.srcmap.SyntaxErrors(expr, "unknown symbol") + return r.srcmap.SyntaxErrors(expr, "unresolved symbol") } diff --git a/pkg/corset/scope.go b/pkg/corset/scope.go index 984f6ea..fab5c1f 100644 --- a/pkg/corset/scope.go +++ b/pkg/corset/scope.go @@ -11,16 +11,12 @@ import ( // variable used within an expression refers to. For example, a variable can // refer to a column, or a parameter, etc. type Scope interface { - // Get the name of the enclosing module. This is generally useful for - // reporting errors. - EnclosingModule() string // HasModule checks whether a given module exists, or not. HasModule(string) bool - // Lookup a given variable being referenced with an optional module - // specifier. This variable could correspond to a column, a function, a - // parameter, or a local variable. Furthermore, the returned binding will - // be nil if this variable does not exist. - Bind(*string, string, bool) Binding + // Attempt to bind a given symbol within this scope. If successful, the + // symbol is then resolved with the appropriate binding. Return value + // indicates whether successful or not. + Bind(Symbol) bool } // ============================================================================= @@ -56,12 +52,6 @@ func (p *GlobalScope) DeclareModule(module string) { p.ids[module] = mid } -// EnclosingModule returns the name of the enclosing module. For a global -// scope, this has no meaning. -func (p *GlobalScope) EnclosingModule() string { - panic("unreachable") -} - // HasModule checks whether a given module exists, or not. func (p *GlobalScope) HasModule(module string) bool { // Attempt to lookup the module @@ -72,18 +62,25 @@ func (p *GlobalScope) HasModule(module string) bool { // Bind looks up a given variable being referenced within a given module. For a // root context, this is either a column, an alias or a function declaration. -func (p *GlobalScope) Bind(module *string, name string, fn bool) Binding { - if module == nil { +func (p *GlobalScope) Bind(symbol Symbol) bool { + if !symbol.IsQualified() { panic("cannot bind unqualified symbol in the global scope") + } else if !p.HasModule(symbol.Module()) { + // Pontially, it might be better to report a more useful error message. + return false } // - return p.Module(*module).Bind(nil, name, fn) + return p.Module(symbol.Module()).Bind(symbol) } -// Module returns the identifier of the module with the given name. +// Module returns the identifier of the module with the given name. Observe +// that this will panic if the module in question does not exist. func (p *GlobalScope) Module(name string) *ModuleScope { - mid := p.ids[name] - return &p.modules[mid] + if mid, ok := p.ids[name]; ok { + return &p.modules[mid] + } + // Problem. + panic(fmt.Sprintf("unknown module \"%s\"", name)) } // ToEnvironment converts this global scope into a concrete environment by @@ -123,32 +120,51 @@ func (p *ModuleScope) HasModule(module string) bool { // Bind looks up a given variable being referenced within a given module. For a // root context, this is either a column, an alias or a function declaration. -func (p *ModuleScope) Bind(module *string, name string, fn bool) Binding { +func (p *ModuleScope) Bind(symbol Symbol) bool { // Determine module for this lookup. - if module != nil { + if symbol.IsQualified() && symbol.Module() != p.module { // non-local lookup - return p.enclosing.Bind(module, name, fn) + return p.enclosing.Bind(symbol) } // construct binding identifier - if bid, ok := p.ids[BindingId{name, fn}]; ok { - return p.bindings[bid] + id := BindingId{symbol.Name(), symbol.IsFunction()} + // Look for it. + if bid, ok := p.ids[id]; ok { + // Extract binding + binding := p.bindings[bid] + // Resolve symbol + symbol.Resolve(binding) + // Success + return true } // failed - return nil + return false +} + +// Column returns information about a particular column declared within this +// module. +func (p *ModuleScope) Column(name string) *ColumnBinding { + // construct binding identifier + bid := p.ids[BindingId{name, false}] + // + return p.bindings[bid].(*ColumnBinding) } // Declare declares a given binding within this module scope. -func (p *ModuleScope) Declare(name string, fn bool, binding Binding) { +func (p *ModuleScope) Declare(symbol SymbolDefinition) bool { // construct binding identifier - bid := BindingId{name, fn} + bid := BindingId{symbol.Name(), symbol.IsFunction()} // Sanity check not already declared if _, ok := p.ids[bid]; ok { - panic(fmt.Sprintf("attempt to redeclare binding for \"%s\"", name)) + // Cannot redeclare + return false } // Done id := uint(len(p.bindings)) - p.bindings = append(p.bindings, binding) + p.bindings = append(p.bindings, symbol.Binding()) p.ids[bid] = id + // + return true } // ============================================================================= @@ -197,12 +213,6 @@ func (p LocalScope) IsGlobal() bool { return p.global } -// EnclosingModule returns the name of the enclosing module. This is generally -// useful for reporting errors. -func (p LocalScope) EnclosingModule() string { - return p.enclosing.EnclosingModule() -} - // FixContext fixes the context for this scope. Since every scope requires // exactly one context, this fails if we fix it to incompatible contexts. func (p LocalScope) FixContext(context Context) bool { @@ -219,14 +229,16 @@ func (p LocalScope) HasModule(module string) bool { // Bind looks up a given variable or function being referenced either within the // enclosing scope (module==nil) or within a specified module. -func (p LocalScope) Bind(module *string, name string, fn bool) Binding { +func (p LocalScope) Bind(symbol Symbol) bool { // Check whether this is a local variable access. - if id, ok := p.locals[name]; ok && !fn && module == nil { + if id, ok := p.locals[symbol.Name()]; ok && !symbol.IsFunction() && !symbol.IsQualified() { // Yes, this is a local variable access. - return &ParameterBinding{id} + symbol.Resolve(&ParameterBinding{id}) + // Success + return true } // No, this is not a local variable access. - return p.enclosing.Bind(module, name, fn) + return p.enclosing.Bind(symbol) } // DeclareLocal registers a new local variable (e.g. a parameter). diff --git a/pkg/corset/translator.go b/pkg/corset/translator.go index 194c6b0..c00d2e8 100644 --- a/pkg/corset/translator.go +++ b/pkg/corset/translator.go @@ -91,15 +91,15 @@ func (t *translator) translateDefColumns(decl *DefColumns, module string) []Synt var errors []SyntaxError // Add each column to schema for _, c := range decl.Columns { - context := t.env.ContextFrom(module, c.LengthMultiplier) - cid := t.schema.AddDataColumn(context, c.Name, c.DataType) + context := t.env.ContextFrom(module, c.LengthMultiplier()) + cid := t.schema.AddDataColumn(context, c.Name(), c.DataType()) // Prove type (if requested) - if c.MustProve { - bound := c.DataType.AsUint().Bound() - t.schema.AddRangeConstraint(c.Name, context, &hir.ColumnAccess{Column: cid, Shift: 0}, bound) + if c.MustProve() { + bound := c.DataType().AsUint().Bound() + t.schema.AddRangeConstraint(c.Name(), context, &hir.ColumnAccess{Column: cid, Shift: 0}, bound) } // Sanity check column identifier - if info := t.env.Column(module, c.Name); info.ColumnId() != cid { + if info := t.env.Column(module, c.Name()); info.ColumnId() != cid { errors = append(errors, *t.srcmap.SyntaxError(c, "invalid column identifier")) } } @@ -228,15 +228,15 @@ func (t *translator) translateDefInterleaved(decl *DefInterleaved, module string // sources := make([]uint, len(decl.Sources)) // Lookup target column info - info := t.env.Column(module, decl.Target) + info := t.env.Column(module, decl.Target.Name()) // Determine source column identifiers for i, source := range decl.Sources { - sources[i] = t.env.Column(module, source.Name).ColumnId() + sources[i] = t.env.Column(module, source.Name()).ColumnId() } // Construct context for this assignment context := t.env.ContextFrom(module, info.multiplier) // Register assignment - cid := t.schema.AddAssignment(assignment.NewInterleaving(context, decl.Target, sources, info.datatype)) + cid := t.schema.AddAssignment(assignment.NewInterleaving(context, decl.Target.Name(), sources, info.dataType)) // Sanity check column identifiers align. if cid != info.ColumnId() { errors = append(errors, *t.srcmap.SyntaxError(decl, "invalid column identifier")) @@ -258,10 +258,10 @@ func (t *translator) translateDefPermutation(decl *DefPermutation, module string sources := make([]uint, len(decl.Sources)) // for i := 0; i < len(decl.Sources); i++ { - target := t.env.Column(module, decl.Targets[i].Name) + target := t.env.Column(module, decl.Targets[i].Name()) context = t.env.ContextFrom(module, target.multiplier) - targets[i] = sc.NewColumn(context, decl.Targets[i].Name, target.datatype) - sources[i] = t.env.Column(module, decl.Sources[i].Name).ColumnId() + targets[i] = sc.NewColumn(context, decl.Targets[i].Name(), target.dataType) + sources[i] = t.env.Column(module, decl.Sources[i].Name()).ColumnId() signs[i] = decl.Signs[i] // Record first CID if i == 0 { @@ -354,12 +354,14 @@ func (t *translator) translateExpressionInModule(expr Expr, module string) (hir. args, errs := t.translateExpressionsInModule([]Expr{v.Condition, v.TrueBranch, v.FalseBranch}, module) return &hir.IfZero{Condition: args[0], TrueBranch: args[1], FalseBranch: args[2]}, errs } else if e, ok := expr.(*Invoke); ok { - if e.Binding != nil && e.Binding.arity == uint(len(e.Args)) { - body := e.Binding.Apply(e.Args) - return t.translateExpressionInModule(body, module) - } else if e.Binding != nil { - msg := fmt.Sprintf("incorrect number of arguments (expected %d, found %d)", e.Binding.arity, len(e.Args)) - return nil, t.srcmap.SyntaxErrors(expr, msg) + if binding, ok := e.Binding().(*FunctionBinding); ok { + if binding.Arity() == uint(len(e.Args())) { + body := binding.Apply(e.Args()) + return t.translateExpressionInModule(body, module) + } else { + msg := fmt.Sprintf("incorrect number of arguments (expected %d, found %d)", binding.Arity(), len(e.Args())) + return nil, t.srcmap.SyntaxErrors(expr, msg) + } } // return nil, t.srcmap.SyntaxErrors(expr, "unbound function") @@ -376,11 +378,11 @@ func (t *translator) translateExpressionInModule(expr Expr, module string) (hir. args, errs := t.translateExpressionsInModule(v.Args, module) return &hir.Sub{Args: args}, errs } else if e, ok := expr.(*VariableAccess); ok { - if binding, ok := e.Binding.(*ColumnBinding); ok { + if binding, ok := e.Binding().(*ColumnBinding); ok { // Lookup column binding - cinfo := t.env.Column(binding.module, e.Name) + cinfo := t.env.Column(binding.module, e.Name()) // Done - return &hir.ColumnAccess{Column: cinfo.ColumnId(), Shift: e.Shift}, nil + return &hir.ColumnAccess{Column: cinfo.ColumnId(), Shift: e.Shift()}, nil } // error return nil, t.srcmap.SyntaxErrors(expr, "unbound variable") diff --git a/pkg/trace/context.go b/pkg/trace/context.go index 611cfcc..c436fb3 100644 --- a/pkg/trace/context.go +++ b/pkg/trace/context.go @@ -49,7 +49,7 @@ func VoidContext[T comparable]() RawContext[T] { // deteremed. This value is generally considered to indicate an error. func ConflictingContext[T comparable]() RawContext[T] { var empty T - return RawContext[T]{empty, math.MaxUint - 1} + return RawContext[T]{empty, math.MaxUint} } // NewContext returns a context representing the given module with the given @@ -119,5 +119,11 @@ func (p RawContext[T]) Join(other RawContext[T]) RawContext[T] { } func (p RawContext[T]) String() string { + if p.IsVoid() { + return "⊥" + } else if p.IsConflicted() { + return "⊤" + } + // Valid multiplier. return fmt.Sprintf("%v*%d", p.module, p.multiplier) }