Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AIR permutations #173

Merged
merged 1 commit into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/air/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ func (p *Schema) AddComputation(c table.TraceComputation) {

// AddPermutationConstraint appends a new permutation constraint which
// ensures that one column is a permutation of another.
func (p *Schema) AddPermutationConstraint(target string, source string) {
p.permutations = append(p.permutations, table.NewPermutation(target, source))
func (p *Schema) AddPermutationConstraint(targets []string, sources []string) {
p.permutations = append(p.permutations, table.NewPermutation(targets, sources))
}

// AddVanishingConstraint appends a new vanishing constraint.
Expand Down
3 changes: 2 additions & 1 deletion pkg/mir/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche
// Add individual permutation constraints
for i := 0; i < ncols; i++ {
airSchema.AddColumn(c.Targets[i], true)
airSchema.AddPermutationConstraint(c.Targets[i], c.Sources[i])
}
//
airSchema.AddPermutationConstraint(c.Targets, c.Sources)
// Add the trace computation.
airSchema.AddComputation(c)
// Add sorting constraints + synthetic columns as necessary.
Expand Down
142 changes: 83 additions & 59 deletions pkg/table/column.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,19 @@ func (c *ComputedColumn[E]) String() string {
// Permutation declares a constraint that one column is a permutation
// of another.
type Permutation struct {
// The target column
Target string
// The target columns
Targets []string
// The so columns
Source string
Sources []string
}

// NewPermutation creates a new permutation
func NewPermutation(target string, source string) *Permutation {
return &Permutation{target, source}
func NewPermutation(targets []string, sources []string) *Permutation {
if len(targets) != len(sources) {
panic("differeng number of target / source permutation columns")
}

return &Permutation{targets, sources}
}

// RequiredSpillage returns the minimum amount of spillage required to ensure
Expand All @@ -174,18 +178,44 @@ func (p *Permutation) RequiredSpillage() uint {
// Accepts checks whether a permutation holds between the source and
// target columns.
func (p *Permutation) Accepts(tr Trace) error {
// Check column in trace!
if !tr.HasColumn(p.Target) {
return fmt.Errorf("Trace missing permutation target column ({%s})", p.Target)
} else if !tr.HasColumn(p.Source) {
return fmt.Errorf("Trace missing permutation source column ({%s})", p.Source)
// Sanity check columns well formed.
if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil {
return err
}

return IsPermutationOf(p.Target, p.Source, tr)
// Slice out data
src := sliceMatchingColumns(p.Sources, tr)
dst := sliceMatchingColumns(p.Targets, tr)
// Sanity check whether column exists
if !util.ArePermutationOf(dst, src) {
msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})",
p.Targets, p.Sources)
return errors.New(msg)
}
// Success
return nil
}

func (p *Permutation) String() string {
return fmt.Sprintf("(permutation %s %s)", p.Target, p.Source)
targets := ""
sources := ""

for i, s := range p.Targets {
if i != 0 {
targets += " "
}

targets += s
}

for i, s := range p.Sources {
if i != 0 {
sources += " "
}

sources += s
}

return fmt.Sprintf("(permutation (%s) (%s))", targets, sources)
}

// ===================================================================
Expand Down Expand Up @@ -221,41 +251,26 @@ func (p *SortedPermutation) RequiredSpillage() uint {
// Accepts checks whether a sorted permutation holds between the
// source and target columns.
func (p *SortedPermutation) Accepts(tr Trace) error {
ncols := len(p.Sources)
cols := make([][]*fr.Element, ncols)
// Check required columns in trace
for _, n := range p.Targets {
if !tr.HasColumn(n) {
return fmt.Errorf("Trace missing permutation target column ({%s})", n)
}
// Sanity check columns well formed.
if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil {
return err
}

for _, n := range p.Sources {
if !tr.HasColumn(n) {
return fmt.Errorf("Trace missing permutation source ({%s})", n)
}
}
// Check that target and source columns exist and are permutations of source
// columns.
for i := 0; i < ncols; i++ {
dstName := p.Targets[i]
srcName := p.Sources[i]
// Access column data based on column name.
err := IsPermutationOf(dstName, srcName, tr)
if err != nil {
return err
}

cols[i] = tr.ColumnByName(dstName).Data()
// Slice out data
src := sliceMatchingColumns(p.Sources, tr)
dst := sliceMatchingColumns(p.Targets, tr)
// Sanity check whether column exists
if !util.ArePermutationOf(dst, src) {
msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})",
p.Targets, p.Sources)
return errors.New(msg)
}

// Check that target columns are sorted lexicographically.
if util.AreLexicographicallySorted(cols, p.Signs) {
if util.AreLexicographicallySorted(dst, p.Signs) {
return nil
}

// Error case
msg := fmt.Sprintf("Permutation columns not lexicographically sorted ({%s})", p.Targets)

// Done
return errors.New(msg)
}

Expand Down Expand Up @@ -321,23 +336,32 @@ func (p *SortedPermutation) String() string {
return fmt.Sprintf("(permute (%s) (%s))", targets, sources)
}

// IsPermutationOf checks whether (or not) one column is a permutation
// of another in given trace. The order in which columns are given is
// not important.
func IsPermutationOf(target string, source string, tr Trace) error {
dst := tr.ColumnByName(target).Data()
src := tr.ColumnByName(source).Data()
// Sanity check whether column exists
if dst == nil {
msg := fmt.Sprintf("Invalid target column for permutation ({%s})", target)
return errors.New(msg)
} else if src == nil {
msg := fmt.Sprintf("Invalid source column for permutation ({%s})", source)
return errors.New(msg)
} else if !util.IsPermutationOf(dst, src) {
msg := fmt.Sprintf("Target column (%s) not permutation of source ({%s})", target, source)
return errors.New(msg)
func validPermutationColumns(targets []string, sources []string, tr Trace) error {
ncols := len(targets)
// Sanity check matching length
if len(sources) != ncols {
return fmt.Errorf("Number of source and target columns differs")
}

// Check required columns in trace
for i := 0; i < ncols; i++ {
if !tr.HasColumn(targets[i]) {
return fmt.Errorf("Trace missing permutation target column ({%s})", targets[i])
} else if !tr.HasColumn(sources[i]) {
return fmt.Errorf("Trace missing permutation source ({%s})", sources[i])
}
}
//
return nil
}

func sliceMatchingColumns(names []string, tr Trace) [][]*fr.Element {
// Allocate return array
cols := make([][]*fr.Element, len(names))
// Slice out the data
for i, n := range names {
nth := tr.ColumnByName(n)
cols[i] = nth.Data()
}
// Done
return cols
}
26 changes: 26 additions & 0 deletions pkg/util/arrays.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
package util

import "github.com/consensys/gnark-crypto/ecc/bls12-377/fr"

// Equals2d returns true if two 2D arrays are equal.
func Equals2d(lhs [][]*fr.Element, rhs [][]*fr.Element) bool {
if len(lhs) != len(rhs) {
return false
}

for i := 0; i < len(lhs); i++ {
lhs_i := lhs[i]
rhs_i := rhs[i]
// Check lengths match
if len(lhs_i) != len(rhs_i) {
return false
}
// Check elements match
for j := 0; j < len(lhs_i); j++ {
if lhs_i[j].Cmp(rhs_i[j]) != 0 {
return false
}
}
}
//
return true
}

// FlatArrayIndexOf_2 returns the ith element of the flattened form of a 2d
// array. Consider the array "[[0,7],[4]]". Then its flattened form is
// "[0,7,4]" and, for example, the element at index 1 is "7".
Expand Down
80 changes: 49 additions & 31 deletions pkg/util/permutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,43 @@ import (
"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
)

// IsPermutationOf checks whether or not a given destination column is a valid
// permutation of a given source column. This function does not modify either
// column (though it does allocate an intermediate array).
// ArePermutationOf checks whether or not a set of given destination columns are
// a valid permutation of a given set of source columns. The number of source
// and target columns must match. Likewise, they are expected to have the same
// height. This function does not modify any columns (though it does allocate
// intermediate arrays).
//
// This function operators by cloning both arrays, sorting them and checking
// they are the same.
func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool {
// This function operators by cloning the arrays, sorting them and checking they
// are the same.
func ArePermutationOf(dst [][]*fr.Element, src [][]*fr.Element) bool {
if len(dst) != len(src) {
return false
}
// Copy arrays
dstCopy := make([]*fr.Element, len(dst))
srcCopy := make([]*fr.Element, len(src))
// Determine geometry
ncols := len(dst)
nrows := len(dst[0])
// Rotate input arrays
dstCopy := rotate(dst, ncols, nrows)
srcCopy := rotate(src, ncols, nrows)
// Sort rotated arrays
slices.SortFunc(dstCopy, permutationFunc)
slices.SortFunc(srcCopy, permutationFunc)
// Check rotated arrays match
return Equals2d(dstCopy, srcCopy)
}

copy(dstCopy, dst)
copy(srcCopy, src)
// Sort arrays
slices.SortFunc(dstCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) })
slices.SortFunc(srcCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) })
// Check they are equal
for i := 0; i < len(dst); i++ {
if dstCopy[i].Cmp(srcCopy[i]) != 0 {
return false
func permutationFunc(lhs []*fr.Element, rhs []*fr.Element) int {
for i := 0; i < len(lhs); i++ {
// Compare ith elements
c := lhs[i].Cmp(rhs[i])
// Check whether same
if c != 0 {
// Positive
return c
}
}
// Match
return true
// Identical
return 0
}

// PermutationSort sorts an array of columns in row-wise fashion. For
Expand All @@ -54,17 +64,8 @@ func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool {
func PermutationSort(cols [][]*fr.Element, signs []bool) {
n := len(cols[0])
m := len(cols)
//
rows := make([][]*fr.Element, n)
// project into row-wise form
for i := 0; i < n; i++ {
row := make([]*fr.Element, m)
for j := 0; j < m; j++ {
row[j] = cols[j][i]
}

rows[i] = row
}
// Rotate input matrix
rows := rotate(cols, m, n)
// Perform the permutation sort
slices.SortFunc(rows, func(l []*fr.Element, r []*fr.Element) int {
return permutationSortFunc(l, r, signs)
Expand Down Expand Up @@ -120,3 +121,20 @@ func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int
// Identical
return 0
}

// Clone and rotate a 2-dimensional array assuming a given geometry.
func rotate(src [][]*fr.Element, ncols int, nrows int) [][]*fr.Element {
// Copy outer arrays
dst := make([][]*fr.Element, nrows)
// Copy inner arrays
for i := 0; i < nrows; i++ {
row := make([]*fr.Element, ncols)
for j := 0; j < ncols; j++ {
row[j] = src[j][i]
}

dst[i] = row
}
//
return dst
}