diff --git a/main.go b/main.go index b5a81a4..9bfed7c 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "worldcoin/gnark-mbu/prover" "worldcoin/gnark-mbu/server" + "github.com/consensys/gnark/constraint" gnarkLogger "github.com/consensys/gnark/logger" "github.com/urfave/cli/v2" ) @@ -21,47 +22,31 @@ func main() { EnableBashCompletion: true, Commands: []*cli.Command{ { - Name: "setup-insertion", + Name: "setup", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + path := context.String("output") treeDepth := uint32(context.Uint("tree-depth")) batchSize := uint32(context.Uint("batch-size")) logging.Logger().Info().Msg("Running setup") - system, err := prover.SetupInsertion(treeDepth, batchSize) - if err != nil { - return err - } - file, err := os.Create(path) - defer file.Close() - if err != nil { - return err - } - written, err := system.WriteTo(file) - if err != nil { - return err + + var system *prover.ProvingSystem + var err error + if mode == server.InsertionMode { + system, err = prover.SetupInsertion(treeDepth, batchSize) + } else if mode == server.DeletionMode { + system, err = prover.SetupDeletion(treeDepth, batchSize) + } else { + return fmt.Errorf("Invalid mode: %s", mode) } - logging.Logger().Info().Int64("bytesWritten", written).Msg("proving system written to file") - return nil - }, - }, - { - Name: "setup-deletion", - Flags: []cli.Flag{ - &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, - &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, - &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, - }, - Action: func(context *cli.Context) error { - path := context.String("output") - treeDepth := uint32(context.Uint("tree-depth")) - batchSize := uint32(context.Uint("batch-size")) - logging.Logger().Info().Msg("Running setup") - system, err := prover.SetupDeletion(treeDepth, batchSize) + if err != nil { return err } @@ -79,47 +64,32 @@ func main() { }, }, { - Name: "r1cs-insertion", + Name: "r1cs", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + path := context.String("output") treeDepth := uint32(context.Uint("tree-depth")) batchSize := uint32(context.Uint("batch-size")) logging.Logger().Info().Msg("Building R1CS") - cs, err := prover.BuildR1CSInsertion(treeDepth, batchSize) - if err != nil { - return err - } - file, err := os.Create(path) - defer file.Close() - if err != nil { - return err - } - written, err := cs.WriteTo(file) - if err != nil { - return err + + var cs constraint.ConstraintSystem + var err error + + if mode == server.InsertionMode { + cs, err = prover.BuildR1CSInsertion(treeDepth, batchSize) + } else if mode == server.DeletionMode { + cs, err = prover.BuildR1CSDeletion(treeDepth, batchSize) + } else { + return fmt.Errorf("Invalid mode: %s", mode) } - logging.Logger().Info().Int64("bytesWritten", written).Msg("R1CS written to file") - return nil - }, - }, - { - Name: "r1cs-deletion", - Flags: []cli.Flag{ - &cli.StringFlag{Name: "output", Usage: "Output file", Required: true}, - &cli.UintFlag{Name: "tree-depth", Usage: "Merkle tree depth", Required: true}, - &cli.UintFlag{Name: "batch-size", Usage: "Batch size", Required: true}, - }, - Action: func(context *cli.Context) error { - path := context.String("output") - treeDepth := uint32(context.Uint("tree-depth")) - batchSize := uint32(context.Uint("batch-size")) - logging.Logger().Info().Msg("Building R1CS") - cs, err := prover.BuildR1CSDeletion(treeDepth, batchSize) + if err != nil { return err } @@ -163,59 +133,60 @@ func main() { }, }, { - Name: "gen-test-params-insertion", + Name: "gen-test-params", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.UintFlag{Name: "tree-depth", Usage: "depth of the mock tree", Required: true}, &cli.UintFlag{Name: "batch-size", Usage: "batch size", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + treeDepth := context.Int("tree-depth") batchSize := uint32(context.Uint("batch-size")) logging.Logger().Info().Msg("Generating test params for the insertion circuit") - params := prover.InsertionParameters{} - tree := NewTree(treeDepth) + var r []byte + var err error - params.StartIndex = 0 - params.PreRoot = tree.Root() - params.IdComms = make([]big.Int, batchSize) - params.MerkleProofs = make([][]big.Int, batchSize) - for i := 0; i < int(batchSize); i++ { - params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) - params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) - } - params.PostRoot = tree.Root() - params.ComputeInputHashInsertion() - r, _ := json.Marshal(¶ms) - fmt.Println(string(r)) - return nil - }, - }, - { - Name: "gen-test-params-deletion", - Flags: []cli.Flag{ - &cli.UintFlag{Name: "tree-depth", Usage: "depth of the mock tree", Required: true}, - &cli.UintFlag{Name: "batch-size", Usage: "batch size", Required: true}, - }, - Action: func(context *cli.Context) error { - treeDepth := context.Int("tree-depth") - batchSize := uint32(context.Uint("batch-size")) - logging.Logger().Info().Msg("Generating test params for the deletion circuit") + if mode == server.InsertionMode { + params := prover.InsertionParameters{} + tree := NewTree(treeDepth) + + params.StartIndex = 0 + params.PreRoot = tree.Root() + params.IdComms = make([]big.Int, batchSize) + params.MerkleProofs = make([][]big.Int, batchSize) + for i := 0; i < int(batchSize); i++ { + params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) + params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) + } + params.PostRoot = tree.Root() + params.ComputeInputHashInsertion() + r, err = json.Marshal(¶ms) + } else if mode == server.DeletionMode { + params := prover.DeletionParameters{} + tree := NewTree(treeDepth) - params := prover.DeletionParameters{} - tree := NewTree(treeDepth) + params.DeletionIndices = make([]big.Int, batchSize) + params.PreRoot = tree.Root() + params.IdComms = make([]big.Int, batchSize) + params.MerkleProofs = make([][]big.Int, batchSize) + for i := 0; i < int(batchSize); i++ { + params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) + params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) + } + params.PostRoot = tree.Root() + params.ComputeInputHashDeletion() + r, err = json.Marshal(¶ms) + } else { + return fmt.Errorf("Invalid mode: %s", mode) + } - params.DeletionIndices = make([]big.Int, batchSize) - params.PreRoot = tree.Root() - params.IdComms = make([]big.Int, batchSize) - params.MerkleProofs = make([][]big.Int, batchSize) - for i := 0; i < int(batchSize); i++ { - params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1)) - params.MerkleProofs[i] = tree.Update(i, params.IdComms[i]) + if err != nil { + return err } - params.PostRoot = tree.Root() - params.ComputeInputHashDeletion() - r, _ := json.Marshal(¶ms) + fmt.Println(string(r)) return nil }, @@ -263,43 +234,14 @@ func main() { }, }, { - Name: "prove-insertion", - Flags: []cli.Flag{ - &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, - }, - Action: func(context *cli.Context) error { - keys := context.String("keys-file") - ps, err := prover.ReadSystemFromFile(keys) - if err != nil { - return err - } - logging.Logger().Info().Uint32("treeDepth", ps.TreeDepth).Uint32("batchSize", ps.BatchSize).Msg("Read proving system") - logging.Logger().Info().Msg("reading params from stdin") - bytes, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - var params prover.InsertionParameters - err = json.Unmarshal(bytes, ¶ms) - if err != nil { - return err - } - logging.Logger().Info().Msg("params read successfully") - proof, err := ps.ProveInsertion(¶ms) - if err != nil { - return err - } - r, _ := json.Marshal(&proof) - fmt.Println(string(r)) - return nil - }, - }, - { - Name: "prove-deletion", + Name: "prove", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + keys := context.String("keys-file") ps, err := prover.ReadSystemFromFile(keys) if err != nil { @@ -311,13 +253,28 @@ func main() { if err != nil { return err } - var params prover.DeletionParameters - err = json.Unmarshal(bytes, ¶ms) - if err != nil { - return err + + var proof *prover.Proof + if mode == server.InsertionMode { + var params prover.InsertionParameters + err = json.Unmarshal(bytes, ¶ms) + if err != nil { + return err + } + logging.Logger().Info().Msg("params read successfully") + proof, err = ps.ProveInsertion(¶ms) + } else if mode == server.DeletionMode { + var params prover.DeletionParameters + err = json.Unmarshal(bytes, ¶ms) + if err != nil { + return err + } + logging.Logger().Info().Msg("params read successfully") + proof, err = ps.ProveDeletion(¶ms) + } else { + return fmt.Errorf("Invalid mode: %s", mode) } - logging.Logger().Info().Msg("params read successfully") - proof, err := ps.ProveDeletion(¶ms) + if err != nil { return err } @@ -327,12 +284,15 @@ func main() { }, }, { - Name: "verify-insertion", + Name: "verify", Flags: []cli.Flag{ + &cli.StringFlag{Name: "mode", Usage: "insertion/deletion", DefaultText: "insertion"}, &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, &cli.StringFlag{Name: "input-hash", Usage: "the hash of all public inputs", Required: true}, }, Action: func(context *cli.Context) error { + mode := context.String("mode") + keys := context.String("keys-file") var inputHash big.Int _, ok := inputHash.SetString(context.String("input-hash"), 0) @@ -355,44 +315,15 @@ func main() { return err } logging.Logger().Info().Msg("proof read successfully") - err = ps.VerifyInsertion(inputHash, &proof) - if err != nil { - return err - } - logging.Logger().Info().Msg("verification complete") - return nil - }, - }, - { - Name: "verify-deletion", - Flags: []cli.Flag{ - &cli.StringFlag{Name: "keys-file", Usage: "proving system file", Required: true}, - &cli.StringFlag{Name: "input-hash", Usage: "the hash of all public inputs", Required: true}, - }, - Action: func(context *cli.Context) error { - keys := context.String("keys-file") - var inputHash big.Int - _, ok := inputHash.SetString(context.String("input-hash"), 0) - if !ok { - return fmt.Errorf("invalid number: %s", context.String("input-hash")) - } - ps, err := prover.ReadSystemFromFile(keys) - if err != nil { - return err - } - logging.Logger().Info().Uint32("treeDepth", ps.TreeDepth).Uint32("batchSize", ps.BatchSize).Msg("Read proving system") - logging.Logger().Info().Msg("reading proof from stdin") - bytes, err := io.ReadAll(os.Stdin) - if err != nil { - return err - } - var proof prover.Proof - err = json.Unmarshal(bytes, &proof) - if err != nil { - return err + + if mode == server.InsertionMode { + err = ps.VerifyInsertion(inputHash, &proof) + } else if mode == server.DeletionMode { + err = ps.VerifyDeletion(inputHash, &proof) + } else { + return fmt.Errorf("Invalid mode: %s", mode) } - logging.Logger().Info().Msg("proof read successfully") - err = ps.VerifyDeletion(inputHash, &proof) + if err != nil { return err } diff --git a/server/server.go b/server/server.go index b47c136..444cfe3 100644 --- a/server/server.go +++ b/server/server.go @@ -19,6 +19,9 @@ type Error struct { Message string } +const DeletionMode = "deletion" +const InsertionMode = "insertion" + func malformedBodyError(err error) *Error { return &Error{StatusCode: http.StatusBadRequest, Code: "malformed_body", Message: err.Error()} }