Skip to content

Commit

Permalink
Fix AIR permutations
Browse files Browse the repository at this point in the history
  • Loading branch information
DavePearce committed Jun 17, 2024
1 parent 407b8c5 commit f9f8261
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 93 deletions.
4 changes: 2 additions & 2 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ func (p *Schema) AddComputation(c table.TraceComputation) {

// AddPermutationConstraint appends a new permutation constraint which
// ensures that one column is a permutation of another.
func (p *Schema) AddPermutationConstraint(target string, source string) {
p.permutations = append(p.permutations, table.NewPermutation(target, source))
func (p *Schema) AddPermutationConstraint(targets []string, sources []string) {
p.permutations = append(p.permutations, table.NewPermutation(targets, sources))
}

// AddVanishingConstraint appends a new vanishing constraint.
Expand Down
3 changes: 2 additions & 1 deletion pkg/mir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
// Add individual permutation constraints
for i := 0; i < ncols; i++ {
airSchema.AddColumn(c.Targets[i], true)
airSchema.AddPermutationConstraint(c.Targets[i], c.Sources[i])
}
//
airSchema.AddPermutationConstraint(c.Targets, c.Sources)
// Add the trace computation.
airSchema.AddComputation(c)
// Add sorting constraints + synthetic columns as necessary.
Expand Down
142 changes: 83 additions & 59 deletions pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,19 @@ func (c *ComputedColumn[E]) String() string {
// Permutation declares a constraint that one column is a permutation
// of another.
type Permutation struct {
// The target column
Target string
// The target columns
Targets []string
// The so columns
Source string
Sources []string
}

// NewPermutation creates a new permutation
func NewPermutation(target string, source string) *Permutation {
return &Permutation{target, source}
func NewPermutation(targets []string, sources []string) *Permutation {
if len(targets) != len(sources) {
panic("differeng number of target / source permutation columns")
}

return &Permutation{targets, sources}
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
Expand All @@ -174,18 +178,44 @@ func (p *Permutation) RequiredSpillage() uint {
// Accepts checks whether a permutation holds between the source and
// target columns.
func (p *Permutation) Accepts(tr Trace) error {
// Check column in trace!
if !tr.HasColumn(p.Target) {
return fmt.Errorf("Trace missing permutation target column ({%s})", p.Target)
} else if !tr.HasColumn(p.Source) {
return fmt.Errorf("Trace missing permutation source column ({%s})", p.Source)
// Sanity check columns well formed.
if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil {
return err
}

return IsPermutationOf(p.Target, p.Source, tr)
// Slice out data
src := sliceMatchingColumns(p.Sources, tr)
dst := sliceMatchingColumns(p.Targets, tr)
// Sanity check whether column exists
if !util.ArePermutationOf(dst, src) {
msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})",
p.Targets, p.Sources)
return errors.New(msg)
}
// Success
return nil
}

func (p *Permutation) String() string {
return fmt.Sprintf("(permutation %s %s)", p.Target, p.Source)
targets := ""
sources := ""

for i, s := range p.Targets {
if i != 0 {
targets += " "
}

targets += s
}

for i, s := range p.Sources {
if i != 0 {
sources += " "
}

sources += s
}

return fmt.Sprintf("(permutation (%s) (%s))", targets, sources)
}

// ===================================================================
Expand Down Expand Up @@ -221,41 +251,26 @@ func (p *SortedPermutation) RequiredSpillage() uint {
// Accepts checks whether a sorted permutation holds between the
// source and target columns.
func (p *SortedPermutation) Accepts(tr Trace) error {
ncols := len(p.Sources)
cols := make([][]*fr.Element, ncols)
// Check required columns in trace
for _, n := range p.Targets {
if !tr.HasColumn(n) {
return fmt.Errorf("Trace missing permutation target column ({%s})", n)
}
// Sanity check columns well formed.
if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil {
return err
}

for _, n := range p.Sources {
if !tr.HasColumn(n) {
return fmt.Errorf("Trace missing permutation source ({%s})", n)
}
}
// Check that target and source columns exist and are permutations of source
// columns.
for i := 0; i < ncols; i++ {
dstName := p.Targets[i]
srcName := p.Sources[i]
// Access column data based on column name.
err := IsPermutationOf(dstName, srcName, tr)
if err != nil {
return err
}

cols[i] = tr.ColumnByName(dstName).Data()
// Slice out data
src := sliceMatchingColumns(p.Sources, tr)
dst := sliceMatchingColumns(p.Targets, tr)
// Sanity check whether column exists
if !util.ArePermutationOf(dst, src) {
msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})",
p.Targets, p.Sources)
return errors.New(msg)
}

// Check that target columns are sorted lexicographically.
if util.AreLexicographicallySorted(cols, p.Signs) {
if util.AreLexicographicallySorted(dst, p.Signs) {
return nil
}

// Error case
msg := fmt.Sprintf("Permutation columns not lexicographically sorted ({%s})", p.Targets)

// Done
return errors.New(msg)
}

Expand Down Expand Up @@ -321,23 +336,32 @@ func (p *SortedPermutation) String() string {
return fmt.Sprintf("(permute (%s) (%s))", targets, sources)
}

// IsPermutationOf checks whether (or not) one column is a permutation
// of another in given trace. The order in which columns are given is
// not important.
func IsPermutationOf(target string, source string, tr Trace) error {
dst := tr.ColumnByName(target).Data()
src := tr.ColumnByName(source).Data()
// Sanity check whether column exists
if dst == nil {
msg := fmt.Sprintf("Invalid target column for permutation ({%s})", target)
return errors.New(msg)
} else if src == nil {
msg := fmt.Sprintf("Invalid source column for permutation ({%s})", source)
return errors.New(msg)
} else if !util.IsPermutationOf(dst, src) {
msg := fmt.Sprintf("Target column (%s) not permutation of source ({%s})", target, source)
return errors.New(msg)
func validPermutationColumns(targets []string, sources []string, tr Trace) error {
ncols := len(targets)
// Sanity check matching length
if len(sources) != ncols {
return fmt.Errorf("Number of source and target columns differs")
}

// Check required columns in trace
for i := 0; i < ncols; i++ {
if !tr.HasColumn(targets[i]) {
return fmt.Errorf("Trace missing permutation target column ({%s})", targets[i])
} else if !tr.HasColumn(sources[i]) {
return fmt.Errorf("Trace missing permutation source ({%s})", sources[i])
}
}
//
return nil
}

func sliceMatchingColumns(names []string, tr Trace) [][]*fr.Element {
// Allocate return array
cols := make([][]*fr.Element, len(names))
// Slice out the data
for i, n := range names {
nth := tr.ColumnByName(n)
cols[i] = nth.Data()
}
// Done
return cols
}
26 changes: 26 additions & 0 deletions pkg/util/arrays.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
package util

import "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"

// Equals2d returns true if two 2D arrays are equal.
func Equals2d(lhs [][]*fr.Element, rhs [][]*fr.Element) bool {
if len(lhs) != len(rhs) {
return false
}

for i := 0; i < len(lhs); i++ {
lhs_i := lhs[i]
rhs_i := rhs[i]
// Check lengths match
if len(lhs_i) != len(rhs_i) {
return false
}
// Check elements match
for j := 0; j < len(lhs_i); j++ {
if lhs_i[j].Cmp(rhs_i[j]) != 0 {
return false
}
}
}
//
return true
}

// FlatArrayIndexOf_2 returns the ith element of the flattened form of a 2d
// array. Consider the array "[[0,7],[4]]". Then its flattened form is
// "[0,7,4]" and, for example, the element at index 1 is "7".
Expand Down
80 changes: 49 additions & 31 deletions pkg/util/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,43 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
)

// IsPermutationOf checks whether or not a given destination column is a valid
// permutation of a given source column. This function does not modify either
// column (though it does allocate an intermediate array).
// ArePermutationOf checks whether or not a set of given destination columns are
// a valid permutation of a given set of source columns. The number of source
// and target columns must match. Likewise, they are expected to have the same
// height. This function does not modify any columns (though it does allocate
// intermediate arrays).
//
// This function operators by cloning both arrays, sorting them and checking
// they are the same.
func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool {
// This function operators by cloning the arrays, sorting them and checking they
// are the same.
func ArePermutationOf(dst [][]*fr.Element, src [][]*fr.Element) bool {
if len(dst) != len(src) {
return false
}
// Copy arrays
dstCopy := make([]*fr.Element, len(dst))
srcCopy := make([]*fr.Element, len(src))
// Determine geometry
ncols := len(dst)
nrows := len(dst[0])
// Rotate input arrays
dstCopy := rotate(dst, ncols, nrows)
srcCopy := rotate(src, ncols, nrows)
// Sort rotated arrays
slices.SortFunc(dstCopy, permutationFunc)
slices.SortFunc(srcCopy, permutationFunc)
// Check rotated arrays match
return Equals2d(dstCopy, srcCopy)
}

copy(dstCopy, dst)
copy(srcCopy, src)
// Sort arrays
slices.SortFunc(dstCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) })
slices.SortFunc(srcCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) })
// Check they are equal
for i := 0; i < len(dst); i++ {
if dstCopy[i].Cmp(srcCopy[i]) != 0 {
return false
func permutationFunc(lhs []*fr.Element, rhs []*fr.Element) int {
for i := 0; i < len(lhs); i++ {
// Compare ith elements
c := lhs[i].Cmp(rhs[i])
// Check whether same
if c != 0 {
// Positive
return c
}
}
// Match
return true
// Identical
return 0
}

// PermutationSort sorts an array of columns in row-wise fashion. For
Expand All @@ -54,17 +64,8 @@ func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool {
func PermutationSort(cols [][]*fr.Element, signs []bool) {
n := len(cols[0])
m := len(cols)
//
rows := make([][]*fr.Element, n)
// project into row-wise form
for i := 0; i < n; i++ {
row := make([]*fr.Element, m)
for j := 0; j < m; j++ {
row[j] = cols[j][i]
}

rows[i] = row
}
// Rotate input matrix
rows := rotate(cols, m, n)
// Perform the permutation sort
slices.SortFunc(rows, func(l []*fr.Element, r []*fr.Element) int {
return permutationSortFunc(l, r, signs)
Expand Down Expand Up @@ -120,3 +121,20 @@ func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int
// Identical
return 0
}

// Clone and rotate a 2-dimensional array assuming a given geometry.
func rotate(src [][]*fr.Element, ncols int, nrows int) [][]*fr.Element {
// Copy outer arrays
dst := make([][]*fr.Element, nrows)
// Copy inner arrays
for i := 0; i < nrows; i++ {
row := make([]*fr.Element, ncols)
for j := 0; j < ncols; j++ {
row[j] = src[j][i]
}

dst[i] = row
}
//
return dst
}

0 comments on commit f9f8261

Please sign in to comment.