From 0da4ef3b88be950e8518e7a7b7cfba5e3f62b845 Mon Sep 17 00:00:00 2001 From: Sanjit Bhat Date: Mon, 2 Dec 2024 13:34:01 -0500 Subject: [PATCH] merkle: restore 3D proof. fewer memmoves bc copying hash refs instead of vals --- kt/serde.go | 6 +-- kt/serde.out.go | 12 +++--- marshalutil/marshalutil.go | 81 +++++++++++++++++++------------------- merkle/merkle.go | 72 +++++++++++++++++++-------------- 4 files changed, 93 insertions(+), 78 deletions(-) diff --git a/kt/serde.go b/kt/serde.go index 795deeb..f548820 100644 --- a/kt/serde.go +++ b/kt/serde.go @@ -30,18 +30,18 @@ type Memb struct { LabelProof []byte EpochAdded uint64 PkOpen *CommitOpen - MerkProof []byte + MerkProof [][][]byte } type MembHide struct { LabelProof []byte MapVal []byte - MerkProof []byte + MerkProof [][][]byte } type NonMemb struct { LabelProof []byte - MerkProof []byte + MerkProof [][][]byte } type ServerPutArg struct { diff --git a/kt/serde.out.go b/kt/serde.out.go index ce5cb33..d1f1c4a 100644 --- a/kt/serde.out.go +++ b/kt/serde.out.go @@ -102,7 +102,7 @@ func MembEncode(b0 []byte, o *Memb) []byte { b = marshalutil.WriteSlice1D(b, o.LabelProof) b = marshal.WriteInt(b, o.EpochAdded) b = CommitOpenEncode(b, o.PkOpen) - b = marshalutil.WriteSlice1D(b, o.MerkProof) + b = marshalutil.WriteSlice3D(b, o.MerkProof) return b } func MembDecode(b0 []byte) (*Memb, []byte, bool) { @@ -118,7 +118,7 @@ func MembDecode(b0 []byte) (*Memb, []byte, bool) { if err3 { return nil, nil, true } - a4, b4, err4 := marshalutil.ReadSlice1D(b3) + a4, b4, err4 := marshalutil.ReadSlice3D(b3) if err4 { return nil, nil, true } @@ -128,7 +128,7 @@ func MembHideEncode(b0 []byte, o *MembHide) []byte { var b = b0 b = marshalutil.WriteSlice1D(b, o.LabelProof) b = marshalutil.WriteSlice1D(b, o.MapVal) - b = marshalutil.WriteSlice1D(b, o.MerkProof) + b = marshalutil.WriteSlice3D(b, o.MerkProof) return b } func MembHideDecode(b0 []byte) (*MembHide, []byte, bool) { @@ -140,7 +140,7 @@ func MembHideDecode(b0 []byte) (*MembHide, []byte, bool) { if err2 { return nil, nil, true } - a3, b3, err3 := marshalutil.ReadSlice1D(b2) + a3, b3, err3 := marshalutil.ReadSlice3D(b2) if err3 { return nil, nil, true } @@ -149,7 +149,7 @@ func MembHideDecode(b0 []byte) (*MembHide, []byte, bool) { func NonMembEncode(b0 []byte, o *NonMemb) []byte { var b = b0 b = marshalutil.WriteSlice1D(b, o.LabelProof) - b = marshalutil.WriteSlice1D(b, o.MerkProof) + b = marshalutil.WriteSlice3D(b, o.MerkProof) return b } func NonMembDecode(b0 []byte) (*NonMemb, []byte, bool) { @@ -157,7 +157,7 @@ func NonMembDecode(b0 []byte) (*NonMemb, []byte, bool) { if err1 { return nil, nil, true } - a2, b2, err2 := marshalutil.ReadSlice1D(b1) + a2, b2, err2 := marshalutil.ReadSlice3D(b1) if err2 { return nil, nil, true } diff --git a/marshalutil/marshalutil.go b/marshalutil/marshalutil.go index 19eb31f..aaebfc7 100644 --- a/marshalutil/marshalutil.go +++ b/marshalutil/marshalutil.go @@ -4,74 +4,75 @@ import ( "github.com/tchajed/marshal" ) -type errorTy = bool - -const ( - errNone errorTy = false - errSome errorTy = true -) +func WriteBytes2D(b0 []byte, data [][]byte) []byte { + var b = b0 + for _, x := range data { + b = marshal.WriteBytes(b, x) + } + return b +} -func ReadBool(b0 []byte) (bool, []byte, errorTy) { +func ReadBool(b0 []byte) (bool, []byte, bool) { var b = b0 if uint64(len(b)) < 1 { - return false, nil, errSome + return false, nil, true } data, b := marshal.ReadBool(b) - return data, b, errNone + return data, b, false } -func ReadConstBool(b0 []byte, cst bool) ([]byte, errorTy) { +func ReadConstBool(b0 []byte, cst bool) ([]byte, bool) { var b = b0 res, b, err := ReadBool(b) if err { - return nil, errSome + return nil, true } if res != cst { - return nil, errSome + return nil, true } - return b, errNone + return b, false } -func ReadInt(b0 []byte) (uint64, []byte, errorTy) { +func ReadInt(b0 []byte) (uint64, []byte, bool) { var b = b0 if uint64(len(b)) < 8 { - return 0, nil, errSome + return 0, nil, true } data, b := marshal.ReadInt(b) - return data, b, errNone + return data, b, false } -func ReadConstInt(b0 []byte, cst uint64) ([]byte, errorTy) { +func ReadConstInt(b0 []byte, cst uint64) ([]byte, bool) { var b = b0 res, b, err := ReadInt(b) if err { - return nil, errSome + return nil, true } if res != cst { - return nil, errSome + return nil, true } - return b, errNone + return b, false } -func ReadByte(b0 []byte) (byte, []byte, errorTy) { +func ReadByte(b0 []byte) (byte, []byte, bool) { var b = b0 if uint64(len(b)) < 1 { - return 0, nil, errSome + return 0, nil, true } data, b := marshal.ReadBytes(b, 1) - return data[0], b, errNone + return data[0], b, false } -func ReadConstByte(b0 []byte, cst byte) ([]byte, errorTy) { +func ReadConstByte(b0 []byte, cst byte) ([]byte, bool) { var b = b0 res, b, err := ReadByte(b) if err { - return nil, errSome + return nil, true } if res != cst { - return nil, errSome + return nil, true } - return b, errNone + return b, false } func WriteByte(b0 []byte, data byte) []byte { @@ -80,16 +81,16 @@ func WriteByte(b0 []byte, data byte) []byte { return b } -func ReadBytes(b0 []byte, length uint64) ([]byte, []byte, errorTy) { +func ReadBytes(b0 []byte, length uint64) ([]byte, []byte, bool) { var b = b0 if uint64(len(b)) < length { - return nil, nil, errSome + return nil, nil, true } data, b := marshal.ReadBytes(b, length) - return data, b, errNone + return data, b, false } -func ReadSlice1D(b0 []byte) ([]byte, []byte, errorTy) { +func ReadSlice1D(b0 []byte) ([]byte, []byte, bool) { var b = b0 length, b, err := ReadInt(b) if err { @@ -99,7 +100,7 @@ func ReadSlice1D(b0 []byte) ([]byte, []byte, errorTy) { if err { return nil, nil, err } - return data, b, errNone + return data, b, false } func WriteSlice1D(b0 []byte, data []byte) []byte { @@ -109,18 +110,18 @@ func WriteSlice1D(b0 []byte, data []byte) []byte { return b } -func ReadSlice2D(b0 []byte) ([][]byte, []byte, errorTy) { +func ReadSlice2D(b0 []byte) ([][]byte, []byte, bool) { var b = b0 length, b, err := ReadInt(b) if err { return nil, nil, err } var data0 [][]byte - var err0 errorTy + var err0 bool var i uint64 for ; i < length; i++ { var data1 []byte - var err1 errorTy + var err1 bool data1, b, err1 = ReadSlice1D(b) if err1 { err0 = err1 @@ -131,7 +132,7 @@ func ReadSlice2D(b0 []byte) ([][]byte, []byte, errorTy) { if err0 { return nil, nil, err0 } - return data0, b, errNone + return data0, b, false } func WriteSlice2D(b0 []byte, data [][]byte) []byte { @@ -143,18 +144,18 @@ func WriteSlice2D(b0 []byte, data [][]byte) []byte { return b } -func ReadSlice3D(b0 []byte) ([][][]byte, []byte, errorTy) { +func ReadSlice3D(b0 []byte) ([][][]byte, []byte, bool) { var b = b0 length, b, err := ReadInt(b) if err { return nil, nil, err } var data0 [][][]byte - var err0 errorTy + var err0 bool var i uint64 for ; i < length; i++ { var data1 [][]byte - var err1 errorTy + var err1 bool data1, b, err1 = ReadSlice2D(b) if err1 { err0 = err1 @@ -165,7 +166,7 @@ func ReadSlice3D(b0 []byte) ([][][]byte, []byte, errorTy) { if err0 { return nil, nil, err0 } - return data0, b, errNone + return data0, b, false } func WriteSlice3D(b0 []byte, data [][][]byte) []byte { diff --git a/merkle/merkle.go b/merkle/merkle.go index 5696253..0081dbe 100644 --- a/merkle/merkle.go +++ b/merkle/merkle.go @@ -3,18 +3,18 @@ package merkle import ( "github.com/goose-lang/std" "github.com/mit-pdos/pav/cryptoffi" + "github.com/mit-pdos/pav/marshalutil" "github.com/tchajed/marshal" ) const ( // Branch on a byte. 2 ** 8 (bits in byte) = 256. - numChildren uint64 = 256 - hashesPerProofDepth uint64 = (numChildren - 1) * cryptoffi.HashLen - emptyNodeTag byte = 0 - leafNodeTag byte = 1 - interiorNodeTag byte = 2 - NonmembProofTy bool = false - MembProofTy bool = true + numChildren uint64 = 256 + emptyNodeTag byte = 0 + leafNodeTag byte = 1 + interiorNodeTag byte = 2 + NonmembProofTy bool = false + MembProofTy bool = true ) type Tree struct { @@ -41,7 +41,7 @@ func (t *Tree) Digest() []byte { } // Put returns the digest, proof, and error. -func (t *Tree) Put(label []byte, mapVal []byte) ([]byte, []byte, bool) { +func (t *Tree) Put(label []byte, mapVal []byte) ([]byte, [][][]byte, bool) { if uint64(len(label)) != cryptoffi.HashLen { return nil, nil, true } @@ -82,7 +82,7 @@ func (t *Tree) Put(label []byte, mapVal []byte) ([]byte, []byte, bool) { // Get returns the mapVal, digest, proofTy, proof, and error. // return ProofTy vs. having sep funcs bc regardless, would want a proof. -func (t *Tree) Get(label []byte) ([]byte, []byte, bool, []byte, bool) { +func (t *Tree) Get(label []byte) ([]byte, []byte, bool, [][][]byte, bool) { if uint64(len(label)) != cryptoffi.HashLen { return nil, nil, false, nil, true } @@ -104,20 +104,14 @@ func NewTree() *Tree { } // CheckProof returns an error if the proof is invalid. -func CheckProof(proofTy bool, proof []byte, label []byte, mapVal []byte, dig []byte) bool { +func CheckProof(proofTy bool, proof [][][]byte, label []byte, mapVal []byte, dig []byte) bool { proofLen := uint64(len(proof)) - if proofLen%hashesPerProofDepth != 0 { - return true - } - proofDepth := proofLen / hashesPerProofDepth - if proofDepth > cryptoffi.HashLen { + if proofLen > cryptoffi.HashLen { return true } if uint64(len(label)) != cryptoffi.HashLen { return true } - // NonmembProof has original label. slice it down to match proof. - labelPref := label[:proofDepth] var nodeHash []byte if proofTy { nodeHash = compLeafNodeHash(mapVal) @@ -129,15 +123,24 @@ func CheckProof(proofTy bool, proof []byte, label []byte, mapVal []byte, dig []b var loopCurrHash []byte = nodeHash var loopBuf = make([]byte, 0, numChildren*cryptoffi.HashLen+1) var loopIdx = uint64(0) - for ; loopIdx < proofDepth; loopIdx++ { - depth := proofDepth - 1 - loopIdx - begin := depth * hashesPerProofDepth - middle := begin + uint64(labelPref[depth])*cryptoffi.HashLen - end := (depth + 1) * hashesPerProofDepth + for ; loopIdx < proofLen; loopIdx++ { + depth := proofLen - 1 - loopIdx + children := proof[depth] + if uint64(len(children)) != numChildren-1 { + loopErr = true + continue + } + if !checkValidHashes(children) { + loopErr = true + continue + } - loopBuf = marshal.WriteBytes(loopBuf, proof[begin:middle]) + pos := label[depth] + before := children[:pos] + after := children[pos:] + loopBuf = marshalutil.WriteBytes2D(loopBuf, before) loopBuf = marshal.WriteBytes(loopBuf, loopCurrHash) - loopBuf = marshal.WriteBytes(loopBuf, proof[middle:end]) + loopBuf = marshalutil.WriteBytes2D(loopBuf, after) loopBuf = append(loopBuf, interiorNodeTag) loopCurrHash = cryptoffi.Hash(loopBuf) loopBuf = loopBuf[:0] @@ -204,19 +207,20 @@ func getPath(root *node, label []byte) []*node { return nodePath } -func (ctx *context) getProof(interiors []*node, label []byte) []byte { +func (ctx *context) getProof(interiors []*node, label []byte) [][][]byte { interiorsLen := uint64(len(interiors)) - var proof = make([]byte, 0, interiorsLen*hashesPerProofDepth) + proof := make([][][]byte, 0, interiorsLen) for depth := uint64(0); depth < interiorsLen; depth++ { children := interiors[depth].children - // convert to uint64 bc otherwise pos+1 might overflow. pos := uint64(label[depth]) + oneProof := make([][]byte, 0, numChildren-1) for _, n := range children[:pos] { - proof = marshal.WriteBytes(proof, ctx.getHash(n)) + oneProof = append(oneProof, ctx.getHash(n)) } for _, n := range children[pos+1:] { - proof = marshal.WriteBytes(proof, ctx.getHash(n)) + oneProof = append(oneProof, ctx.getHash(n)) } + proof = append(proof, oneProof) } return proof } @@ -229,3 +233,13 @@ func newInteriorNode() *node { func newCtx() *context { return &context{emptyHash: compEmptyNodeHash()} } + +func checkValidHashes(hashes [][]byte) bool { + var ok = true + for _, hash := range hashes { + if uint64(len(hash)) != cryptoffi.HashLen { + ok = false + } + } + return ok +}