Skip to content

Commit

Permalink
feat: implement batch hash utils (#384)
Browse files Browse the repository at this point in the history
* feat: implement batch hash utils

* fix: ssz perf test check-types

* fix: avoid Array<HashComputation[]> and remove hashtree peerDependencies
  • Loading branch information
twoeths authored Jul 17, 2024
1 parent ccadf43 commit 1578883
Show file tree
Hide file tree
Showing 28 changed files with 1,210 additions and 106 deletions.
5 changes: 4 additions & 1 deletion packages/persistent-merkle-tree/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
"clean": "rm -rf lib",
"build": "tsc",
"lint": "eslint --color --ext .ts src/",
"benchmark": "node --max-old-space-size=4096 --expose-gc -r ts-node/register ./node_modules/.bin/benchmark 'test/perf/*.perf.ts'",
"lint:fix": "yarn run lint --fix",
"benchmark:files": "node --max-old-space-size=4096 --expose-gc -r ts-node/register ../../node_modules/.bin/benchmark",
"benchmark": "yarn benchmark:files 'test/perf/*.test.ts'",
"benchmark:local": "yarn benchmark --local",
"test": "mocha -r ts-node/register 'test/unit/**/*.test.ts'"
},
Expand All @@ -45,6 +47,7 @@
"homepage": "https://github.com/ChainSafe/persistent-merkle-tree#readme",
"dependencies": {
"@chainsafe/as-sha256": "0.4.2",
"@chainsafe/hashtree": "1.0.1",
"@noble/hashes": "^1.3.0"
}
}
138 changes: 136 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/as-sha256.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,141 @@
import {digest2Bytes32, digest64HashObjects} from "@chainsafe/as-sha256";
import {
digest2Bytes32,
digest64HashObjectsInto,
digest64HashObjects,
batchHash4HashObjectInputs,
hashInto,
} from "@chainsafe/as-sha256";
import type {Hasher} from "./types";
import {HashComputation, Node} from "../node";
import {doDigestNLevel, doMerkleizeInto} from "./util";

export const hasher: Hasher = {
name: "as-sha256",
digest64: digest2Bytes32,
digest64HashObjects,
digest64HashObjects: digest64HashObjectsInto,
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations: (hashComputations: HashComputation[][]) => {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

if (hcArr.length === 0) {
// nothing to hash
continue;
}

// HashComputations of the same level are safe to batch
let src0_0: Node | null = null;
let src1_0: Node | null = null;
let dest0: Node | null = null;
let src0_1: Node | null = null;
let src1_1: Node | null = null;
let dest1: Node | null = null;
let src0_2: Node | null = null;
let src1_2: Node | null = null;
let dest2: Node | null = null;
let src0_3: Node | null = null;
let src1_3: Node | null = null;
let dest3: Node | null = null;

for (const [i, hc] of hcArr.entries()) {
const indexInBatch = i % 4;

switch (indexInBatch) {
case 0:
src0_0 = hc.src0;
src1_0 = hc.src1;
dest0 = hc.dest;
break;
case 1:
src0_1 = hc.src0;
src1_1 = hc.src1;
dest1 = hc.dest;
break;
case 2:
src0_2 = hc.src0;
src1_2 = hc.src1;
dest2 = hc.dest;
break;
case 3:
src0_3 = hc.src0;
src1_3 = hc.src1;
dest3 = hc.dest;

if (
src0_0 !== null &&
src1_0 !== null &&
dest0 !== null &&
src0_1 !== null &&
src1_1 !== null &&
dest1 !== null &&
src0_2 !== null &&
src1_2 !== null &&
dest2 !== null &&
src0_3 !== null &&
src1_3 !== null &&
dest3 !== null
) {
// TODO - batch: find a way not allocate here
const [o0, o1, o2, o3] = batchHash4HashObjectInputs([
src0_0,
src1_0,
src0_1,
src1_1,
src0_2,
src1_2,
src0_3,
src1_3,
]);
if (o0 == null || o1 == null || o2 == null || o3 == null) {
throw Error(`batchHash4HashObjectInputs return null or undefined at batch ${i} level ${level}`);
}
dest0.applyHash(o0);
dest1.applyHash(o1);
dest2.applyHash(o2);
dest3.applyHash(o3);

// reset for next batch
src0_0 = null;
src1_0 = null;
dest0 = null;
src0_1 = null;
src1_1 = null;
dest1 = null;
src0_2 = null;
src1_2 = null;
dest2 = null;
src0_3 = null;
src1_3 = null;
dest3 = null;
}
break;
default:
throw Error(`Unexpected indexInBatch ${indexInBatch}`);
}
}

// remaining
if (src0_0 !== null && src1_0 !== null && dest0 !== null) {
dest0.applyHash(digest64HashObjects(src0_0, src1_0));
}
if (src0_1 !== null && src1_1 !== null && dest1 !== null) {
dest1.applyHash(digest64HashObjects(src0_1, src1_1));
}
if (src0_2 !== null && src1_2 !== null && dest2 !== null) {
dest2.applyHash(digest64HashObjects(src0_2, src1_2));
}
if (src0_3 !== null && src1_3 !== null && dest3 !== null) {
dest3.applyHash(digest64HashObjects(src0_3, src1_3));
}
}
},
};
124 changes: 124 additions & 0 deletions packages/persistent-merkle-tree/src/hasher/hashtree.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import {hashInto} from "@chainsafe/hashtree";
import {Hasher, HashObject} from "./types";
import {HashComputation, Node} from "../node";
import {byteArrayIntoHashObject} from "@chainsafe/as-sha256/lib/hashObject";
import {doDigestNLevel, doMerkleizeInto} from "./util";

/**
* Best SIMD implementation is in 512 bits = 64 bytes
* If not, hashtree will make a loop inside
* Given sha256 operates on a block of 4 bytes, we can hash 16 inputs at once
* Each input is 64 bytes
*/
const PARALLEL_FACTOR = 16;
const MAX_INPUT_SIZE = PARALLEL_FACTOR * 64;
const uint8Input = new Uint8Array(MAX_INPUT_SIZE);
const uint32Input = new Uint32Array(uint8Input.buffer);
const uint8Output = new Uint8Array(PARALLEL_FACTOR * 32);
// having this will cause more memory to extract uint32
// const uint32Output = new Uint32Array(uint8Output.buffer);
// convenient reusable Uint8Array for hash64
const hash64Input = uint8Input.subarray(0, 64);
const hash64Output = uint8Output.subarray(0, 32);

export const hasher: Hasher = {
name: "hashtree",
digest64(obj1: Uint8Array, obj2: Uint8Array): Uint8Array {
if (obj1.length !== 32 || obj2.length !== 32) {
throw new Error("Invalid input length");
}
hash64Input.set(obj1, 0);
hash64Input.set(obj2, 32);
hashInto(hash64Input, hash64Output);
return hash64Output.slice();
},
digest64HashObjects(left: HashObject, right: HashObject, parent: HashObject): void {
hashObjectsToUint32Array(left, right, uint32Input);
hashInto(hash64Input, hash64Output);
byteArrayIntoHashObject(hash64Output, 0, parent);
},
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations(hashComputations: HashComputation[][]): void {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

if (hcArr.length === 0) {
// nothing to hash
continue;
}

// size input array to 2 HashObject per computation * 32 bytes per object
// const input: Uint8Array = Uint8Array.from(new Array(hcArr.length * 2 * 32));
let destNodes: Node[] = [];

// hash every 16 inputs at once to avoid memory allocation
for (const [i, {src0, src1, dest}] of hcArr.entries()) {
const indexInBatch = i % PARALLEL_FACTOR;
const offset = indexInBatch * 16;

hashObjectToUint32Array(src0, uint32Input, offset);
hashObjectToUint32Array(src1, uint32Input, offset + 8);
destNodes.push(dest);
if (indexInBatch === PARALLEL_FACTOR - 1) {
hashInto(uint8Input, uint8Output);
for (const [j, destNode] of destNodes.entries()) {
byteArrayIntoHashObject(uint8Output, j * 32, destNode);
}
destNodes = [];
}
}

const remaining = hcArr.length % PARALLEL_FACTOR;
// we prepared data in input, now hash the remaining
if (remaining > 0) {
const remainingInput = uint8Input.subarray(0, remaining * 64);
const remainingOutput = uint8Output.subarray(0, remaining * 32);
hashInto(remainingInput, remainingOutput);
// destNodes was prepared above
for (const [i, destNode] of destNodes.entries()) {
byteArrayIntoHashObject(remainingOutput, i * 32, destNode);
}
}
}
},
};

function hashObjectToUint32Array(obj: HashObject, arr: Uint32Array, offset: number): void {
arr[offset] = obj.h0;
arr[offset + 1] = obj.h1;
arr[offset + 2] = obj.h2;
arr[offset + 3] = obj.h3;
arr[offset + 4] = obj.h4;
arr[offset + 5] = obj.h5;
arr[offset + 6] = obj.h6;
arr[offset + 7] = obj.h7;
}

// note that uint32ArrayToHashObject will cause more memory
function hashObjectsToUint32Array(obj1: HashObject, obj2: HashObject, arr: Uint32Array): void {
arr[0] = obj1.h0;
arr[1] = obj1.h1;
arr[2] = obj1.h2;
arr[3] = obj1.h3;
arr[4] = obj1.h4;
arr[5] = obj1.h5;
arr[6] = obj1.h6;
arr[7] = obj1.h7;
arr[8] = obj2.h0;
arr[9] = obj2.h1;
arr[10] = obj2.h2;
arr[11] = obj2.h3;
arr[12] = obj2.h4;
arr[13] = obj2.h5;
arr[14] = obj2.h6;
arr[15] = obj2.h7;
}
20 changes: 18 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import {Hasher} from "./types";
import {hasher as nobleHasher} from "./noble";
import type {HashComputation} from "../node";

export {HashObject} from "@chainsafe/as-sha256/lib/hashObject";
export * from "./types";
export * from "./util";

/**
* Hasher used across the SSZ codebase
* Hasher used across the SSZ codebase, by default, this does not support batch hash.
*/
export let hasher: Hasher = nobleHasher;

Expand All @@ -18,3 +18,19 @@ export let hasher: Hasher = nobleHasher;
export function setHasher(newHasher: Hasher): void {
hasher = newHasher;
}

export function digest64(a: Uint8Array, b: Uint8Array): Uint8Array {
return hasher.digest64(a, b);
}

export function digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return hasher.digestNLevel(data, nLevel);
}

export function merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
hasher.merkleizeInto(data, padFor, output, offset);
}

export function executeHashComputations(hashComputations: HashComputation[][]): void {
hasher.executeHashComputations(hashComputations);
}
44 changes: 42 additions & 2 deletions packages/persistent-merkle-tree/src/hasher/noble.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
import {sha256} from "@noble/hashes/sha256";
import {digest64HashObjects, byteArrayIntoHashObject} from "@chainsafe/as-sha256";
import type {Hasher} from "./types";
import {hashObjectToUint8Array, uint8ArrayToHashObject} from "./util";
import {doDigestNLevel, doMerkleizeInto, hashObjectToUint8Array} from "./util";

const digest64 = (a: Uint8Array, b: Uint8Array): Uint8Array => sha256.create().update(a).update(b).digest();
const hashInto = (input: Uint8Array, output: Uint8Array): void => {
if (input.length % 64 !== 0) {
throw new Error(`Invalid input length ${input.length}`);
}
if (input.length !== output.length * 2) {
throw new Error(`Invalid output length ${output.length}`);
}

const count = Math.floor(input.length / 64);
for (let i = 0; i < count; i++) {
const offset = i * 64;
const in1 = input.subarray(offset, offset + 32);
const in2 = input.subarray(offset + 32, offset + 64);
const out = digest64(in1, in2);
output.set(out, i * 32);
}
};

export const hasher: Hasher = {
name: "noble",
digest64,
digest64HashObjects: (a, b) => uint8ArrayToHashObject(digest64(hashObjectToUint8Array(a), hashObjectToUint8Array(b))),
digest64HashObjects: (left, right, parent) => {
byteArrayIntoHashObject(digest64(hashObjectToUint8Array(left), hashObjectToUint8Array(right)), 0, parent);
},
merkleizeInto(data: Uint8Array, padFor: number, output: Uint8Array, offset: number): void {
return doMerkleizeInto(data, padFor, output, offset, hashInto);
},
digestNLevel(data: Uint8Array, nLevel: number): Uint8Array {
return doDigestNLevel(data, nLevel, hashInto);
},
executeHashComputations: (hashComputations) => {
for (let level = hashComputations.length - 1; level >= 0; level--) {
const hcArr = hashComputations[level];
if (!hcArr) {
// should not happen
throw Error(`no hash computations for level ${level}`);
}

for (const hc of hcArr) {
hc.dest.applyHash(digest64HashObjects(hc.src0, hc.src1));
}
}
},
};
Loading

0 comments on commit 1578883

Please sign in to comment.