From f9f82614a8729121fc5273189318f413ab5e551b Mon Sep 17 00:00:00 2001 From: DavePearce Date: Mon, 17 Jun 2024 22:05:45 +1200 Subject: [PATCH] Fix AIR permutations --- pkg/air/schema.go | 4 +- pkg/mir/schema.go | 3 +- pkg/table/column.go | 142 +++++++++++++++++++++++----------------- pkg/util/arrays.go | 26 ++++++++ pkg/util/permutation.go | 80 +++++++++++++--------- 5 files changed, 162 insertions(+), 93 deletions(-) diff --git a/pkg/air/schema.go b/pkg/air/schema.go index e102751..6ee0b0c 100644 --- a/pkg/air/schema.go +++ b/pkg/air/schema.go @@ -186,8 +186,8 @@ func (p *Schema) AddComputation(c table.TraceComputation) { // AddPermutationConstraint appends a new permutation constraint which // ensures that one column is a permutation of another. -func (p *Schema) AddPermutationConstraint(target string, source string) { - p.permutations = append(p.permutations, table.NewPermutation(target, source)) +func (p *Schema) AddPermutationConstraint(targets []string, sources []string) { + p.permutations = append(p.permutations, table.NewPermutation(targets, sources)) } // AddVanishingConstraint appends a new vanishing constraint. diff --git a/pkg/mir/schema.go b/pkg/mir/schema.go index 43e3a47..80b0900 100644 --- a/pkg/mir/schema.go +++ b/pkg/mir/schema.go @@ -192,8 +192,9 @@ func lowerPermutationToAir(c Permutation, mirSchema *Schema, airSchema *air.Sche // Add individual permutation constraints for i := 0; i < ncols; i++ { airSchema.AddColumn(c.Targets[i], true) - airSchema.AddPermutationConstraint(c.Targets[i], c.Sources[i]) } + // + airSchema.AddPermutationConstraint(c.Targets, c.Sources) // Add the trace computation. airSchema.AddComputation(c) // Add sorting constraints + synthetic columns as necessary. diff --git a/pkg/table/column.go b/pkg/table/column.go index 4158f85..58c5cd6 100644 --- a/pkg/table/column.go +++ b/pkg/table/column.go @@ -154,15 +154,19 @@ func (c *ComputedColumn[E]) String() string { // Permutation declares a constraint that one column is a permutation // of another. type Permutation struct { - // The target column - Target string + // The target columns + Targets []string // The so columns - Source string + Sources []string } // NewPermutation creates a new permutation -func NewPermutation(target string, source string) *Permutation { - return &Permutation{target, source} +func NewPermutation(targets []string, sources []string) *Permutation { + if len(targets) != len(sources) { + panic("differeng number of target / source permutation columns") + } + + return &Permutation{targets, sources} } // RequiredSpillage returns the minimum amount of spillage required to ensure @@ -174,18 +178,44 @@ func (p *Permutation) RequiredSpillage() uint { // Accepts checks whether a permutation holds between the source and // target columns. func (p *Permutation) Accepts(tr Trace) error { - // Check column in trace! - if !tr.HasColumn(p.Target) { - return fmt.Errorf("Trace missing permutation target column ({%s})", p.Target) - } else if !tr.HasColumn(p.Source) { - return fmt.Errorf("Trace missing permutation source column ({%s})", p.Source) + // Sanity check columns well formed. + if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil { + return err } - - return IsPermutationOf(p.Target, p.Source, tr) + // Slice out data + src := sliceMatchingColumns(p.Sources, tr) + dst := sliceMatchingColumns(p.Targets, tr) + // Sanity check whether column exists + if !util.ArePermutationOf(dst, src) { + msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", + p.Targets, p.Sources) + return errors.New(msg) + } + // Success + return nil } func (p *Permutation) String() string { - return fmt.Sprintf("(permutation %s %s)", p.Target, p.Source) + targets := "" + sources := "" + + for i, s := range p.Targets { + if i != 0 { + targets += " " + } + + targets += s + } + + for i, s := range p.Sources { + if i != 0 { + sources += " " + } + + sources += s + } + + return fmt.Sprintf("(permutation (%s) (%s))", targets, sources) } // =================================================================== @@ -221,41 +251,26 @@ func (p *SortedPermutation) RequiredSpillage() uint { // Accepts checks whether a sorted permutation holds between the // source and target columns. func (p *SortedPermutation) Accepts(tr Trace) error { - ncols := len(p.Sources) - cols := make([][]*fr.Element, ncols) - // Check required columns in trace - for _, n := range p.Targets { - if !tr.HasColumn(n) { - return fmt.Errorf("Trace missing permutation target column ({%s})", n) - } + // Sanity check columns well formed. + if err := validPermutationColumns(p.Targets, p.Sources, tr); err != nil { + return err } - - for _, n := range p.Sources { - if !tr.HasColumn(n) { - return fmt.Errorf("Trace missing permutation source ({%s})", n) - } - } - // Check that target and source columns exist and are permutations of source - // columns. - for i := 0; i < ncols; i++ { - dstName := p.Targets[i] - srcName := p.Sources[i] - // Access column data based on column name. - err := IsPermutationOf(dstName, srcName, tr) - if err != nil { - return err - } - - cols[i] = tr.ColumnByName(dstName).Data() + // Slice out data + src := sliceMatchingColumns(p.Sources, tr) + dst := sliceMatchingColumns(p.Targets, tr) + // Sanity check whether column exists + if !util.ArePermutationOf(dst, src) { + msg := fmt.Sprintf("Target columns (%v) not permutation of source columns ({%v})", + p.Targets, p.Sources) + return errors.New(msg) } - // Check that target columns are sorted lexicographically. - if util.AreLexicographicallySorted(cols, p.Signs) { + if util.AreLexicographicallySorted(dst, p.Signs) { return nil } - + // Error case msg := fmt.Sprintf("Permutation columns not lexicographically sorted ({%s})", p.Targets) - + // Done return errors.New(msg) } @@ -321,23 +336,32 @@ func (p *SortedPermutation) String() string { return fmt.Sprintf("(permute (%s) (%s))", targets, sources) } -// IsPermutationOf checks whether (or not) one column is a permutation -// of another in given trace. The order in which columns are given is -// not important. -func IsPermutationOf(target string, source string, tr Trace) error { - dst := tr.ColumnByName(target).Data() - src := tr.ColumnByName(source).Data() - // Sanity check whether column exists - if dst == nil { - msg := fmt.Sprintf("Invalid target column for permutation ({%s})", target) - return errors.New(msg) - } else if src == nil { - msg := fmt.Sprintf("Invalid source column for permutation ({%s})", source) - return errors.New(msg) - } else if !util.IsPermutationOf(dst, src) { - msg := fmt.Sprintf("Target column (%s) not permutation of source ({%s})", target, source) - return errors.New(msg) +func validPermutationColumns(targets []string, sources []string, tr Trace) error { + ncols := len(targets) + // Sanity check matching length + if len(sources) != ncols { + return fmt.Errorf("Number of source and target columns differs") } - + // Check required columns in trace + for i := 0; i < ncols; i++ { + if !tr.HasColumn(targets[i]) { + return fmt.Errorf("Trace missing permutation target column ({%s})", targets[i]) + } else if !tr.HasColumn(sources[i]) { + return fmt.Errorf("Trace missing permutation source ({%s})", sources[i]) + } + } + // return nil } + +func sliceMatchingColumns(names []string, tr Trace) [][]*fr.Element { + // Allocate return array + cols := make([][]*fr.Element, len(names)) + // Slice out the data + for i, n := range names { + nth := tr.ColumnByName(n) + cols[i] = nth.Data() + } + // Done + return cols +} diff --git a/pkg/util/arrays.go b/pkg/util/arrays.go index 7e9c702..f46cb7f 100644 --- a/pkg/util/arrays.go +++ b/pkg/util/arrays.go @@ -1,5 +1,31 @@ package util +import "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" + +// Equals2d returns true if two 2D arrays are equal. +func Equals2d(lhs [][]*fr.Element, rhs [][]*fr.Element) bool { + if len(lhs) != len(rhs) { + return false + } + + for i := 0; i < len(lhs); i++ { + lhs_i := lhs[i] + rhs_i := rhs[i] + // Check lengths match + if len(lhs_i) != len(rhs_i) { + return false + } + // Check elements match + for j := 0; j < len(lhs_i); j++ { + if lhs_i[j].Cmp(rhs_i[j]) != 0 { + return false + } + } + } + // + return true +} + // FlatArrayIndexOf_2 returns the ith element of the flattened form of a 2d // array. Consider the array "[[0,7],[4]]". Then its flattened form is // "[0,7,4]" and, for example, the element at index 1 is "7". diff --git a/pkg/util/permutation.go b/pkg/util/permutation.go index c5f5f27..73474f2 100644 --- a/pkg/util/permutation.go +++ b/pkg/util/permutation.go @@ -8,33 +8,43 @@ import ( "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" ) -// IsPermutationOf checks whether or not a given destination column is a valid -// permutation of a given source column. This function does not modify either -// column (though it does allocate an intermediate array). +// ArePermutationOf checks whether or not a set of given destination columns are +// a valid permutation of a given set of source columns. The number of source +// and target columns must match. Likewise, they are expected to have the same +// height. This function does not modify any columns (though it does allocate +// intermediate arrays). // -// This function operators by cloning both arrays, sorting them and checking -// they are the same. -func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool { +// This function operators by cloning the arrays, sorting them and checking they +// are the same. +func ArePermutationOf(dst [][]*fr.Element, src [][]*fr.Element) bool { if len(dst) != len(src) { return false } - // Copy arrays - dstCopy := make([]*fr.Element, len(dst)) - srcCopy := make([]*fr.Element, len(src)) + // Determine geometry + ncols := len(dst) + nrows := len(dst[0]) + // Rotate input arrays + dstCopy := rotate(dst, ncols, nrows) + srcCopy := rotate(src, ncols, nrows) + // Sort rotated arrays + slices.SortFunc(dstCopy, permutationFunc) + slices.SortFunc(srcCopy, permutationFunc) + // Check rotated arrays match + return Equals2d(dstCopy, srcCopy) +} - copy(dstCopy, dst) - copy(srcCopy, src) - // Sort arrays - slices.SortFunc(dstCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) }) - slices.SortFunc(srcCopy, func(l *fr.Element, r *fr.Element) int { return l.Cmp(r) }) - // Check they are equal - for i := 0; i < len(dst); i++ { - if dstCopy[i].Cmp(srcCopy[i]) != 0 { - return false +func permutationFunc(lhs []*fr.Element, rhs []*fr.Element) int { + for i := 0; i < len(lhs); i++ { + // Compare ith elements + c := lhs[i].Cmp(rhs[i]) + // Check whether same + if c != 0 { + // Positive + return c } } - // Match - return true + // Identical + return 0 } // PermutationSort sorts an array of columns in row-wise fashion. For @@ -54,17 +64,8 @@ func IsPermutationOf(dst []*fr.Element, src []*fr.Element) bool { func PermutationSort(cols [][]*fr.Element, signs []bool) { n := len(cols[0]) m := len(cols) - // - rows := make([][]*fr.Element, n) - // project into row-wise form - for i := 0; i < n; i++ { - row := make([]*fr.Element, m) - for j := 0; j < m; j++ { - row[j] = cols[j][i] - } - - rows[i] = row - } + // Rotate input matrix + rows := rotate(cols, m, n) // Perform the permutation sort slices.SortFunc(rows, func(l []*fr.Element, r []*fr.Element) int { return permutationSortFunc(l, r, signs) @@ -120,3 +121,20 @@ func permutationSortFunc(lhs []*fr.Element, rhs []*fr.Element, signs []bool) int // Identical return 0 } + +// Clone and rotate a 2-dimensional array assuming a given geometry. +func rotate(src [][]*fr.Element, ncols int, nrows int) [][]*fr.Element { + // Copy outer arrays + dst := make([][]*fr.Element, nrows) + // Copy inner arrays + for i := 0; i < nrows; i++ { + row := make([]*fr.Element, ncols) + for j := 0; j < ncols; j++ { + row[j] = src[j][i] + } + + dst[i] = row + } + // + return dst +}