diff --git a/go.mod b/go.mod index b25a896..f34594e 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/consensys/gnark v0.8.0 github.com/iden3/go-iden3-crypto v0.0.13 github.com/prometheus/client_golang v1.14.0 + github.com/reilabs/gnark-lean-extractor v1.1.0 github.com/urfave/cli/v2 v2.10.2 ) @@ -25,7 +26,6 @@ require ( github.com/prometheus/client_model v0.3.0 // indirect github.com/prometheus/common v0.37.0 // indirect github.com/prometheus/procfs v0.8.0 // indirect - github.com/reilabs/gnark-lean-extractor v1.1.0 // indirect github.com/rogpeppe/go-internal v1.10.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect diff --git a/integration_test.go b/integration_test.go index a4006d1..a7414fe 100644 --- a/integration_test.go +++ b/integration_test.go @@ -1,6 +1,7 @@ -package main +package main_test import ( + gnarkLogger "github.com/consensys/gnark/logger" "io" "net/http" "strings" @@ -13,24 +14,39 @@ import ( const ProverAddress = "localhost:8080" const MetricsAddress = "localhost:9999" +var mode string + func TestMain(m *testing.M) { + gnarkLogger.Set(*logging.Logger()) logging.Logger().Info().Msg("Setting up the prover") - ps, err := prover.Setup(3, 2) + ps, err := prover.SetupInsertion(3, 2) if err != nil { panic(err) } cfg := server.Config{ ProverAddress: ProverAddress, MetricsAddress: MetricsAddress, + Mode: server.InsertionMode, } - logging.Logger().Info().Msg("Starting the server") + logging.Logger().Info().Msg("Starting the insertion server") instance := server.Run(&cfg, ps) - logging.Logger().Info().Msg("Running the tests") - defer func() { - instance.RequestStop() - instance.AwaitStop() - }() + logging.Logger().Info().Msg("Running the insertion tests") + mode = server.InsertionMode + m.Run() + instance.RequestStop() + instance.AwaitStop() + cfg.Mode = server.DeletionMode + ps, err = prover.SetupDeletion(3, 2) + if err != nil { + panic(err) + } + logging.Logger().Info().Msg("Starting the deletion server") + instance = server.Run(&cfg, ps) + logging.Logger().Info().Msg("Running the deletion tests") + mode = server.DeletionMode m.Run() + instance.RequestStop() + instance.AwaitStop() } func TestWrongMethod(t *testing.T) { @@ -43,7 +59,10 @@ func TestWrongMethod(t *testing.T) { } } -func TestHappyPath(t *testing.T) { +func TestInsertionHappyPath(t *testing.T) { + if mode != server.InsertionMode { + return + } body := `{ "inputHash":"0x5057a31740d54d42ac70c05e0768fb770c682cb2c559bdd03fe4099f7e584e4f", "startIndex":0, @@ -63,7 +82,33 @@ func TestHappyPath(t *testing.T) { } } -func TestWrongInput(t *testing.T) { +func TestDeletionHappyPath(t *testing.T) { + if mode != server.DeletionMode { + return + } + body := `{ + "inputHash":"0xdcd389a94b549222fadc9e335c358a3fe4d534155182f46927f82ea8491c7480", + "deletionIndices":[0,2], + "preRoot":"0xd11eefe87b985333c0d327b0cdd39a9641b5ac32c35c2bda84301ef3231a8ac", + "postRoot":"0x1912415186579e1d9ff6282b76d081f0acd527d8549ea803385b1382d9498f35", + "identityCommitments":["0x1","0x3"], + "merkleProofs":[ + ["0x2","0x20a3af0435914ccd84b806164531b0cd36e37d4efb93efab76913a93e1f30996","0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1"], + ["0x4","0x65e2c6cc08a36c4a943286bc91c216054a1981eb4f7570f67394ef8937a21b8","0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1"] + ]}` + response, err := http.Post("http://localhost:8080/prove", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != http.StatusOK { + t.Fatalf("Expected status code %d, got %d", http.StatusOK, response.StatusCode) + } +} + +func TestInsertionWrongInput(t *testing.T) { + if mode != server.InsertionMode { + return + } body := `{ "inputHash":"0x5057a31740d54d42ac70c05e0768fb770c682cb2c559bdd03fe4099f7e584e4f", "startIndex":0, @@ -88,4 +133,35 @@ func TestWrongInput(t *testing.T) { if !strings.Contains(string(responseBody), "proving_error") { t.Fatalf("Expected error message to be tagged with 'proving_error', got %s", string(responseBody)) } + +} + +func TestDeletionWrongInput(t *testing.T) { + if mode != server.DeletionMode { + return + } + body := `{ + "inputHash":"0xdcd389a94b549222fadc9e335c358a3fe4d534155182f46927f82ea8491c7480", + "deletionIndices":[0,2], + "preRoot":"0xd11eefe87b985333c0d327b0cdd39a9641b5ac32c35c2bda84301ef3231a8ac", + "postRoot":"0x1912415186579e1d9ff6282b76d081f0acd527d8549ea803385b1382d9498f35", + "identityCommitments":["0x1","0x3"], + "merkleProofs":[ + ["0x2","0xD","0xD"], + ["0x4","0xD","0xD"] + ]}` + response, err := http.Post("http://localhost:8080/prove", "application/json", strings.NewReader(body)) + if err != nil { + t.Fatal(err) + } + if response.StatusCode != http.StatusBadRequest { + t.Fatalf("Expected status code %d, got %d", http.StatusBadRequest, response.StatusCode) + } + responseBody, err := io.ReadAll(response.Body) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(responseBody), "proving_error") { + t.Fatalf("Expected error message to be tagged with 'proving_error', got %s", string(responseBody)) + } } diff --git a/main.go b/main.go index e6c5e3f..9670744 100644 --- a/main.go +++ b/main.go @@ -3,8 +3,6 @@ package main import ( "encoding/json" "fmt" - gnarkLogger "github.com/consensys/gnark/logger" - "github.com/urfave/cli/v2" "io" "math/big" "os" @@ -12,6 +10,10 @@ import ( "worldcoin/gnark-mbu/logging" "worldcoin/gnark-mbu/prover" "worldcoin/gnark-mbu/server" + + "github.com/consensys/gnark/constraint" + gnarkLogger "github.com/consensys/gnark/logger" + "github.com/urfave/cli/v2" ) func main() { @@ -22,16 +24,29 @@ func main() { { Name: "setup", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + path := context.String("output") treeDepth := uint32(context.Uint("tree-depth")) batchSize := uint32(context.Uint("batch-size")) logging.Logger().Info().Msg("Running setup") - system, err := prover.Setup(treeDepth, batchSize) + + var system *prover.ProvingSystem + var err error + if mode == server.InsertionMode { + system, err = prover.SetupInsertion(treeDepth, batchSize) + } else if mode == server.DeletionMode { + system, err = prover.SetupDeletion(treeDepth, batchSize) + } else { + return fmt.Errorf("Invalid mode: %s", mode) + } + if err != nil { return err } @@ -51,16 +66,30 @@ func main() { { Name: "r1cs", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + path := context.String("output") treeDepth := uint32(context.Uint("tree-depth")) batchSize := uint32(context.Uint("batch-size")) logging.Logger().Info().Msg("Building R1CS") - cs, err := prover.BuildR1CS(treeDepth, batchSize) + + var cs constraint.ConstraintSystem + var err error + + if mode == server.InsertionMode { + cs, err = prover.BuildR1CSInsertion(treeDepth, batchSize) + } else if mode == server.DeletionMode { + cs, err = prover.BuildR1CSDeletion(treeDepth, batchSize) + } else { + return fmt.Errorf("Invalid mode: %s", mode) + } + if err != nil { return err } @@ -106,28 +135,62 @@ func main() { { Name: "gen-test-params", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.UintFlag{Name: "tree-depth", Usage: "depth of the mock tree", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + treeDepth := context.Int("tree-depth") batchSize := uint32(context.Uint("batch-size")) - logging.Logger().Info().Msg("Generating test params") + logging.Logger().Info().Msg("Generating test params for the insertion circuit") + + var r []byte + var err error - params := prover.Parameters{} - tree := NewTree(treeDepth) + if mode == server.InsertionMode { + params := prover.InsertionParameters{} + tree := NewTree(treeDepth) + + params.StartIndex = 0 + params.PreRoot = tree.Root() + params.IdComms = make([]big.Int, batchSize) + params.MerkleProofs = make([][]big.Int, batchSize) + for i := 0; i < int(batchSize); i++ { + params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) + params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) + } + params.PostRoot = tree.Root() + params.ComputeInputHashInsertion() + r, err = json.Marshal(¶ms) + } else if mode == server.DeletionMode { + params := prover.DeletionParameters{} + tree := NewTree(treeDepth) + + params.DeletionIndices = make([]uint32, batchSize) + params.IdComms = make([]big.Int, batchSize) + params.MerkleProofs = make([][]big.Int, batchSize) + for i := 0; i < int(batchSize*2); i++ { + tree.Update(i, *new(big.Int).SetUint64(uint64(i + 1))) + } + params.PreRoot = tree.Root() + for i := 0; i < int(batchSize); i++ { + params.DeletionIndices[i] = uint32(2 * i) + params.IdComms[i] = *new(big.Int).SetUint64(uint64(2*i + 1)) + params.MerkleProofs[i] = tree.Update(2*i, *big.NewInt(0)) + } + params.PostRoot = tree.Root() + params.ComputeInputHashDeletion() + r, err = json.Marshal(¶ms) + } else { + return fmt.Errorf("Invalid mode: %s", mode) + } - params.StartIndex = 0 - params.PreRoot = tree.Root() - params.IdComms = make([]big.Int, batchSize) - params.MerkleProofs = make([][]big.Int, batchSize) - for i := 0; i < int(batchSize); i++ { - params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) - params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) + if err != nil { + return err } - params.PostRoot = tree.Root() - params.ComputeInputHash() - r, _ := json.Marshal(¶ms) + fmt.Println(string(r)) return nil }, @@ -135,6 +198,7 @@ func main() { { Name: "start", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, &cli.BoolFlag{Name: "json-logging", Usage: "enable JSON logging", Required: false}, &cli.StringFlag{Name: "prover-address", Usage: "address for the prover server", Value: "localhost:3001", Required: false}, @@ -145,6 +209,12 @@ func main() { logging.SetJSONOutput() } keys := context.String("keys-file") + mode := context.String("mode") + + if mode != server.DeletionMode && mode != server.InsertionMode { + return fmt.Errorf("invalid mode: %s", mode) + } + logging.Logger().Info().Msg("Reading proving system from file") ps, err := prover.ReadSystemFromFile(keys) if err != nil { @@ -154,6 +224,7 @@ func main() { config := server.Config{ ProverAddress: context.String("prover-address"), MetricsAddress: context.String("metrics-address"), + Mode: mode, } instance := server.Run(&config, ps) sigint := make(chan os.Signal, 1) @@ -169,9 +240,12 @@ func main() { { Name: "prove", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + keys := context.String("keys-file") ps, err := prover.ReadSystemFromFile(keys) if err != nil { @@ -183,13 +257,28 @@ func main() { if err != nil { return err } - var params prover.Parameters - err = json.Unmarshal(bytes, ¶ms) - if err != nil { - return err + + var proof *prover.Proof + if mode == server.InsertionMode { + var params prover.InsertionParameters + err = json.Unmarshal(bytes, ¶ms) + if err != nil { + return err + } + logging.Logger().Info().Msg("params read successfully") + proof, err = ps.ProveInsertion(¶ms) + } else if mode == server.DeletionMode { + var params prover.DeletionParameters + err = json.Unmarshal(bytes, ¶ms) + if err != nil { + return err + } + logging.Logger().Info().Msg("params read successfully") + proof, err = ps.ProveDeletion(¶ms) + } else { + return fmt.Errorf("Invalid mode: %s", mode) } - logging.Logger().Info().Msg("params read successfully") - proof, err := ps.Prove(¶ms) + if err != nil { return err } @@ -201,10 +290,13 @@ func main() { { Name: "verify", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, &cli.StringFlag{Name: "input-hash", Usage: "the hash of all public inputs", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + keys := context.String("keys-file") var inputHash big.Int _, ok := inputHash.SetString(context.String("input-hash"), 0) @@ -227,7 +319,15 @@ func main() { return err } logging.Logger().Info().Msg("proof read successfully") - err = ps.Verify(inputHash, &proof) + + if mode == server.InsertionMode { + err = ps.VerifyInsertion(inputHash, &proof) + } else if mode == server.DeletionMode { + err = ps.VerifyDeletion(inputHash, &proof) + } else { + return fmt.Errorf("Invalid mode: %s", mode) + } + if err != nil { return err } diff --git a/prover/circuit.go b/prover/circuit_utils.go similarity index 50% rename from prover/circuit.go rename to prover/circuit_utils.go index d3651b0..861957f 100644 --- a/prover/circuit.go +++ b/prover/circuit_utils.go @@ -1,33 +1,30 @@ package prover import ( + "io" "strconv" - "worldcoin/gnark-mbu/prover/keccak" "worldcoin/gnark-mbu/prover/poseidon" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint" "github.com/consensys/gnark/frontend" "github.com/reilabs/gnark-lean-extractor/abstractor" ) -const emptyLeaf = 0 - -type MbuCircuit struct { - // single public input - InputHash frontend.Variable `gnark:",public"` - - // private inputs, but used as public inputs - StartIndex frontend.Variable `gnark:"input"` - PreRoot frontend.Variable `gnark:"input"` - PostRoot frontend.Variable `gnark:"input"` - IdComms []frontend.Variable `gnark:"input"` - - // private inputs - MerkleProofs [][]frontend.Variable `gnark:"input"` +type Proof struct { + Proof groth16.Proof +} - BatchSize int - Depth int +type ProvingSystem struct { + TreeDepth uint32 + BatchSize uint32 + ProvingKey groth16.ProvingKey + VerifyingKey groth16.VerifyingKey + ConstraintSystem constraint.ConstraintSystem } +const emptyLeaf = 0 + type bitPatternLengthError struct { actualLength int } @@ -100,79 +97,13 @@ func FromBinaryBigEndian(bitsBigEndian []frontend.Variable, api frontend.API) (v return api.FromBinary(bitsLittleEndian...), nil } -func (circuit *MbuCircuit) Define(api frontend.API) error { - // Hash private inputs. - // We keccak hash all input to save verification gas. Inputs are arranged as follows: - // StartIndex || PreRoot || PostRoot || IdComms[0] || IdComms[1] || ... || IdComms[batchSize-1] - // 32 || 256 || 256 || 256 || 256 || ... || 256 bits - - kh := keccak.NewKeccak256(api, (circuit.BatchSize+2)*256+32) - - var bits []frontend.Variable - var err error - - // We convert all the inputs to the keccak hash to use big-endian (network) byte - // ordering so that it agrees with Solidity. This ensures that we don't have to - // perform the conversion inside the contract and hence save on gas. - bits, err = ToBinaryBigEndian(circuit.StartIndex, 32, api) - if err != nil { - return err - } - kh.Write(bits...) - - bits, err = ToBinaryBigEndian(circuit.PreRoot, 256, api) - if err != nil { - return err +func toBytesLE(b []byte) []byte { + for i := 0; i < len(b)/2; i++ { + b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] } - kh.Write(bits...) - - bits, err = ToBinaryBigEndian(circuit.PostRoot, 256, api) - if err != nil { - return err - } - kh.Write(bits...) - - for i := 0; i < circuit.BatchSize; i++ { - bits, err = ToBinaryBigEndian(circuit.IdComms[i], 256, api) - if err != nil { - return err - } - kh.Write(bits...) - } - - var sum frontend.Variable - sum, err = FromBinaryBigEndian(kh.Sum(), api) - if err != nil { - return err - } - - // The same endianness conversion has been performed in the hash generation - // externally, so we can safely assert their equality here. - api.AssertIsEqual(circuit.InputHash, sum) - - // Actual batch merkle proof verification. - var root frontend.Variable - - prevRoot := circuit.PreRoot - - // Individual insertions. - for i := 0; i < circuit.BatchSize; i += 1 { - currentIndex := api.Add(circuit.StartIndex, i) - currentPath := api.ToBinary(currentIndex, circuit.Depth) - - // Verify proof for empty leaf. - root = VerifyProof(api, append([]frontend.Variable{emptyLeaf}, circuit.MerkleProofs[i][:]...), currentPath) - api.AssertIsEqual(root, prevRoot) - - // Verify proof for idComm. - root = VerifyProof(api, append([]frontend.Variable{circuit.IdComms[i]}, circuit.MerkleProofs[i][:]...), currentPath) - - // Set root for next iteration. - prevRoot = root - } - - // Final root needs to match. - api.AssertIsEqual(root, circuit.PostRoot) + return b +} - return nil +func (ps *ProvingSystem) ExportSolidity(writer io.Writer) error { + return ps.VerifyingKey.ExportSolidity(writer) } diff --git a/prover/deletion_circuit.go b/prover/deletion_circuit.go new file mode 100644 index 0000000..7b81b86 --- /dev/null +++ b/prover/deletion_circuit.go @@ -0,0 +1,90 @@ +package prover + +import ( + "worldcoin/gnark-mbu/prover/keccak" + + "github.com/consensys/gnark/frontend" +) + +type DeletionMbuCircuit struct { + // single public input + InputHash frontend.Variable `gnark:",public"` + + // private inputs, but used as public inputs + DeletionIndices []frontend.Variable `gnark:"input"` + PreRoot frontend.Variable `gnark:"input"` + PostRoot frontend.Variable `gnark:"input"` + + // private inputs + IdComms []frontend.Variable `gnark:"input"` + MerkleProofs [][]frontend.Variable `gnark:"input"` + + BatchSize int + Depth int +} + +func (circuit *DeletionMbuCircuit) Define(api frontend.API) error { + // Hash private inputs. + // We keccak hash all input to save verification gas. Inputs are arranged as follows: + // deletionIndices[0] || deletionIndices[1] || ... || deletionIndices[batchSize-1] || PreRoot || PostRoot + // 32 || 32 || ... || 32 || 256 || 256 + kh := keccak.NewKeccak256(api, circuit.BatchSize*32+2*256) + + var bits []frontend.Variable + var err error + + for i := 0; i < circuit.BatchSize; i++ { + bits, err = ToBinaryBigEndian(circuit.DeletionIndices[i], 32, api) + if err != nil { + return err + } + kh.Write(bits...) + } + + bits, err = ToBinaryBigEndian(circuit.PreRoot, 256, api) + if err != nil { + return err + } + kh.Write(bits...) + + bits, err = ToBinaryBigEndian(circuit.PostRoot, 256, api) + if err != nil { + return err + } + kh.Write(bits...) + + var sum frontend.Variable + sum, err = FromBinaryBigEndian(kh.Sum(), api) + if err != nil { + return err + } + + // The same endianness conversion has been performed in the hash generation + // externally, so we can safely assert their equality here. + api.AssertIsEqual(circuit.InputHash, sum) + + // Actual batch merkle proof verification. + var root frontend.Variable + + prevRoot := circuit.PreRoot + + // Individual insertions. + for i := 0; i < circuit.BatchSize; i += 1 { + currentPath := api.ToBinary(circuit.DeletionIndices[i], circuit.Depth) + + // Verify proof for idComm. + root = VerifyProof(api, append([]frontend.Variable{circuit.IdComms[i]}, circuit.MerkleProofs[i][:]...), currentPath) + api.AssertIsEqual(root, prevRoot) + + // Verify proof for empty leaf. + root = VerifyProof(api, append([]frontend.Variable{emptyLeaf}, circuit.MerkleProofs[i][:]...), currentPath) + + // Set root for next iteration. + prevRoot = root + } + + // Final root needs to match. + api.AssertIsEqual(root, circuit.PostRoot) + + return nil +} diff --git a/prover/deletion_proving_system.go b/prover/deletion_proving_system.go new file mode 100644 index 0000000..a1a7e8d --- /dev/null +++ b/prover/deletion_proving_system.go @@ -0,0 +1,145 @@ +package prover + +import ( + "bytes" + "encoding/binary" + "fmt" + "math/big" + "worldcoin/gnark-mbu/logging" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" + "github.com/iden3/go-iden3-crypto/keccak256" +) + +type DeletionParameters struct { + InputHash big.Int + PreRoot big.Int + PostRoot big.Int + DeletionIndices []uint32 + IdComms []big.Int + MerkleProofs [][]big.Int +} + +func (p *DeletionParameters) ValidateShape(treeDepth uint32, batchSize uint32) error { + if len(p.IdComms) != int(batchSize) { + return fmt.Errorf("wrong number of identity commitments: %d", len(p.IdComms)) + } + if len(p.MerkleProofs) != int(batchSize) { + return fmt.Errorf("wrong number of merkle proofs: %d", len(p.MerkleProofs)) + } + if len(p.DeletionIndices) != int(batchSize) { + return fmt.Errorf("wrong number of deletion indices: %d", len(p.DeletionIndices)) + } + for i, proof := range p.MerkleProofs { + if len(proof) != int(treeDepth) { + return fmt.Errorf("wrong size of merkle proof for proof %d: %d", i, len(proof)) + } + } + return nil +} + +// ComputeInputHashDeletion computes the input hash to the prover and verifier. +// +// It uses big-endian byte ordering (network ordering) in order to agree with +// Solidity and avoid the need to perform the byte swapping operations on-chain +// where they would increase our gas cost. +func (p *DeletionParameters) ComputeInputHashDeletion() error { + var data []byte + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.BigEndian, p.DeletionIndices) + if err != nil { + return err + } + data = append(data, buf.Bytes()...) + data = append(data, p.PreRoot.Bytes()...) + data = append(data, p.PostRoot.Bytes()...) + + hashBytes := keccak256.Hash(data) + p.InputHash.SetBytes(hashBytes) + return nil +} + +func BuildR1CSDeletion(treeDepth uint32, batchSize uint32) (constraint.ConstraintSystem, error) { + proofs := make([][]frontend.Variable, batchSize) + for i := 0; i < int(batchSize); i++ { + proofs[i] = make([]frontend.Variable, treeDepth) + } + circuit := DeletionMbuCircuit{ + Depth: int(treeDepth), + BatchSize: int(batchSize), + DeletionIndices: make([]frontend.Variable, batchSize), + IdComms: make([]frontend.Variable, batchSize), + MerkleProofs: proofs, + } + return frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) +} + +func SetupDeletion(treeDepth uint32, batchSize uint32) (*ProvingSystem, error) { + ccs, err := BuildR1CSDeletion(treeDepth, batchSize) + if err != nil { + return nil, err + } + pk, vk, err := groth16.Setup(ccs) + if err != nil { + return nil, err + } + return &ProvingSystem{treeDepth, batchSize, pk, vk, ccs}, nil +} + +func (ps *ProvingSystem) ProveDeletion(params *DeletionParameters) (*Proof, error) { + if err := params.ValidateShape(ps.TreeDepth, ps.BatchSize); err != nil { + return nil, err + } + + deletionIndices := make([]frontend.Variable, ps.BatchSize) + for i := 0; i < int(ps.BatchSize); i++ { + deletionIndices[i] = params.DeletionIndices[i] + } + + idComms := make([]frontend.Variable, ps.BatchSize) + for i := 0; i < int(ps.BatchSize); i++ { + idComms[i] = params.IdComms[i] + } + proofs := make([][]frontend.Variable, ps.BatchSize) + for i := 0; i < int(ps.BatchSize); i++ { + proofs[i] = make([]frontend.Variable, ps.TreeDepth) + for j := 0; j < int(ps.TreeDepth); j++ { + proofs[i][j] = params.MerkleProofs[i][j] + } + } + assignment := DeletionMbuCircuit{ + InputHash: params.InputHash, + DeletionIndices: deletionIndices, + PreRoot: params.PreRoot, + PostRoot: params.PostRoot, + IdComms: idComms, + MerkleProofs: proofs, + } + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + return nil, err + } + logging.Logger().Info().Msg("generating proof") + proof, err := groth16.Prove(ps.ConstraintSystem, ps.ProvingKey, witness) + if err != nil { + return nil, err + } + logging.Logger().Info().Msg("proof generated successfully") + return &Proof{proof}, nil +} + +func (ps *ProvingSystem) VerifyDeletion(inputHash big.Int, proof *Proof) error { + publicAssignment := DeletionMbuCircuit{ + InputHash: inputHash, + DeletionIndices: make([]frontend.Variable, ps.BatchSize), + } + witness, err := frontend.NewWitness(&publicAssignment, ecc.BN254.ScalarField(), frontend.PublicOnly()) + if err != nil { + return err + } + return groth16.Verify(proof.Proof, ps.VerifyingKey, witness) +} diff --git a/prover/insertion_circuit.go b/prover/insertion_circuit.go new file mode 100644 index 0000000..5a1f5ff --- /dev/null +++ b/prover/insertion_circuit.go @@ -0,0 +1,101 @@ +package prover + +import ( + "worldcoin/gnark-mbu/prover/keccak" + + "github.com/consensys/gnark/frontend" +) + +type InsertionMbuCircuit struct { + // single public input + InputHash frontend.Variable `gnark:",public"` + + // private inputs, but used as public inputs + StartIndex frontend.Variable `gnark:"input"` + PreRoot frontend.Variable `gnark:"input"` + PostRoot frontend.Variable `gnark:"input"` + IdComms []frontend.Variable `gnark:"input"` + + // private inputs + MerkleProofs [][]frontend.Variable `gnark:"input"` + + BatchSize int + Depth int +} + +func (circuit *InsertionMbuCircuit) Define(api frontend.API) error { + // Hash private inputs. + // We keccak hash all input to save verification gas. Inputs are arranged as follows: + // StartIndex || PreRoot || PostRoot || IdComms[0] || IdComms[1] || ... || IdComms[batchSize-1] + // 32 || 256 || 256 || 256 || 256 || ... || 256 bits + + kh := keccak.NewKeccak256(api, (circuit.BatchSize+2)*256+32) + + var bits []frontend.Variable + var err error + + // We convert all the inputs to the keccak hash to use big-endian (network) byte + // ordering so that it agrees with Solidity. This ensures that we don't have to + // perform the conversion inside the contract and hence save on gas. + bits, err = ToBinaryBigEndian(circuit.StartIndex, 32, api) + if err != nil { + return err + } + kh.Write(bits...) + + bits, err = ToBinaryBigEndian(circuit.PreRoot, 256, api) + if err != nil { + return err + } + kh.Write(bits...) + + bits, err = ToBinaryBigEndian(circuit.PostRoot, 256, api) + if err != nil { + return err + } + kh.Write(bits...) + + for i := 0; i < circuit.BatchSize; i++ { + bits, err = ToBinaryBigEndian(circuit.IdComms[i], 256, api) + if err != nil { + return err + } + kh.Write(bits...) + } + + var sum frontend.Variable + sum, err = FromBinaryBigEndian(kh.Sum(), api) + if err != nil { + return err + } + + // The same endianness conversion has been performed in the hash generation + // externally, so we can safely assert their equality here. + api.AssertIsEqual(circuit.InputHash, sum) + + // Actual batch merkle proof verification. + var root frontend.Variable + + prevRoot := circuit.PreRoot + + // Individual insertions. + for i := 0; i < circuit.BatchSize; i += 1 { + currentIndex := api.Add(circuit.StartIndex, i) + currentPath := api.ToBinary(currentIndex, circuit.Depth) + + // Verify proof for empty leaf. + root = VerifyProof(api, append([]frontend.Variable{emptyLeaf}, circuit.MerkleProofs[i][:]...), currentPath) + api.AssertIsEqual(root, prevRoot) + + // Verify proof for idComm. + root = VerifyProof(api, append([]frontend.Variable{circuit.IdComms[i]}, circuit.MerkleProofs[i][:]...), currentPath) + + // Set root for next iteration. + prevRoot = root + } + + // Final root needs to match. + api.AssertIsEqual(root, circuit.PostRoot) + + return nil +} diff --git a/prover/proving_system.go b/prover/insertion_proving_system.go similarity index 78% rename from prover/proving_system.go rename to prover/insertion_proving_system.go index 76fdd2e..639bc2f 100644 --- a/prover/proving_system.go +++ b/prover/insertion_proving_system.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "fmt" - "io" "math/big" "worldcoin/gnark-mbu/logging" "worldcoin/gnark-mbu/prover/poseidon" @@ -18,7 +17,7 @@ import ( "github.com/reilabs/gnark-lean-extractor/extractor" ) -type Parameters struct { +type InsertionParameters struct { InputHash big.Int StartIndex uint32 PreRoot big.Int @@ -27,19 +26,7 @@ type Parameters struct { MerkleProofs [][]big.Int } -type Proof struct { - Proof groth16.Proof -} - -type ProvingSystem struct { - TreeDepth uint32 - BatchSize uint32 - ProvingKey groth16.ProvingKey - VerifyingKey groth16.VerifyingKey - ConstraintSystem constraint.ConstraintSystem -} - -func (p *Parameters) ValidateShape(treeDepth uint32, batchSize uint32) error { +func (p *InsertionParameters) ValidateShape(treeDepth uint32, batchSize uint32) error { if len(p.IdComms) != int(batchSize) { return fmt.Errorf("wrong number of identity commitments: %d", len(p.IdComms)) } @@ -54,19 +41,12 @@ func (p *Parameters) ValidateShape(treeDepth uint32, batchSize uint32) error { return nil } -func toBytesLE(b []byte) []byte { - for i := 0; i < len(b)/2; i++ { - b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] - } - return b -} - // ComputeInputHash computes the input hash to the prover and verifier. // // It uses big-endian byte ordering (network ordering) in order to agree with // Solidity and avoid the need to perform the byte swapping operations on-chain // where they would increase our gas cost. -func (p *Parameters) ComputeInputHash() error { +func (p *InsertionParameters) ComputeInputHashInsertion() error { var data []byte buf := new(bytes.Buffer) err := binary.Write(buf, binary.BigEndian, p.StartIndex) @@ -89,12 +69,12 @@ func (p *Parameters) ComputeInputHash() error { return nil } -func BuildR1CS(treeDepth uint32, batchSize uint32) (constraint.ConstraintSystem, error) { +func BuildR1CSInsertion(treeDepth uint32, batchSize uint32) (constraint.ConstraintSystem, error) { proofs := make([][]frontend.Variable, batchSize) for i := 0; i < int(batchSize); i++ { proofs[i] = make([]frontend.Variable, treeDepth) } - circuit := MbuCircuit{ + circuit := InsertionMbuCircuit{ Depth: int(treeDepth), BatchSize: int(batchSize), IdComms: make([]frontend.Variable, batchSize), @@ -103,8 +83,8 @@ func BuildR1CS(treeDepth uint32, batchSize uint32) (constraint.ConstraintSystem, return frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit) } -func Setup(treeDepth uint32, batchSize uint32) (*ProvingSystem, error) { - ccs, err := BuildR1CS(treeDepth, batchSize) +func SetupInsertion(treeDepth uint32, batchSize uint32) (*ProvingSystem, error) { + ccs, err := BuildR1CSInsertion(treeDepth, batchSize) if err != nil { return nil, err } @@ -120,11 +100,7 @@ func ExtractLean() (string, error) { return extractor.GadgetToLean(&assignment, ecc.BN254) } -func (ps *ProvingSystem) ExportSolidity(writer io.Writer) error { - return ps.VerifyingKey.ExportSolidity(writer) -} - -func (ps *ProvingSystem) Prove(params *Parameters) (*Proof, error) { +func (ps *ProvingSystem) ProveInsertion(params *InsertionParameters) (*Proof, error) { if err := params.ValidateShape(ps.TreeDepth, ps.BatchSize); err != nil { return nil, err } @@ -139,7 +115,7 @@ func (ps *ProvingSystem) Prove(params *Parameters) (*Proof, error) { proofs[i][j] = params.MerkleProofs[i][j] } } - assignment := MbuCircuit{ + assignment := InsertionMbuCircuit{ InputHash: params.InputHash, StartIndex: params.StartIndex, PreRoot: params.PreRoot, @@ -160,8 +136,8 @@ func (ps *ProvingSystem) Prove(params *Parameters) (*Proof, error) { return &Proof{proof}, nil } -func (ps *ProvingSystem) Verify(inputHash big.Int, proof *Proof) error { - publicAssignment := MbuCircuit{ +func (ps *ProvingSystem) VerifyInsertion(inputHash big.Int, proof *Proof) error { + publicAssignment := InsertionMbuCircuit{ InputHash: inputHash, IdComms: make([]frontend.Variable, ps.BatchSize), } diff --git a/prover/marshal.go b/prover/marshal.go index 9680fdb..4f26cfc 100644 --- a/prover/marshal.go +++ b/prover/marshal.go @@ -5,11 +5,12 @@ import ( "encoding/binary" "encoding/json" "fmt" - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark/backend/groth16" "io" "math/big" "os" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" ) func fromHex(i *big.Int, s string) error { @@ -24,7 +25,7 @@ func toHex(i *big.Int) string { return fmt.Sprintf("0x%s", i.Text(16)) } -type ParametersJSON struct { +type InsertionParametersJSON struct { InputHash string `json:"inputHash"` StartIndex uint32 `json:"startIndex"` PreRoot string `json:"preRoot"` @@ -32,9 +33,17 @@ type ParametersJSON struct { IdComms []string `json:"identityCommitments"` MerkleProofs [][]string `json:"merkleProofs"` } +type DeletionParametersJSON struct { + InputHash string `json:"inputHash"` + DeletionIndices []uint32 `json:"deletionIndices"` + PreRoot string `json:"preRoot"` + PostRoot string `json:"postRoot"` + IdComms []string `json:"identityCommitments"` + MerkleProofs [][]string `json:"merkleProofs"` +} -func (p *Parameters) MarshalJSON() ([]byte, error) { - paramsJson := ParametersJSON{} +func (p *InsertionParameters) MarshalJSON() ([]byte, error) { + paramsJson := InsertionParametersJSON{} paramsJson.InputHash = toHex(&p.InputHash) paramsJson.StartIndex = p.StartIndex paramsJson.PreRoot = toHex(&p.PreRoot) @@ -53,9 +62,9 @@ func (p *Parameters) MarshalJSON() ([]byte, error) { return json.Marshal(paramsJson) } -func (p *Parameters) UnmarshalJSON(data []byte) error { +func (p *InsertionParameters) UnmarshalJSON(data []byte) error { - var params ParametersJSON + var params InsertionParametersJSON err := json.Unmarshal(data, ¶ms) if err != nil { @@ -101,6 +110,74 @@ func (p *Parameters) UnmarshalJSON(data []byte) error { return nil } +func (p *DeletionParameters) MarshalJSON() ([]byte, error) { + paramsJson := DeletionParametersJSON{} + paramsJson.InputHash = toHex(&p.InputHash) + paramsJson.DeletionIndices = p.DeletionIndices + paramsJson.PreRoot = toHex(&p.PreRoot) + paramsJson.PostRoot = toHex(&p.PostRoot) + paramsJson.IdComms = make([]string, len(p.IdComms)) + for i := 0; i < len(p.IdComms); i++ { + paramsJson.IdComms[i] = toHex(&p.IdComms[i]) + } + paramsJson.MerkleProofs = make([][]string, len(p.MerkleProofs)) + for i := 0; i < len(p.MerkleProofs); i++ { + paramsJson.MerkleProofs[i] = make([]string, len(p.MerkleProofs[i])) + for j := 0; j < len(p.MerkleProofs[i]); j++ { + paramsJson.MerkleProofs[i][j] = toHex(&p.MerkleProofs[i][j]) + } + } + return json.Marshal(paramsJson) +} + +func (p *DeletionParameters) UnmarshalJSON(data []byte) error { + + var params DeletionParametersJSON + + err := json.Unmarshal(data, ¶ms) + if err != nil { + return err + } + + err = fromHex(&p.InputHash, params.InputHash) + if err != nil { + return err + } + + p.DeletionIndices = params.DeletionIndices + + err = fromHex(&p.PreRoot, params.PreRoot) + if err != nil { + return err + } + + err = fromHex(&p.PostRoot, params.PostRoot) + if err != nil { + return err + } + + p.IdComms = make([]big.Int, len(params.IdComms)) + for i := 0; i < len(params.IdComms); i++ { + err = fromHex(&p.IdComms[i], params.IdComms[i]) + if err != nil { + return err + } + } + + p.MerkleProofs = make([][]big.Int, len(params.MerkleProofs)) + for i := 0; i < len(params.MerkleProofs); i++ { + p.MerkleProofs[i] = make([]big.Int, len(params.MerkleProofs[i])) + for j := 0; j < len(params.MerkleProofs[i]); j++ { + err = fromHex(&p.MerkleProofs[i][j], params.MerkleProofs[i][j]) + if err != nil { + return err + } + } + } + + return nil +} + type ProofJSON struct { Ar [2]string `json:"ar"` Bs [2][2]string `json:"bs"` diff --git a/server/server.go b/server/server.go index 244c7ef..f05e312 100644 --- a/server/server.go +++ b/server/server.go @@ -8,8 +8,9 @@ import ( "net/http" "worldcoin/gnark-mbu/logging" - "github.com/prometheus/client_golang/prometheus/promhttp" "worldcoin/gnark-mbu/prover" + + "github.com/prometheus/client_golang/prometheus/promhttp" ) type Error struct { @@ -18,6 +19,9 @@ type Error struct { Message string } +const DeletionMode = "deletion" +const InsertionMode = "insertion" + func malformedBodyError(err error) *Error { return &Error{StatusCode: http.StatusBadRequest, Code: "malformed_body", Message: err.Error()} } @@ -52,6 +56,7 @@ func (error *Error) send(w http.ResponseWriter) { type Config struct { ProverAddress string MetricsAddress string + Mode string } func spawnServerJob(server *http.Server, label string) RunningJob { @@ -80,7 +85,7 @@ func Run(config *Config, provingSystem *prover.ProvingSystem) RunningJob { logging.Logger().Info().Str("addr", config.MetricsAddress).Msg("metrics server started") proverMux := http.NewServeMux() - proverMux.Handle("/prove", proveHandler{provingSystem: provingSystem}) + proverMux.Handle("/prove", proveHandler{provingSystem: provingSystem, mode: config.Mode}) proverServer := &http.Server{Addr: config.ProverAddress, Handler: proverMux} proverJob := spawnServerJob(proverServer, "prover server") logging.Logger().Info().Str("addr", config.ProverAddress).Msg("app server started") @@ -89,6 +94,7 @@ func Run(config *Config, provingSystem *prover.ProvingSystem) RunningJob { } type proveHandler struct { + mode string provingSystem *prover.ProvingSystem } @@ -103,22 +109,45 @@ func (handler proveHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { malformedBodyError(err).send(w) return } - var params prover.Parameters - err = json.Unmarshal(buf, ¶ms) - if err != nil { - malformedBodyError(err).send(w) - return + + var proof *prover.Proof + if handler.mode == InsertionMode { + var params prover.InsertionParameters + + err = json.Unmarshal(buf, ¶ms) + if err != nil { + malformedBodyError(err).send(w) + return + } + + proof, err = handler.provingSystem.ProveInsertion(¶ms) + } else if handler.mode == DeletionMode { + var params prover.DeletionParameters + + err = json.Unmarshal(buf, ¶ms) + if err != nil { + malformedBodyError(err).send(w) + return + } + + proof, err = handler.provingSystem.ProveDeletion(¶ms) } - proof, err := handler.provingSystem.Prove(¶ms) + if err != nil { provingError(err).send(w) return } + responseBytes, err := json.Marshal(&proof) if err != nil { unexpectedError(err).send(w) return } + w.WriteHeader(http.StatusOK) _, err = w.Write(responseBytes) + + if err != nil { + logging.Logger().Error().Err(err).Msg("error writing response") + } }