diff --git a/README.md b/README.md index fa712b0..9630f1f 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ New features will be prioritised based on demand. If there's something you'd lik - [X] `PUSH(v)` length detection - [x] Macros - [x] Compiler-state assertions (e.g. expected stack depth) -- [ ] Automatic (most-efficient) stack transformation - - [ ] Permutations +- [x] Determination of optimal (least-gas) stack transformations + - [x] Permutations - [ ] Duplications w/ permutations - [ ] Standalone compiler - [x] In-process EVM execution (geth) diff --git a/compile.go b/compile.go index ced204b..baa8483 100644 --- a/compile.go +++ b/compile.go @@ -163,15 +163,21 @@ CodeLoop: return nil, err } - op := vm.OpCode(code[0]) - d, ok := stackDeltas[op] - if !ok { - return nil, posErr("invalid %T(%v) as first byte returned by Bytecode()", op, op) - } - if stackDepth < d.pop { - return nil, posErr("popping %d values with stack depth %d", d.pop, stackDepth) + for i, n := 0, len(code); i < n; i++ { + op := vm.OpCode(code[i]) + d, ok := stackDeltas[op] + if !ok { + return nil, posErr("invalid %T(%v) as byte [%d] returned by Bytecode()", op, op, i) + } + if stackDepth < d.pop { + return nil, posErr("Bytecode()[%d] popping %d values with stack depth %d", i, d.pop, stackDepth) + } + stackDepth += d.push - d.pop // we're not in Solidity anymore ;) + + if op.IsPush() { + i += int(op - vm.PUSH0) + } } - stackDepth += d.push - d.pop // we're not in Solidity anymore ;) buf.Write(code) } diff --git a/runopts/debugger_test.go b/runopts/debugger_test.go index 70be9ac..7ef6022 100644 --- a/runopts/debugger_test.go +++ b/runopts/debugger_test.go @@ -88,12 +88,6 @@ func TestDebuggerCompilationError(t *testing.T) { } func TestDebuggerErrors(t *testing.T) { - const invalid = vm.OpCode(0xf8) - if vm.StringToOp(invalid.String()) != 0 { - // This may happen if the above opcode is added. Any invalid value suffices. - t.Fatalf("Bad test setup; %[1]T(%[1]d) = %[1]v is valid; want invalid", invalid) - } - tests := []struct { name string code Code @@ -116,7 +110,7 @@ func TestDebuggerErrors(t *testing.T) { { name: "invalid opcode", code: Code{ - Raw{byte(invalid)}, + Raw{byte(INVALID)}, }, wantErrType: reflect.TypeOf(new(vm.ErrInvalidOpCode)), }, diff --git a/stack/BUILD.bazel b/stack/BUILD.bazel index 97e1890..f9fb8be 100644 --- a/stack/BUILD.bazel +++ b/stack/BUILD.bazel @@ -1,8 +1,24 @@ -load("@rules_go//go:def.bzl", "go_library") +load("@rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "stack", - srcs = ["stack.go"], + srcs = [ + "stack.go", + "transform.go", + ], importpath = "github.com/solidifylabs/specops/stack", visibility = ["//visibility:public"], + deps = ["@com_github_ethereum_go_ethereum//core/vm"], +) + +go_test( + name = "stack_test", + srcs = ["transform_test.go"], + deps = [ + ":stack", + "//:specops", + "//evmdebug", + "@com_github_ethereum_go_ethereum//core/vm", + "@com_github_google_go_cmp//cmp", + ], ) diff --git a/stack/transform.go b/stack/transform.go new file mode 100644 index 0000000..60fb56d --- /dev/null +++ b/stack/transform.go @@ -0,0 +1,204 @@ +package stack + +import ( + "errors" + "fmt" + "strings" + + "github.com/ethereum/go-ethereum/core/vm" +) + +type xFormType int + +const ( + unknownXform xFormType = iota + permutation +) + +// A Transformation transforms the stack by modifying its order, growing, and/or +// shrinking it. +type Transformation struct { + typ xFormType + indices []uint8 + cache string +} + +// Permute returns a Transformation that permutes the order of the stack. The +// indices MUST be a contiguous set of distinct values [0,n) in any order. +func Permute(indices ...uint8) *Transformation { + return &Transformation{ + typ: permutation, + indices: indices, + } +} + +// Bytecode returns the stack-transforming opcodes (SWAP, DUP, etc) necessary to +// achieve the transformation in the most efficient manner. +func (t *Transformation) Bytecode() ([]byte, error) { + if t.cache != "" { + return t.cached() + } + + switch t.typ { + case permutation: + return t.permute() + default: + return nil, fmt.Errorf("invalid %T.typ = %d", t, t.typ) + } +} + +func (t *Transformation) cached() ([]byte, error) { + return nil, errors.New("cached transformations unimplemented") +} + +// permute checks that t.indices is valid for a permutation and then returns +// t.bfs(). +func (t *Transformation) permute() ([]byte, error) { + if n := len(t.indices); n > 16 { + return nil, fmt.Errorf("can only permute up to 16 stack items; got %d", n) + } + + set := make(map[uint8]bool) + for _, idx := range t.indices { + if set[idx] { + return nil, fmt.Errorf("duplicate index %d in permutation %v", idx, t.indices) + } + set[idx] = true + } + + for i := range t.indices { // explicitly not `_, i` like last loop + if !set[uint8(i)] { + return nil, fmt.Errorf("non-contiguous indices in permutation %v; missing %d", t.indices, i) + } + } + return t.bfs(len(t.indices)) +} + +// bfs performs a breadth-first search over a graph of stack-value orders, +// starting from the root, in-order node [0, size). Edges represent nodes that +// are reachable with only a single opcode. +// +// bfs should be called by the transformation-type-specific methods that first +// check for valid indices. bfs itself is, however, type-agnostic. +func (t *Transformation) bfs(size int) ([]byte, error) { + if size == 0 || size > 16 { + return nil, fmt.Errorf("invalid %T size %d", t, size) + } + + root := rootNode(uint8(size)) + want := nodeFromIndices(t.indices) + if want == root { + return nil, nil + } + + // An implicit graph representation that only has nodes added when enqueued + // by the BFS. + graph := transformationPaths{ + root: nil, + } + + for queue := []node{root}; len(queue) > 0; { + curr := queue[0] + queue = queue[1:] + currPath, ok := graph[curr] + if !ok { + return nil, fmt.Errorf("BUG: node %q in queue but not in graph", curr) + } + + // SWAPs are limited to n-1 because they're 1-indexed in the stack + for i, n := 0, len(t.indices)-1; i < n; i++ { + op := vm.SWAP1 + vm.OpCode(i) + next, err := curr.apply(op) + if err != nil { + return nil, err + } + // The next node has already been visited and, since this is an + // unweighted graph, BFS ordering is sufficient for the shortest + // path. + if _, ok := graph[next]; ok { + continue + } + + nextPath := make(path, len(currPath)+1) + copy(nextPath, currPath) + nextPath[len(currPath)] = op + + if next == want { + return nextPath.bytes(), nil + } + + graph[next] = nextPath + queue = append(queue, next) + } + } + + // This should never happen (famous last words!) + return nil, fmt.Errorf("stack transformation %v not reached by BFS", t.indices) +} + +// transformationPaths represent the paths to reach the specific node from the +// rootNode(). +type transformationPaths map[node]path + +// A node represents a slice of stack indices as a string so it can be used as a +// map key. To aid in debugging, it represents each index as a hex character, +// however this MUST NOT be relied upon to be stable. +type node string + +// A path represents a set of opcodes which, if applied in order, transform the +// root node into another. +type path []vm.OpCode + +// nodeFromIndices converts the indices into a node. +func nodeFromIndices(is []uint8) node { + var s strings.Builder + for _, i := range is { + switch { + case i < 10: + s.WriteByte('0' + i) + case i < 16: + s.WriteByte('a' + i - 10) + default: + // If this happens then there's a broken invariant that should have + // been prevented by an error-returning path. Panicking here is only + // possible if there's a bug. + panic(fmt.Sprintf("BUG: invalid index value %d > 15", i)) + } + } + return node(s.String()) +} + +// rootNode returns the node representing [0, …, size). +func rootNode(size uint8) node { + buf := make([]byte, size) + for i := range buf { + buf[i] = byte(i) + } + return nodeFromIndices(buf) +} + +// apply returns a *new* node equivalent to applying the opcode to n. +func (n node) apply(o vm.OpCode) (node, error) { + switch base := o & 0xf0; { + case base == vm.SWAP1: + out := make([]byte, len(n)) + copy(out, []byte(n)) + + i := o - vm.SWAP1 + 1 + out[0], out[i] = out[i], out[0] // invariants in the BFS loop guarantee that these are in range + + return node(out), nil + + default: + return "", fmt.Errorf("unsupported transformation %T(%v)", o, o) + } +} + +// bytes returns p, verbatim, as bytes. +func (p path) bytes() []byte { + out := make([]byte, len(p)) + for i, pp := range p { + out[i] = byte(pp) + } + return out +} diff --git a/stack/transform_test.go b/stack/transform_test.go new file mode 100644 index 0000000..c537219 --- /dev/null +++ b/stack/transform_test.go @@ -0,0 +1,131 @@ +// Package stack_test avoids a circular dependency between the specops and stack +// packages. +package stack_test + +import ( + "fmt" + "math/rand" + "testing" + + "github.com/ethereum/go-ethereum/core/vm" + "github.com/google/go-cmp/cmp" + "github.com/solidifylabs/specops" + "github.com/solidifylabs/specops/evmdebug" + "github.com/solidifylabs/specops/stack" +) + +func intPtr(x int) *int { + return &x +} + +func TestPermute(t *testing.T) { + type test struct { + indices []uint8 + wantNumSwaps *int // don't know when fuzzing so only test if non-nil + } + + tests := []test{ + { + indices: []uint8{0, 1, 2, 3}, + wantNumSwaps: intPtr(0), + }, + { + indices: []uint8{7, 1, 2, 3, 4, 5, 6, 0}, + wantNumSwaps: intPtr(1), + }, + { + indices: []uint8{4, 1, 2, 3, 0, 5, 6}, + wantNumSwaps: intPtr(1), + }, + { + indices: []uint8{2, 1, 0, 3}, + }, + { + indices: []uint8{3, 2, 1, 0}, + }, + { + indices: []uint8{5, 0, 6, 3, 4, 2, 1}, + }, + } + + rng := rand.New(rand.NewSource(42)) + for i := 0; i < 20; i++ { + in := []uint8{0, 1, 2, 3, 4, 5, 6, 7} + rng.Shuffle(len(in), func(i, j int) { + in[i], in[j] = in[j], in[i] + }) + tests = append(tests, test{indices: in}) + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("%v", tt.indices), func(t *testing.T) { + var code specops.Code + for i := range tt.indices { // explicitly not _, i + code = append(code, specops.PUSH(len(tt.indices)-i-1)) // {0 … n-1} top to bottom + } + + perm := stack.Permute(tt.indices...) + code = append(code, perm) + + swaps, err := perm.Bytecode() + if err != nil { + t.Fatalf("Bad test setup; Permute(%v).Bytecode() error %v", tt.indices, err) + } + for _, s := range swaps { + t.Log(vm.OpCode(s)) + } + if got := len(swaps); tt.wantNumSwaps != nil && got != *tt.wantNumSwaps { + t.Errorf("Permute(%v) got %d swaps; want %d", tt.indices, got, *tt.wantNumSwaps) + } + + dbg, _, err := code.StartDebugging(nil) + if err != nil { + t.Fatalf("%T.StartDebugging(nil) error %v", code, err) + } + defer dbg.FastForward() + + for i := 0; i < len(tt.indices); i++ { + dbg.Step() + } + inOrder := make([]uint8, len(tt.indices)) + for i := range inOrder { + inOrder[i] = uint8(i) + } + t.Run("after PUSHing indices in order", stackTest(dbg, inOrder)) + + for i := 0; i < len(swaps); i++ { + dbg.Step() + } + t.Run("after SWAPing based on Permute()", stackTest(dbg, tt.indices)) + }) + } +} + +// stackTest returns a test function that checks the current stack values. +func stackTest(dbg *evmdebug.Debugger, want8 []uint8) func(*testing.T) { + return func(t *testing.T) { + t.Helper() + st := dbg.State() + if st.Err != nil { + t.Fatalf("%T.State().Err = %v; want nil", dbg, st.Err) + } + + var got []uint64 + stack := st.ScopeContext.Stack + for i, n := 0, len(stack.Data()); i < n; i++ { + g := stack.Back(i) + if !g.IsUint64() { + t.Fatalf("%T.State().ScopeContext.Stack.Data()[%d] not representable as uint64", dbg, i) + } + got = append(got, g.Uint64()) + } + + want := make([]uint64, len(want8)) + for i, w := range want8 { + want[i] = uint64(w) + } + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Stack [top to bottom] diff (-want +got):\n%s", diff) + } + } +}