Skip to content

Commit

Permalink
add tests, fix logic
Browse files Browse the repository at this point in the history
  • Loading branch information
kustosz committed Sep 4, 2023
1 parent 61e1187 commit d50f6b4
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 91 deletions.
113 changes: 84 additions & 29 deletions integration_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main
package main_test

import (
gnarkLogger "github.com/consensys/gnark/logger"
"io"
"net/http"
"strings"
Expand All @@ -13,7 +14,10 @@ import (
const ProverAddress = "localhost:8080"
const MetricsAddress = "localhost:9999"

var mode string

func TestMain(m *testing.M) {
gnarkLogger.Set(*logging.Logger())
logging.Logger().Info().Msg("Setting up the prover")
ps, err := prover.SetupInsertion(3, 2)
if err != nil {
Expand All @@ -24,36 +28,27 @@ func TestMain(m *testing.M) {
MetricsAddress: MetricsAddress,
Mode: server.InsertionMode,
}
logging.Logger().Info().Msg("Starting the server")
logging.Logger().Info().Msg("Starting the insertion server")
instance := server.Run(&cfg, ps)
logging.Logger().Info().Msg("Running the tests")
defer func() {
instance.RequestStop()
instance.AwaitStop()
}()
logging.Logger().Info().Msg("Running the insertion tests")
mode = server.InsertionMode
m.Run()
instance.RequestStop()
instance.AwaitStop()
cfg.Mode = server.DeletionMode
ps, err = prover.SetupDeletion(3, 2)
if err != nil {
panic(err)
}
logging.Logger().Info().Msg("Starting the deletion server")
instance = server.Run(&cfg, ps)
logging.Logger().Info().Msg("Running the deletion tests")
mode = server.DeletionMode
m.Run()
instance.RequestStop()
instance.AwaitStop()
}

// func TestMainDeletion(m *testing.M) {
// logging.Logger().Info().Msg("Setting up the prover")
// ps, err := prover.SetupDeletion(3, 2)
// if err != nil {
// panic(err)
// }
// cfg := server.Config{
// ProverAddress: ProverAddress,
// MetricsAddress: MetricsAddress,
// }
// logging.Logger().Info().Msg("Starting the server")
// instance := server.Run(&cfg, ps)
// logging.Logger().Info().Msg("Running the tests")
// defer func() {
// instance.RequestStop()
// instance.AwaitStop()
// }()
// m.Run()
// }

func TestWrongMethod(t *testing.T) {
response, err := http.Get("http://localhost:8080/prove")
if err != nil {
Expand All @@ -64,7 +59,10 @@ func TestWrongMethod(t *testing.T) {
}
}

func TestHappyPath(t *testing.T) {
func TestInsertionHappyPath(t *testing.T) {
if mode != server.InsertionMode {
return
}
body := `{
"inputHash":"0x5057a31740d54d42ac70c05e0768fb770c682cb2c559bdd03fe4099f7e584e4f",
"startIndex":0,
Expand All @@ -84,7 +82,33 @@ func TestHappyPath(t *testing.T) {
}
}

func TestWrongInput(t *testing.T) {
func TestDeletionHappyPath(t *testing.T) {
if mode != server.DeletionMode {
return
}
body := `{
"inputHash":"0xdcd389a94b549222fadc9e335c358a3fe4d534155182f46927f82ea8491c7480",
"deletionIndices":[0,2],
"preRoot":"0xd11eefe87b985333c0d327b0cdd39a9641b5ac32c35c2bda84301ef3231a8ac",
"postRoot":"0x1912415186579e1d9ff6282b76d081f0acd527d8549ea803385b1382d9498f35",
"identityCommitments":["0x1","0x3"],
"merkleProofs":[
["0x2","0x20a3af0435914ccd84b806164531b0cd36e37d4efb93efab76913a93e1f30996","0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1"],
["0x4","0x65e2c6cc08a36c4a943286bc91c216054a1981eb4f7570f67394ef8937a21b8","0x1069673dcdb12263df301a6ff584a7ec261a44cb9dc68df067a4774460b1f1e1"]
]}`
response, err := http.Post("http://localhost:8080/prove", "application/json", strings.NewReader(body))
if err != nil {
t.Fatal(err)
}
if response.StatusCode != http.StatusOK {
t.Fatalf("Expected status code %d, got %d", http.StatusOK, response.StatusCode)
}
}

func TestInsertionWrongInput(t *testing.T) {
if mode != server.InsertionMode {
return
}
body := `{
"inputHash":"0x5057a31740d54d42ac70c05e0768fb770c682cb2c559bdd03fe4099f7e584e4f",
"startIndex":0,
Expand All @@ -109,4 +133,35 @@ func TestWrongInput(t *testing.T) {
if !strings.Contains(string(responseBody), "proving_error") {
t.Fatalf("Expected error message to be tagged with 'proving_error', got %s", string(responseBody))
}

}

func TestDeletionWrongInput(t *testing.T) {
if mode != server.DeletionMode {
return
}
body := `{
"inputHash":"0xdcd389a94b549222fadc9e335c358a3fe4d534155182f46927f82ea8491c7480",
"deletionIndices":[0,2],
"preRoot":"0xd11eefe87b985333c0d327b0cdd39a9641b5ac32c35c2bda84301ef3231a8ac",
"postRoot":"0x1912415186579e1d9ff6282b76d081f0acd527d8549ea803385b1382d9498f35",
"identityCommitments":["0x1","0x3"],
"merkleProofs":[
["0x2","0xD","0xD"],
["0x4","0xD","0xD"]
]}`
response, err := http.Post("http://localhost:8080/prove", "application/json", strings.NewReader(body))
if err != nil {
t.Fatal(err)
}
if response.StatusCode != http.StatusBadRequest {
t.Fatalf("Expected status code %d, got %d", http.StatusBadRequest, response.StatusCode)
}
responseBody, err := io.ReadAll(response.Body)
if err != nil {
t.Fatal(err)
}
if !strings.Contains(string(responseBody), "proving_error") {
t.Fatalf("Expected error message to be tagged with 'proving_error', got %s", string(responseBody))
}
}
12 changes: 8 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,13 +168,17 @@ func main() {
params := prover.DeletionParameters{}
tree := NewTree(treeDepth)

params.DeletionIndices = make([]big.Int, batchSize)
params.PreRoot = tree.Root()
params.DeletionIndices = make([]uint32, batchSize)
params.IdComms = make([]big.Int, batchSize)
params.MerkleProofs = make([][]big.Int, batchSize)
for i := 0; i < int(batchSize*2); i++ {
tree.Update(i, *new(big.Int).SetUint64(uint64(i + 1)))
}
params.PreRoot = tree.Root()
for i := 0; i < int(batchSize); i++ {
params.IdComms[i] = *new(big.Int).SetUint64(uint64(i + 1))
params.MerkleProofs[i] = tree.Update(i, params.IdComms[i])
params.DeletionIndices[i] = uint32(2 * i)
params.IdComms[i] = *new(big.Int).SetUint64(uint64(2*i + 1))
params.MerkleProofs[i] = tree.Update(2*i, *big.NewInt(0))
}
params.PostRoot = tree.Root()
params.ComputeInputHashDeletion()
Expand Down
1 change: 0 additions & 1 deletion prover/deletion_circuit.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ func (circuit *DeletionMbuCircuit) Define(api frontend.API) error {
// We keccak hash all input to save verification gas. Inputs are arranged as follows:
// deletionIndices[0] || deletionIndices[1] || ... || deletionIndices[batchSize-1] || PreRoot || PostRoot
// 32 || 32 || ... || 32 || 256 || 256

kh := keccak.NewKeccak256(api, circuit.BatchSize*32+2*256)

var bits []frontend.Variable
Expand Down
30 changes: 17 additions & 13 deletions prover/deletion_proving_system.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package prover

import (
"bytes"
"encoding/binary"
"fmt"
"math/big"
"worldcoin/gnark-mbu/logging"
Expand All @@ -17,7 +19,7 @@ type DeletionParameters struct {
InputHash big.Int
PreRoot big.Int
PostRoot big.Int
DeletionIndices []big.Int
DeletionIndices []uint32
IdComms []big.Int
MerkleProofs [][]big.Int
}
Expand All @@ -29,6 +31,9 @@ func (p *DeletionParameters) ValidateShape(treeDepth uint32, batchSize uint32) e
if len(p.MerkleProofs) != int(batchSize) {
return fmt.Errorf("wrong number of merkle proofs: %d", len(p.MerkleProofs))
}
if len(p.DeletionIndices) != int(batchSize) {
return fmt.Errorf("wrong number of deletion indices: %d", len(p.DeletionIndices))
}
for i, proof := range p.MerkleProofs {
if len(proof) != int(treeDepth) {
return fmt.Errorf("wrong size of merkle proof for proof %d: %d", i, len(proof))
Expand All @@ -37,21 +42,19 @@ func (p *DeletionParameters) ValidateShape(treeDepth uint32, batchSize uint32) e
return nil
}

// ComputeInputHash computes the input hash to the prover and verifier.
// ComputeInputHashDeletion computes the input hash to the prover and verifier.
//
// It uses big-endian byte ordering (network ordering) in order to agree with
// Solidity and avoid the need to perform the byte swapping operations on-chain
// where they would increase our gas cost.
func (p *DeletionParameters) ComputeInputHashDeletion() error {
var data []byte
for _, v := range p.IdComms {
idBytes := v.Bytes()
// extend to 32 bytes if necessary, maintaining big-endian ordering
if len(idBytes) < 32 {
idBytes = append(make([]byte, 32-len(idBytes)), idBytes...)
}
data = append(data, idBytes...)
buf := new(bytes.Buffer)
err := binary.Write(buf, binary.BigEndian, p.DeletionIndices)
if err != nil {
return err
}
data = append(data, buf.Bytes()...)
data = append(data, p.PreRoot.Bytes()...)
data = append(data, p.PostRoot.Bytes()...)

Expand All @@ -66,10 +69,11 @@ func BuildR1CSDeletion(treeDepth uint32, batchSize uint32) (constraint.Constrain
proofs[i] = make([]frontend.Variable, treeDepth)
}
circuit := DeletionMbuCircuit{
Depth: int(treeDepth),
BatchSize: int(batchSize),
IdComms: make([]frontend.Variable, batchSize),
MerkleProofs: proofs,
Depth: int(treeDepth),
BatchSize: int(batchSize),
DeletionIndices: make([]frontend.Variable, batchSize),
IdComms: make([]frontend.Variable, batchSize),
MerkleProofs: proofs,
}
return frontend.Compile(ecc.BN254.ScalarField(), r1cs.NewBuilder, &circuit)
}
Expand Down
88 changes: 44 additions & 44 deletions prover/marshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type InsertionParametersJSON struct {
}
type DeletionParametersJSON struct {
InputHash string `json:"inputHash"`
DeletionIndices []string `json:"deletionIndices"`
DeletionIndices []uint32 `json:"deletionIndices"`
PreRoot string `json:"preRoot"`
PostRoot string `json:"postRoot"`
IdComms []string `json:"identityCommitments"`
Expand Down Expand Up @@ -113,7 +113,7 @@ func (p *InsertionParameters) UnmarshalJSON(data []byte) error {
func (p *DeletionParameters) MarshalJSON() ([]byte, error) {
paramsJson := DeletionParametersJSON{}
paramsJson.InputHash = toHex(&p.InputHash)
paramsJson.DeletionIndices = make([]string, len(p.DeletionIndices))
paramsJson.DeletionIndices = p.DeletionIndices
paramsJson.PreRoot = toHex(&p.PreRoot)
paramsJson.PostRoot = toHex(&p.PostRoot)
paramsJson.IdComms = make([]string, len(p.IdComms))
Expand All @@ -132,48 +132,48 @@ func (p *DeletionParameters) MarshalJSON() ([]byte, error) {

func (p *DeletionParameters) UnmarshalJSON(data []byte) error {

//var params InsertionParametersJSON
//
//err := json.Unmarshal(data, &params)
//if err != nil {
// return err
//}
//
//err = fromHex(&p.InputHash, params.InputHash)
//if err != nil {
// return err
//}
//
//p.StartIndex = params.StartIndex
//
//err = fromHex(&p.PreRoot, params.PreRoot)
//if err != nil {
// return err
//}
//
//err = fromHex(&p.PostRoot, params.PostRoot)
//if err != nil {
// return err
//}
//
//p.IdComms = make([]big.Int, len(params.IdComms))
//for i := 0; i < len(params.IdComms); i++ {
// err = fromHex(&p.IdComms[i], params.IdComms[i])
// if err != nil {
// return err
// }
//}
//
//p.MerkleProofs = make([][]big.Int, len(params.MerkleProofs))
//for i := 0; i < len(params.MerkleProofs); i++ {
// p.MerkleProofs[i] = make([]big.Int, len(params.MerkleProofs[i]))
// for j := 0; j < len(params.MerkleProofs[i]); j++ {
// err = fromHex(&p.MerkleProofs[i][j], params.MerkleProofs[i][j])
// if err != nil {
// return err
// }
// }
//}
var params DeletionParametersJSON

err := json.Unmarshal(data, &params)
if err != nil {
return err
}

err = fromHex(&p.InputHash, params.InputHash)
if err != nil {
return err
}

p.DeletionIndices = params.DeletionIndices

err = fromHex(&p.PreRoot, params.PreRoot)
if err != nil {
return err
}

err = fromHex(&p.PostRoot, params.PostRoot)
if err != nil {
return err
}

p.IdComms = make([]big.Int, len(params.IdComms))
for i := 0; i < len(params.IdComms); i++ {
err = fromHex(&p.IdComms[i], params.IdComms[i])
if err != nil {
return err
}
}

p.MerkleProofs = make([][]big.Int, len(params.MerkleProofs))
for i := 0; i < len(params.MerkleProofs); i++ {
p.MerkleProofs[i] = make([]big.Int, len(params.MerkleProofs[i]))
for j := 0; j < len(params.MerkleProofs[i]); j++ {
err = fromHex(&p.MerkleProofs[i][j], params.MerkleProofs[i][j])
if err != nil {
return err
}
}
}

return nil
}
Expand Down

0 comments on commit d50f6b4

Please sign in to comment.