Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Summary Information in trace Command #249

Merged
merged 2 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 127 additions & 28 deletions pkg/cmd/trace.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/consensys/go-corset/pkg/trace"
"github.com/consensys/go-corset/pkg/util"

"github.com/spf13/cobra"
)

Expand All @@ -27,6 +26,7 @@ var traceCmd = &cobra.Command{
// Parse trace
tr := readTraceFile(args[0])
list := getFlag(cmd, "list")
stats := getFlag(cmd, "stats")
print := getFlag(cmd, "print")
padding := getUint(cmd, "pad")
start := getUint(cmd, "start")
Expand All @@ -44,6 +44,9 @@ var traceCmd = &cobra.Command{
if list {
listColumns(tr)
}
if stats {
summaryStats(tr)
}
//
if output != "" {
writeTraceFile(output, tr)
Expand All @@ -55,6 +58,19 @@ var traceCmd = &cobra.Command{
},
}

func init() {
rootCmd.AddCommand(traceCmd)
traceCmd.Flags().BoolP("list", "l", false, "list only the columns in the trace file")
traceCmd.Flags().Bool("stats", false, "print summary information about the trace file")
traceCmd.Flags().BoolP("print", "p", false, "print entire trace file")
traceCmd.Flags().Uint("pad", 0, "add a given number of padding rows (to each module)")
traceCmd.Flags().UintP("start", "s", 0, "filter out rows below this")
traceCmd.Flags().UintP("end", "e", math.MaxUint, "filter out this and all following rows")
traceCmd.Flags().Uint("max-width", 32, "specify maximum display width for a column")
traceCmd.Flags().StringP("out", "o", "", "Specify output file to write trace")
traceCmd.Flags().StringP("filter", "f", "", "Filter columns beginning with prefix")
}

// Construct a new trace containing only those columns from the original who
// name begins with the given prefix.
func filterColumns(tr trace.Trace, prefix string) trace.Trace {
Expand Down Expand Up @@ -90,23 +106,6 @@ func filterColumns(tr trace.Trace, prefix string) trace.Trace {
// Done
return builder.Build()
}

func listColumns(tr trace.Trace) {
n := tr.Columns().Len()
tbl := util.NewTablePrinter(3, n)

for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i).Data()
elems := fmt.Sprintf("%d rows", ith.Len())
bytes := fmt.Sprintf("(%d*%d) = %d bytes", ith.Len(), ith.ByteWidth(), ith.ByteWidth()*ith.Len())
tbl.SetRow(i, QualifiedColumnName(i, tr), elems, bytes)
}

//
tbl.SetMaxWidth(64)
tbl.Print()
}

func printTrace(start uint, end uint, max_width uint, tr trace.Trace) {
cols := tr.Columns()
n := tr.Columns().Len()
Expand All @@ -133,14 +132,114 @@ func printTrace(start uint, end uint, max_width uint, tr trace.Trace) {
tbl.Print()
}

func init() {
rootCmd.AddCommand(traceCmd)
traceCmd.Flags().BoolP("list", "l", false, "list only the columns in the trace file")
traceCmd.Flags().BoolP("print", "p", false, "print entire trace file")
traceCmd.Flags().Uint("pad", 0, "add a given number of padding rows (to each module)")
traceCmd.Flags().UintP("start", "s", 0, "filter out rows below this")
traceCmd.Flags().UintP("end", "e", math.MaxUint, "filter out this and all following rows")
traceCmd.Flags().Uint("max-width", 32, "specify maximum display width for a column")
traceCmd.Flags().StringP("out", "o", "", "Specify output file to write trace")
traceCmd.Flags().StringP("filter", "f", "", "Filter columns beginning with prefix")
func listColumns(tr trace.Trace) {
// Determine number of columns
m := 1 + uint(len(summarisers))
// Determine number of rows
n := tr.Columns().Len()
// Go!
tbl := util.NewTablePrinter(m, n)

for i := uint(0); i < n; i++ {
ith := tr.Columns().Get(i)
row := make([]string, m)
row[0] = QualifiedColumnName(i, tr)
// Add summarises
for j := 0; j < len(summarisers); j++ {
row[j+1] = summarisers[j].summary(ith)
}
tbl.SetRow(i, row...)
}
//
tbl.SetMaxWidth(64)
tbl.Print()
}

func summaryStats(tr trace.Trace) {
m := uint(len(trSummarisers))
tbl := util.NewTablePrinter(2, m)
// Go!
for i := uint(0); i < m; i++ {
ith := trSummarisers[i]
summary := ith.summary(tr)
tbl.SetRow(i, ith.name, summary)
}
//
tbl.SetMaxWidth(64)
tbl.Print()
}

// ============================================================================
// Column Summarisers
// ============================================================================

// ColSummariser abstracts the notion of a function which summarises the
// contents of a given column.
type ColSummariser struct {
name string
summary func(trace.Column) string
}

var summarisers []ColSummariser = []ColSummariser{
{"count", rowSummariser},
{"width", widthSummariser},
{"bytes", bytesSummariser},
{"unique", uniqueSummariser},
}

func rowSummariser(col trace.Column) string {
return fmt.Sprintf("%d rows", col.Data().Len())
}

func widthSummariser(col trace.Column) string {
return fmt.Sprintf("%d bits", col.Data().ByteWidth()*8)
}

func bytesSummariser(col trace.Column) string {
return fmt.Sprintf("%d bytes", col.Data().Len()*col.Data().ByteWidth())
}

func uniqueSummariser(col trace.Column) string {
data := col.Data()
elems := util.NewHashSet[util.BytesKey](data.Len() / 2)
// Add all the elements
for i := uint(0); i < data.Len(); i++ {
bytes := util.FrElementToBytes(data.Get(i))
elems.Insert(util.NewBytesKey(bytes[:]))
}
// Done
return fmt.Sprintf("%d elements", elems.Size())
}

// ============================================================================
// Trace Summarisers
// ============================================================================

type traceSummariser struct {
name string
summary func(trace.Trace) string
}

var trSummarisers []traceSummariser = []traceSummariser{
trWidthSummariser(1, 8),
trWidthSummariser(9, 16),
trWidthSummariser(17, 32),
trWidthSummariser(33, 128),
trWidthSummariser(129, 256),
}

func trWidthSummariser(lowWidth uint, highWidth uint) traceSummariser {
return traceSummariser{
name: fmt.Sprintf("# Columns (%d..%d bits)", lowWidth, highWidth),
summary: func(tr trace.Trace) string {
count := 0
for i := uint(0); i < tr.Columns().Len(); i++ {
ithWidth := tr.Columns().Get(i).Data().ByteWidth() * 8
if ithWidth >= lowWidth && ithWidth <= highWidth {
count++
}
}
return fmt.Sprintf("%d", count)
},
}
}
2 changes: 1 addition & 1 deletion pkg/mir/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

func (e *ColumnAccess) String() string {
if e.Shift == 0 {
return fmt.Sprintf("#%d)", e.Column)
return fmt.Sprintf("#%d", e.Column)
}

return fmt.Sprintf("(shift #%d %d)", e.Column, e.Shift)
Expand Down
15 changes: 15 additions & 0 deletions pkg/util/fields.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package util

import (
"encoding/binary"

"github.com/consensys/gnark-crypto/ecc/bls12-377/fr"
)

Expand All @@ -25,3 +27,16 @@ func Pow(val *fr.Element, n uint64) {
}
}
}

// FrElementToBytes converts a given field element into a slice of 32 bytes.
func FrElementToBytes(element *fr.Element) [32]byte {
// Each fr.Element is 4 x 64bit words.
var bytes [32]byte
// Copy over each element
binary.BigEndian.PutUint64(bytes[:], element[0])
binary.BigEndian.PutUint64(bytes[8:], element[1])
binary.BigEndian.PutUint64(bytes[16:], element[2])
binary.BigEndian.PutUint64(bytes[24:], element[3])
// Done
return bytes
}