Skip to content

Commit

Permalink
Merge pull request #170 from Consensys/117-feat-accounting-for-spillage
Browse files Browse the repository at this point in the history
feat: Accounting for Spillage
  • Loading branch information
DavePearce authored Jun 17, 2024
2 parents 8a12829 + e0bdd65 commit 28dbfe5
Show file tree
Hide file tree
Showing 39 changed files with 672 additions and 288 deletions.
49 changes: 49 additions & 0 deletions pkg/air/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package air
import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
)

// Expr represents an expression in the Arithmetic Intermediate Representation
Expand Down Expand Up @@ -33,6 +34,10 @@ type Expr interface {

// Equate one expression with another
Equate(Expr) Expr

// Determine the maximum shift in this expression in either the negative
// (left) or positive direction (right).
MaxShift() util.Pair[uint, uint]
}

// Add represents the sum over zero or more expressions.
Expand All @@ -50,6 +55,10 @@ func (p *Add) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Add) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Add) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Sub represents the subtraction over zero or more expressions.
type Sub struct{ Args []Expr }

Expand All @@ -65,6 +74,10 @@ func (p *Sub) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Sub) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Sub) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Mul represents the product over zero or more expressions.
type Mul struct{ Args []Expr }

Expand All @@ -80,6 +93,10 @@ func (p *Mul) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Mul) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *Mul) MaxShift() util.Pair[uint, uint] { return maxShiftOfArray(p.Args) }

// Constant represents a constant value within an expression.
type Constant struct{ Value *fr.Element }

Expand Down Expand Up @@ -118,6 +135,10 @@ func (p *Constant) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}} }
// Equate one expression with another (equivalent to subtraction).
func (p *Constant) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right). A constant has zero shift.
func (p *Constant) MaxShift() util.Pair[uint, uint] { return util.NewPair[uint, uint](0, 0) }

// ColumnAccess represents reading the value held at a given column in the
// tabular context. Furthermore, the current row maybe shifted up (or down) by
// a given amount. Suppose we are evaluating a constraint on row k=5 which
Expand Down Expand Up @@ -146,3 +167,31 @@ func (p *ColumnAccess) Mul(other Expr) Expr { return &Mul{Args: []Expr{p, other}

// Equate one expression with another (equivalent to subtraction).
func (p *ColumnAccess) Equate(other Expr) Expr { return &Sub{Args: []Expr{p, other}} }

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (p *ColumnAccess) MaxShift() util.Pair[uint, uint] {
if p.Shift >= 0 {
// Positive shift
return util.NewPair[uint, uint](0, uint(p.Shift))
}
// Negative shift
return util.NewPair[uint, uint](uint(-p.Shift), 0)
}

// ==========================================================================
// Helpers
// ==========================================================================

func maxShiftOfArray(args []Expr) util.Pair[uint, uint] {
neg := uint(0)
pos := uint(0)

for _, e := range args {
mx := e.MaxShift()
neg = max(neg, mx.Left)
pos = max(pos, mx.Right)
}
// Done
return util.NewPair(neg, pos)
}
3 changes: 2 additions & 1 deletion pkg/air/gadgets/bits.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func ApplyBitwidthGadget(col string, nbits uint, schema *air.Schema) {
// Construct X == (X:0 * 1) + ... + (X:n * 2^n)
X := air.NewColumnAccess(col, 0)
eq := X.Equate(sum)
schema.AddVanishingConstraint(col, nil, eq)
// Construct column name
schema.AddVanishingConstraint(fmt.Sprintf("%s:u%d", col, nbits), nil, eq)
// Finally, add the necessary byte decomposition computation.
schema.AddComputation(table.NewByteDecomposition(col, nbits))
}
Expand Down
12 changes: 9 additions & 3 deletions pkg/air/gadgets/lexicographic_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,12 @@ type lexicographicSortExpander struct {
bitwidth uint
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding.
func (p *lexicographicSortExpander) RequiredSpillage() uint {
return uint(0)
}

// Accepts checks whether a given trace has the necessary columns
func (p *lexicographicSortExpander) Accepts(tr table.Trace) error {
prefix := constructLexicographicSortingPrefix(p.columns, p.signs)
Expand Down Expand Up @@ -194,14 +200,14 @@ func (p *lexicographicSortExpander) ExpandTrace(tr table.Trace) error {
bit[i] = make([]*fr.Element, nrows)
}

for i := 0; i < nrows; i++ {
for i := uint(0); i < nrows; i++ {
set := false
// Initialise delta to zero
delta[i] = &zero
// Decide which row is the winner (if any)
for j := 0; j < ncols; j++ {
prev := tr.GetByName(p.columns[j], i-1)
curr := tr.GetByName(p.columns[j], i)
prev := tr.GetByName(p.columns[j], int(i-1))
curr := tr.GetByName(p.columns[j], int(i))

if !set && prev != nil && prev.Cmp(curr) != 0 {
var diff fr.Element
Expand Down
5 changes: 5 additions & 0 deletions pkg/air/gadgets/normalisation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
"github.com/consensys/go-corset/pkg/air"
"github.com/consensys/go-corset/pkg/table"
"github.com/consensys/go-corset/pkg/util"
)

// Normalise constructs an expression representing the normalised value of e.
Expand Down Expand Up @@ -73,6 +74,10 @@ func (e *Inverse) EvalAt(k int, tbl table.Trace) *fr.Element {
return inv.Inverse(val)
}

// MaxShift returns max shift in either the negative (left) or positive
// direction (right).
func (e *Inverse) MaxShift() util.Pair[uint, uint] { return e.Expr.MaxShift() }

func (e *Inverse) String() string {
return fmt.Sprintf("(inv %s)", e.Expr)
}
40 changes: 33 additions & 7 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,38 @@ func (p *Schema) HasColumn(name string) bool {
return false
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
// valid traces are accepted in the presence of arbitrary padding. Spillage can
// only arise from computations as this is where values outside of the user's
// control are determined.
func (p *Schema) RequiredSpillage() uint {
// Ensures always at least one row of spillage (referred to as the "initial
// padding row")
mx := uint(1)
// Determine if any more spillage required
for _, c := range p.computations {
mx = max(mx, c.RequiredSpillage())
}

return mx
}

// ApplyPadding adds n items of padding to each column of the trace.
// Padding values are placed either at the front or the back of a given
// column, depending on their interpretation.
func (p *Schema) ApplyPadding(n uint, tr table.Trace) {
tr.Pad(n, func(j int) *fr.Element {
// Extract front value to use for padding.
return tr.GetByIndex(j, 0)
})
}

// IsInputTrace determines whether a given input trace is a suitable
// input (i.e. non-expanded) trace for this schema. Specifically, the
// input trace must contain a matching column for each non-synthetic
// column in this trace.
func (p *Schema) IsInputTrace(tr table.Trace) error {
count := 0
count := uint(0)

for _, c := range p.dataColumns {
if !c.Synthetic && !tr.HasColumn(c.Name) {
Expand All @@ -112,8 +138,8 @@ func (p *Schema) IsInputTrace(tr table.Trace) error {
// Determine the unknown columns for error reporting.
unknown := make([]string, 0)

for i := 0; i < tr.Width(); i++ {
n := tr.ColumnName(i)
for i := uint(0); i < tr.Width(); i++ {
n := tr.ColumnName(int(i))
if !p.HasColumn(n) {
unknown = append(unknown, n)
}
Expand All @@ -132,7 +158,7 @@ func (p *Schema) IsInputTrace(tr table.Trace) error {
// output trace must contain a matching column for each column in this
// trace (synthetic or otherwise).
func (p *Schema) IsOutputTrace(tr table.Trace) error {
count := 0
count := uint(0)

for _, c := range p.dataColumns {
if !tr.HasColumn(c.Name) {
Expand All @@ -153,7 +179,9 @@ func (p *Schema) IsOutputTrace(tr table.Trace) error {
// AddColumn appends a new data column which is either synthetic or
// not. A synthetic column is one which has been introduced by the
// process of lowering from HIR / MIR to AIR. That is, it is not a
// column which was original specified by the user.
// column which was original specified by the user. Columns also support a
// "padding sign", which indicates whether padding should occur at the front
// (positive sign) or the back (negative sign).
func (p *Schema) AddColumn(name string, synthetic bool) {
// NOTE: the air level has no ability to enforce the type specified for a
// given column.
Expand Down Expand Up @@ -219,8 +247,6 @@ func (p *Schema) Accepts(trace table.Trace) error {
// columns. Observe that computed columns have to be computed in the correct
// order.
func (p *Schema) ExpandTrace(tr table.Trace) error {
// Insert initial padding row
table.PadTrace(1, tr)
// Execute all computations
for _, c := range p.computations {
err := c.ExpandTrace(tr)
Expand Down
Loading

0 comments on commit 28dbfe5

Please sign in to comment.