Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMT optimization #154

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 62 additions & 43 deletions contracts/data/IncrementalMerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,26 @@ library IncrementalMerkleTree {
* @return treeSize size of tree
*/
function size(Tree storage t) internal view returns (uint256 treeSize) {
if (t.height() > 0) {
treeSize = t.nodes[0].length;
bytes32[][] storage nodes = t.nodes;

assembly {
mstore(0x00, nodes.slot)
treeSize := sload(keccak256(0x00, 0x20))
}
}

/**
* @notice query one-indexed height of tree
* @dev conventional zero-indexed height would require the use of signed integers, so height is one-indexed instead
* @param t Tree struct storage reference
* @return one-indexed height of tree
* @return treeHeight one-indexed height of tree
*/
function height(Tree storage t) internal view returns (uint256) {
return t.nodes.length;
function height(Tree storage t) internal view returns (uint256 treeHeight) {
bytes32[][] storage nodes = t.nodes;

assembly {
treeHeight := sload(nodes.slot)
}
}

/**
Expand All @@ -36,11 +43,15 @@ library IncrementalMerkleTree {
* @return hash root hash
*/
function root(Tree storage t) internal view returns (bytes32 hash) {
bytes32[][] storage nodes = t.nodes;

uint256 treeHeight = t.height();

if (treeHeight > 0) {
unchecked {
hash = t.nodes[treeHeight - 1][0];
assembly {
mstore(0x00, nodes.slot)
mstore(0x00, add(keccak256(0x00, 0x20), sub(treeHeight, 1)))
hash := sload(keccak256(0x00, 0x20))
}
}
}
Expand All @@ -59,59 +70,58 @@ library IncrementalMerkleTree {
*/
function push(Tree storage t, bytes32 hash) internal {
unchecked {
uint256 treeHeight = t.height();
uint256 treeSize = t.size();
// index to add to tree
uint256 updateIndex = t.size();

// add new layer if tree is at capacity

if (treeSize == (1 << treeHeight) >> 1) {
if (updateIndex == (1 << t.height()) >> 1) {
t.nodes.push();
treeHeight++;
}

// add new columns if rows are full

uint256 row;
uint256 col = treeSize;
uint256 col = updateIndex;

while (row < treeHeight && t.nodes[row].length <= col) {
while (col == t.nodes[row].length) {
t.nodes[row].push();
row++;
if (col == 0) break;
col >>= 1;
}

// add hash to tree

t.set(treeSize, hash);
t.set(updateIndex, hash);
}
}

function pop(Tree storage t) internal {
uint256 treeHeight = t.height();
uint256 treeSize = t.size() - 1;

// remove layer if tree has excess capacity

if (treeSize == (1 << treeHeight) >> 2) {
treeHeight--;
t.nodes.pop();
}
unchecked {
// index to remove from tree
uint256 updateIndex = t.size() - 1;

// remove columns if rows are too long
// remove columns if rows are too long

uint256 row;
uint256 col = treeSize;
uint256 row;
uint256 col = updateIndex;

while (row < treeHeight && t.nodes[row].length > col) {
t.nodes[row].pop();
row++;
col = (col + 1) >> 1;
}
while (col != t.nodes[row].length) {
t.nodes[row].pop();
row++;
col >>= 1;
if (col == 0) break;
}

// recalculate hashes
// if new tree is full, remove excess layer
// if no layer is removed, recalculate hashes

if (treeSize > 0) {
t.set(treeSize - 1, t.at(treeSize - 1));
if (updateIndex == (1 << t.height()) >> 2) {
t.nodes.pop();
} else {
t.set(updateIndex - 1, t.at(updateIndex - 1));
}
}
}

Expand All @@ -122,31 +132,34 @@ library IncrementalMerkleTree {
* @param hash new hash to add
*/
function set(Tree storage t, uint256 index, bytes32 hash) internal {
unchecked {
_set(t.nodes, 0, index, t.height() - 1, hash);
}
_set(t.nodes, 0, index, t.size(), hash);
}

/**
* @notice update element in tree and recursively recalculate hashes
* @param nodes internal tree structure storage reference
* @param rowIndex index of current row to update
* @param colIndex index of current column to update
* @param rootIndex index of root row
* @param rowLength length of row at rowIndex
* @param hash hash to store at current position
*/
function _set(
bytes32[][] storage nodes,
uint256 rowIndex,
uint256 colIndex,
uint256 rootIndex,
uint256 rowLength,
bytes32 hash
) private {
bytes32[] storage row = nodes[rowIndex];

row[colIndex] = hash;
// store hash in array via assembly to avoid array length sload

assembly {
mstore(0x00, row.slot)
sstore(add(keccak256(0x00, 0x20), colIndex), hash)
}

if (rowIndex == rootIndex) return;
if (rowLength == 1) return;

unchecked {
if (colIndex & 1 == 1) {
Expand All @@ -160,7 +173,7 @@ library IncrementalMerkleTree {
mstore(0x20, hash)
hash := keccak256(0x00, 0x40)
}
} else if (colIndex + 1 < row.length) {
} else if (colIndex < rowLength - 1) {
// sibling is on the right (and sibling exists)
assembly {
mstore(0x00, row.slot)
Expand All @@ -173,7 +186,13 @@ library IncrementalMerkleTree {
}
}

_set(nodes, rowIndex + 1, colIndex >> 1, rootIndex, hash);
_set(
nodes,
rowIndex + 1,
colIndex >> 1,
(rowLength + 1) >> 1,
hash
);
}
}
}
12 changes: 10 additions & 2 deletions test/data/IncrementalMerkleTree.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,19 @@ describe('IncrementalMerkleTree', function () {
});

describe('reverts if', () => {
it('index is out of bounds', async () => {
it('tree is size zero', async () => {
await expect(instance.callStatic.at(0)).to.be.revertedWithPanic(
PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS,
);
});

it('index is out of bounds', async () => {
await instance.push(randomHash());

await expect(instance.callStatic.at(1)).to.be.revertedWithPanic(
PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS,
);
});
});
});

Expand Down Expand Up @@ -186,7 +194,7 @@ describe('IncrementalMerkleTree', function () {
describe('reverts if', () => {
it('tree is size zero', async () => {
await expect(instance.pop()).to.be.revertedWithPanic(
PANIC_CODES.ARITHMETIC_UNDER_OR_OVERFLOW,
PANIC_CODES.ARRAY_ACCESS_OUT_OF_BOUNDS,
);
});
});
Expand Down