Skip to content

Commit

Permalink
feat: support corset pure functions (#404)
Browse files Browse the repository at this point in the history
This adds support for `defpurefun` declarations. However, there are still a number of things outstanding to be done.
  • Loading branch information
DavePearce authored Nov 29, 2024
1 parent 5aeed0b commit f1d3d94
Show file tree
Hide file tree
Showing 23 changed files with 781 additions and 239 deletions.
212 changes: 182 additions & 30 deletions pkg/corset/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
sc "github.com/consensys/go-corset/pkg/schema"
"github.com/consensys/go-corset/pkg/sexp"
"github.com/consensys/go-corset/pkg/trace"
tr "github.com/consensys/go-corset/pkg/trace"
)

Expand Down Expand Up @@ -175,7 +174,7 @@ type DefInterleaved struct {
// The target column being defined
Target string
// The source columns used to define the interleaved target column.
Sources []*DefSourceColumn
Sources []*DefName
}

// CanFinalise checks whether or not this interleaving is ready to be finalised.
Expand All @@ -200,19 +199,6 @@ func (p *DefInterleaved) Lisp() sexp.SExp {
panic("got here")
}

// DefSourceColumn provides information about a column being permuted by a
// sorted permutation.
type DefSourceColumn struct {
// Name of the column to be permuted
Name string
}

// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefSourceColumn) Lisp() sexp.SExp {
panic("got here")
}

// DefLookup represents a lookup constraint between a set N of source
// expressions and a set of N target expressions. The source expressions must
// have a single context (i.e. all be in the same module) and likewise for the
Expand Down Expand Up @@ -325,6 +311,52 @@ func (p *DefProperty) Lisp() sexp.SExp {
// parameters). In contrast, an impure function can access those columns
// defined within its enclosing context.
type DefFun struct {
Name *DefName
// Flag whether or not is pure function
Pure bool
// Return type
Return sc.Type
// Parameters
Parameters []*DefParameter
// Body
Body Expr
}

// IsDeclaration needed to signal declaration.
func (p *DefFun) IsDeclaration() {}

// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefFun) Lisp() sexp.SExp {
panic("got here")
}

// DefParameter packages together those piece relevant to declaring an individual
// parameter, such its name and type.
type DefParameter struct {
// Column name
Name string
// The datatype which all values in this parameter should inhabit.
DataType sc.Type
}

// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefParameter) Lisp() sexp.SExp {
panic("got here")
}

// DefName is simply a wrapper around a string which can be associated with
// source information for producing syntax errors.
type DefName struct {
// Name of the column to be permuted
Name string
}

// Lisp converts this node into its lisp representation. This is primarily used
// for debugging purposes.
func (p *DefName) Lisp() sexp.SExp {
panic("got here")
}

// Expr represents an arbitrary expression over the columns of a given context
Expand All @@ -347,6 +379,10 @@ type Expr interface {
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
Context() tr.Context

// Substitute all variables (such as for function parameters) arising in
// this expression.
Substitute(args []Expr) Expr
}

// ============================================================================
Expand Down Expand Up @@ -375,6 +411,12 @@ func (e *Add) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Add) Substitute(args []Expr) Expr {
return &Add{SubstituteExpressions(e.Args, args)}
}

// ============================================================================
// Constants
// ============================================================================
Expand All @@ -401,6 +443,12 @@ func (e *Constant) Lisp() sexp.SExp {
return sexp.NewSymbol(e.Val.String())
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Constant) Substitute(args []Expr) Expr {
return e
}

// ============================================================================
// Exponentiation
// ============================================================================
Expand Down Expand Up @@ -430,6 +478,12 @@ func (e *Exp) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Exp) Substitute(args []Expr) Expr {
return &Exp{e.Arg.Substitute(args), e.Pow}
}

// ============================================================================
// IfZero
// ============================================================================
Expand Down Expand Up @@ -464,6 +518,15 @@ func (e *IfZero) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *IfZero) Substitute(args []Expr) Expr {
return &IfZero{e.Condition.Substitute(args),
SubstituteOptionalExpression(e.TrueBranch, args),
SubstituteOptionalExpression(e.FalseBranch, args),
}
}

// ============================================================================
// List
// ============================================================================
Expand All @@ -490,6 +553,12 @@ func (e *List) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *List) Substitute(args []Expr) Expr {
return &List{SubstituteExpressions(e.Args, args)}
}

// ============================================================================
// Multiplication
// ============================================================================
Expand All @@ -516,6 +585,12 @@ func (e *Mul) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Mul) Substitute(args []Expr) Expr {
return &Mul{SubstituteExpressions(e.Args, args)}
}

// ============================================================================
// Normalise
// ============================================================================
Expand Down Expand Up @@ -543,6 +618,12 @@ func (e *Normalise) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Normalise) Substitute(args []Expr) Expr {
return &Normalise{e.Arg.Substitute(args)}
}

// ============================================================================
// Subtraction
// ============================================================================
Expand All @@ -569,6 +650,54 @@ func (e *Sub) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Sub) Substitute(args []Expr) Expr {
return &Sub{SubstituteExpressions(e.Args, args)}
}

// ============================================================================
// VariableAccess
// ============================================================================

// Invoke represents an attempt to invoke a given function.
type Invoke struct {
Module *string
Name string
Args []Expr
Binding *FunctionBinding
}

// Context returns the context for this expression. Observe that the
// expression must have been resolved for this to be defined (i.e. it may
// panic if it has not been resolved yet).
func (e *Invoke) Context() tr.Context {
if e.Binding == nil {
panic("unresolved expressions encountered whilst resolving context")
}
// TODO: impure functions can have their own context.
return ContextOfExpressions(e.Args)
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *Invoke) Multiplicity() uint {
// FIXME: is this always correct?
return 1
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
func (e *Invoke) Lisp() sexp.SExp {
panic("todo")
}

// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *Invoke) Substitute(args []Expr) Expr {
return &Invoke{e.Module, e.Name, SubstituteExpressions(e.Args, args), e.Binding}
}

// ============================================================================
// VariableAccess
// ============================================================================
Expand All @@ -579,13 +708,12 @@ type VariableAccess struct {
Module *string
Name string
Shift int
Binding *Binder
Binding Binding
}

// Multiplicity determines the number of values that evaluating this expression
// can generate.
func (e *VariableAccess) Multiplicity() uint {
// NOTE: this might not be true for invocations.
return 1
}

Expand All @@ -597,26 +725,28 @@ func (e *VariableAccess) Context() tr.Context {
panic("unresolved expressions encountered whilst resolving context")
}
// Extract saved context
return e.Binding.Context
return e.Binding.Context()
}

// Lisp converts this schema element into a simple S-Expression, for example
// so it can be printed.
// so it can be printed.a
func (e *VariableAccess) Lisp() sexp.SExp {
panic("todo")
}

// Binder provides additional information determined during the resolution
// phase. Specifically, it clarifies the meaning of a given variable name used
// within an expression (i.e. is it a column access, a local variable access,
// etc).
type Binder struct {
// Identifies whether this is a column access, or a variable access.
Column bool
// For a column access, this identifies the enclosing context.
Context trace.Context
// Identifies the variable or column index (as appropriate).
Index uint
// Substitute all variables (such as for function parameters) arising in
// this expression.
func (e *VariableAccess) Substitute(args []Expr) Expr {
if b, ok := e.Binding.(*ParameterBinding); ok {
// This is a variable to be substituted.
if e.Shift != 0 {
panic("support variable shifts")
}
//
return args[b.index]
}
// Nothing to do here
return e
}

// ============================================================================
Expand All @@ -638,6 +768,28 @@ func ContextOfExpressions(exprs []Expr) tr.Context {
return context
}

// SubstituteExpressions substitutes all variables found in a given set of
// expressions.
func SubstituteExpressions(exprs []Expr, vars []Expr) []Expr {
nexprs := make([]Expr, len(exprs))
//
for i := 0; i < len(nexprs); i++ {
nexprs[i] = exprs[i].Substitute(vars)
}
//
return nexprs
}

// SubstituteOptionalExpression substitutes through an expression which is
// optional (i.e. might be nil). In such case, nil is returned.
func SubstituteOptionalExpression(expr Expr, vars []Expr) Expr {
if expr != nil {
expr = expr.Substitute(vars)
}
//
return expr
}

func determineMultiplicity(exprs []Expr) uint {
width := uint(1)
//
Expand Down
62 changes: 62 additions & 0 deletions pkg/corset/binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package corset

import (
tr "github.com/consensys/go-corset/pkg/trace"
)

// Binding represents an association between a name, as found in a source file,
// and concrete item (e.g. a column, function, etc).
type Binding interface {
// Returns the context associated with this binding.
Context() tr.Context
}

// ColumnBinding represents something bound to a given column.
type ColumnBinding struct {
// For a column access, this identifies the enclosing context.
context tr.Context
// Identifies the variable or column index (as appropriate).
index uint
}

// Context returns the enclosing context for this column access.
func (p *ColumnBinding) Context() tr.Context {
return p.context
}

// ColumnID returns the column identifier that this column access refers to.
func (p *ColumnBinding) ColumnID() uint {
return p.index
}

// ParameterBinding represents something bound to a given column.
type ParameterBinding struct {
// Identifies the variable or column index (as appropriate).
index uint
}

// Context for a parameter is always void, as it does not correspond to a column
// in given module.
func (p *ParameterBinding) Context() tr.Context {
return tr.VoidContext()
}

// FunctionBinding represents the binding of a function application to its
// physical definition.
type FunctionBinding struct {
// arity determines the number of arguments this function takes.
arity uint
// body of the function in question.
body Expr
}

// Context for a parameter is always void, as it does not correspond to a column
// in given module.
func (p *FunctionBinding) Context() tr.Context {
return tr.VoidContext()
}

// Apply a given set of arguments to this function binding.
func (p *FunctionBinding) Apply(args []Expr) Expr {
return p.body.Substitute(args)
}
Loading

0 comments on commit f1d3d94

Please sign in to comment.