Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve type.hashTreeRoot() using batch #409

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/as-sha256/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {newInstance} from "./wasm";
import {HashObject, byteArrayIntoHashObject, byteArrayToHashObject, hashObjectToByteArray} from "./hashObject";
import SHA256 from "./sha256";
export {HashObject, byteArrayToHashObject, hashObjectToByteArray, byteArrayIntoHashObject, SHA256};
export {allocUnsafe};

const ctx = newInstance();
const wasmInputValue = ctx.input.value;
Expand Down
5 changes: 5 additions & 0 deletions packages/ssz/src/type/abstract.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ export abstract class Type<V> {
*/
abstract hashTreeRoot(value: V): Uint8Array;

/**
* Same to hashTreeRoot() but here we write result to output.
*/
abstract hashTreeRootInto(value: V, output: Uint8Array, offset: number): void;

// JSON support

/** Parse JSON representation of a type to value */
Expand Down
26 changes: 17 additions & 9 deletions packages/ssz/src/type/arrayComposite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,21 +211,29 @@ export function tree_deserializeFromBytesArrayComposite<ElementType extends Comp
}
}

/**
* @param length In List length = value.length, Vector length = fixed value
*/
export function value_getRootsArrayComposite<ElementType extends CompositeType<unknown, unknown, unknown>>(
export function value_getChunkBytesArrayComposite<ElementType extends CompositeType<unknown, unknown, unknown>>(
elementType: ElementType,
length: number,
value: ValueOf<ElementType>[]
): Uint8Array[] {
const roots = new Array<Uint8Array>(length);
value: ValueOf<ElementType>[],
chunkBytesBuffer: Uint8Array
): Uint8Array {
const isOddChunk = length % 2 === 1;
const chunkBytesLen = isOddChunk ? length * 32 + 32 : length * 32;
if (chunkBytesLen > chunkBytesBuffer.length) {
throw new Error(`chunkBytesBuffer is too small: ${chunkBytesBuffer.length} < ${chunkBytesLen}`);
}
const chunkBytes = chunkBytesBuffer.subarray(0, chunkBytesLen);

for (let i = 0; i < length; i++) {
roots[i] = elementType.hashTreeRoot(value[i]);
elementType.hashTreeRootInto(value[i], chunkBytes, i * 32);
}

if (isOddChunk) {
// similar to append zeroHash(0)
chunkBytes.subarray(length * 32, chunkBytesLen).fill(0);
}

return roots;
return chunkBytes;
}

function readOffsetsArrayComposite(
Expand Down
13 changes: 10 additions & 3 deletions packages/ssz/src/type/basic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,18 @@ export abstract class BasicType<V> extends Type<V> {
}

hashTreeRoot(value: V): Uint8Array {
// TODO: Optimize
const uint8Array = new Uint8Array(32);
// cannot use allocUnsafe() here because hashTreeRootInto() may not fill the whole 32 bytes
const root = new Uint8Array(32);
this.hashTreeRootInto(value, root, 0);
return root;
}

hashTreeRootInto(value: V, output: Uint8Array, offset: number): void {
const uint8Array = output.subarray(offset, offset + 32);
// output could have preallocated data, some types may not fill the whole 32 bytes
uint8Array.fill(0);
const dataView = new DataView(uint8Array.buffer, uint8Array.byteOffset, uint8Array.byteLength);
this.value_serializeToBytes({uint8Array, dataView}, 0, value);
return uint8Array;
}

clone(value: V): V {
Expand Down
13 changes: 10 additions & 3 deletions packages/ssz/src/type/bitArray.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import {concatGindices, Gindex, Node, toGindex, Tree, HashComputationLevel} from "@chainsafe/persistent-merkle-tree";
import {fromHexString, toHexString, byteArrayEquals} from "../util/byteArray";
import {splitIntoRootChunks} from "../util/merkleize";
import {CompositeType, LENGTH_GINDEX} from "./composite";
import {BitArray} from "../value/bitArray";
import {BitArrayTreeView} from "../view/bitArray";
import {BitArrayTreeViewDU} from "../viewDU/bitArray";
import {getChunkBytes} from "./byteArray";

/* eslint-disable @typescript-eslint/member-ordering */

Expand Down Expand Up @@ -40,8 +40,15 @@ export abstract class BitArrayType extends CompositeType<BitArray, BitArrayTreeV

// Merkleization

protected getRoots(value: BitArray): Uint8Array[] {
return splitIntoRootChunks(value.uint8Array);
protected getChunkBytes(value: BitArray): Uint8Array {
// reallocate this.merkleBytes if needed
if (value.uint8Array.length > this.chunkBytesBuffer.length) {
const chunkCount = Math.ceil(value.bitLen / 8 / 32);
const chunkBytes = chunkCount * 32;
// pad 1 chunk if maxChunkCount is not even
this.chunkBytesBuffer = chunkCount % 2 === 1 ? new Uint8Array(chunkBytes + 32) : new Uint8Array(chunkBytes);
}
return getChunkBytes(value.uint8Array, this.chunkBytesBuffer);
}

// Proofs
Expand Down
30 changes: 27 additions & 3 deletions packages/ssz/src/type/bitList.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import {getNodesAtDepth, Node, packedNodeRootsToBytes, packedRootsBytesToNode} from "@chainsafe/persistent-merkle-tree";
import {mixInLength, maxChunksToDepth} from "../util/merkleize";
import {allocUnsafe} from "@chainsafe/as-sha256";
import {
getNodesAtDepth,
merkleizeInto,
Node,
packedNodeRootsToBytes,
packedRootsBytesToNode,
} from "@chainsafe/persistent-merkle-tree";
import {maxChunksToDepth} from "../util/merkleize";
import {Require} from "../util/types";
import {namedClass} from "../util/named";
import {ByteViews} from "./composite";
Expand Down Expand Up @@ -29,6 +36,12 @@ export class BitListType extends BitArrayType {
readonly maxSize: number;
readonly maxChunkCount: number;
readonly isList = true;
readonly mixInLengthChunkBytes = new Uint8Array(64);
readonly mixInLengthBuffer = Buffer.from(
this.mixInLengthChunkBytes.buffer,
this.mixInLengthChunkBytes.byteOffset,
this.mixInLengthChunkBytes.byteLength
);

constructor(readonly limitBits: number, opts?: BitListOptions) {
super();
Expand Down Expand Up @@ -101,7 +114,18 @@ export class BitListType extends BitArrayType {
// Merkleization: inherited from BitArrayType

hashTreeRoot(value: BitArray): Uint8Array {
return mixInLength(super.hashTreeRoot(value), value.bitLen);
const root = allocUnsafe(32);
this.hashTreeRootInto(value, root, 0);
return root;
}

hashTreeRootInto(value: BitArray, output: Uint8Array, offset: number): void {
super.hashTreeRootInto(value, this.mixInLengthChunkBytes, 0);
// mixInLength
this.mixInLengthBuffer.writeUIntLE(value.bitLen, 32, 6);
// one for hashTreeRoot(value), one for length
const chunkCount = 2;
merkleizeInto(this.mixInLengthChunkBytes, chunkCount, output, offset);
}

// Proofs: inherited from BitArrayType
Expand Down
31 changes: 28 additions & 3 deletions packages/ssz/src/type/byteArray.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import {
getHashComputations,
} from "@chainsafe/persistent-merkle-tree";
import {fromHexString, toHexString, byteArrayEquals} from "../util/byteArray";
import {splitIntoRootChunks} from "../util/merkleize";
import {ByteViews} from "./abstract";
import {CompositeType, LENGTH_GINDEX} from "./composite";

Expand Down Expand Up @@ -82,10 +81,23 @@ export abstract class ByteArrayType extends CompositeType<ByteArray, ByteArray,
return Uint8Array.prototype.slice.call(data.uint8Array, start, end);
}

value_toTree(value: ByteArray): Node {
// this saves 1 allocation of Uint8Array
const dataView = new DataView(value.buffer, value.byteOffset, value.byteLength);
return this.tree_deserializeFromBytes({uint8Array: value, dataView}, 0, value.length);
}

// Merkleization

protected getRoots(value: ByteArray): Uint8Array[] {
return splitIntoRootChunks(value);
protected getChunkBytes(value: ByteArray): Uint8Array {
// reallocate this.merkleBytes if needed
if (value.length > this.chunkBytesBuffer.length) {
const chunkCount = Math.ceil(value.length / 32);
const chunkBytes = chunkCount * 32;
// pad 1 chunk if maxChunkCount is not even
this.chunkBytesBuffer = chunkCount % 2 === 1 ? new Uint8Array(chunkBytes + 32) : new Uint8Array(chunkBytes);
}
return getChunkBytes(value, this.chunkBytesBuffer);
}

// Proofs
Expand Down Expand Up @@ -149,3 +161,16 @@ export abstract class ByteArrayType extends CompositeType<ByteArray, ByteArray,

protected abstract assertValidSize(size: number): void;
}

export function getChunkBytes(data: Uint8Array, merkleBytesBuffer: Uint8Array): Uint8Array {
if (data.length > merkleBytesBuffer.length) {
throw new Error(`data length ${data.length} exceeds merkleBytesBuffer length ${merkleBytesBuffer.length}`);
}

merkleBytesBuffer.set(data);
const valueLen = data.length;
const chunkByteLen = Math.ceil(valueLen / 64) * 64;
// all padding bytes must be zero, this is similar to set zeroHash(0)
merkleBytesBuffer.subarray(valueLen, chunkByteLen).fill(0);
return merkleBytesBuffer.subarray(0, chunkByteLen);
}
31 changes: 27 additions & 4 deletions packages/ssz/src/type/byteList.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import {getNodesAtDepth, Node, packedNodeRootsToBytes, packedRootsBytesToNode} from "@chainsafe/persistent-merkle-tree";
import {mixInLength, maxChunksToDepth} from "../util/merkleize";
import {allocUnsafe} from "@chainsafe/as-sha256";
import {
getNodesAtDepth,
Node,
packedNodeRootsToBytes,
packedRootsBytesToNode,
merkleizeInto,
} from "@chainsafe/persistent-merkle-tree";
import {maxChunksToDepth} from "../util/merkleize";
import {Require} from "../util/types";
import {namedClass} from "../util/named";
import {addLengthNode, getChunksNodeFromRootNode, getLengthFromRootNode} from "./arrayBasic";
import {ByteViews} from "./composite";
import {ByteArrayType, ByteArray} from "./byteArray";

/* eslint-disable @typescript-eslint/member-ordering */

export interface ByteListOptions {
Expand Down Expand Up @@ -34,6 +40,12 @@ export class ByteListType extends ByteArrayType {
readonly maxSize: number;
readonly maxChunkCount: number;
readonly isList = true;
readonly mixInLengthChunkBytes = new Uint8Array(64);
readonly mixInLengthBuffer = Buffer.from(
this.mixInLengthChunkBytes.buffer,
this.mixInLengthChunkBytes.byteOffset,
this.mixInLengthChunkBytes.byteLength
);

constructor(readonly limitBytes: number, opts?: ByteListOptions) {
super();
Expand Down Expand Up @@ -89,7 +101,18 @@ export class ByteListType extends ByteArrayType {
// Merkleization: inherited from ByteArrayType

hashTreeRoot(value: ByteArray): Uint8Array {
return mixInLength(super.hashTreeRoot(value), value.length);
const root = allocUnsafe(32);
this.hashTreeRootInto(value, root, 0);
return root;
}

hashTreeRootInto(value: Uint8Array, output: Uint8Array, offset: number): void {
super.hashTreeRootInto(value, this.mixInLengthChunkBytes, 0);
// mixInLength
this.mixInLengthBuffer.writeUIntLE(value.length, 32, 6);
// one for hashTreeRoot(value), one for length
const chunkCount = 2;
merkleizeInto(this.mixInLengthChunkBytes, chunkCount, output, offset);
}

// Proofs: inherited from BitArrayType
Expand Down
35 changes: 30 additions & 5 deletions packages/ssz/src/type/composite.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import {allocUnsafe} from "@chainsafe/as-sha256";
import {
concatGindices,
createProof,
Expand All @@ -7,10 +8,11 @@ import {
Proof,
ProofType,
Tree,
merkleizeInto,
HashComputationLevel,
} from "@chainsafe/persistent-merkle-tree";
import {byteArrayEquals} from "../util/byteArray";
import {merkleize, symbolCachedPermanentRoot, ValueWithCachedPermanentRoot} from "../util/merkleize";
import {cacheRoot, symbolCachedPermanentRoot, ValueWithCachedPermanentRoot} from "../util/merkleize";
import {treePostProcessFromProofNode} from "../util/proof/treePostProcessFromProofNode";
import {Type, ByteViews, JsonPath, JsonPathProp} from "./abstract";
export {ByteViews};
Expand Down Expand Up @@ -59,6 +61,7 @@ export abstract class CompositeType<V, TV, TVDU> extends Type<V> {
* Required for ContainerNodeStruct to ensure no dangerous types are constructed.
*/
abstract readonly isViewMutable: boolean;
protected chunkBytesBuffer = new Uint8Array(0);

constructor(
/**
Expand Down Expand Up @@ -216,13 +219,30 @@ export abstract class CompositeType<V, TV, TVDU> extends Type<V> {
}
}

const root = merkleize(this.getRoots(value), this.maxChunkCount);
const root = allocUnsafe(32);
const safeCache = true;
this.hashTreeRootInto(value, root, 0, safeCache);

// hashTreeRootInto will cache the root if cachePermanentRootStruct is true

return root;
}

hashTreeRootInto(value: V, output: Uint8Array, offset: number, safeCache = false): void {
// Return cached mutable root if any
if (this.cachePermanentRootStruct) {
(value as ValueWithCachedPermanentRoot)[symbolCachedPermanentRoot] = root;
const cachedRoot = (value as ValueWithCachedPermanentRoot)[symbolCachedPermanentRoot];
if (cachedRoot) {
output.set(cachedRoot, offset);
return;
}
}

return root;
const merkleBytes = this.getChunkBytes(value);
merkleizeInto(merkleBytes, this.maxChunkCount, output, offset);
if (this.cachePermanentRootStruct) {
cacheRoot(value as ValueWithCachedPermanentRoot, output, offset, safeCache);
}
}

// For debugging and testing this feature
Expand All @@ -236,7 +256,12 @@ export abstract class CompositeType<V, TV, TVDU> extends Type<V> {
// and feed those numbers directly to the hasher input with a DataView
// - The return of the hasher should be customizable too, to reduce conversions from Uint8Array
// to hashObject and back.
protected abstract getRoots(value: V): Uint8Array[];

/**
* Get merkle bytes of each value, the returned Uint8Array should be multiple of 64 bytes.
* If chunk count is not even, need to append zeroHash(0)
*/
protected abstract getChunkBytes(value: V): Uint8Array;

// Proofs API

Expand Down
13 changes: 7 additions & 6 deletions packages/ssz/src/type/container.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ export class ContainerType<Fields extends Record<string, Type<unknown>>> extends
// Refactor this constructor to allow customization without pollutin the options
this.TreeView = opts?.getContainerTreeViewClass?.(this) ?? getContainerTreeViewClass(this);
this.TreeViewDU = opts?.getContainerTreeViewDUClass?.(this) ?? getContainerTreeViewDUClass(this);
const fieldBytes = this.fieldsEntries.length * 32;
const chunkBytes = Math.ceil(fieldBytes / 64) * 64;
this.chunkBytesBuffer = new Uint8Array(chunkBytes);
}

static named<Fields extends Record<string, Type<unknown>>>(
Expand Down Expand Up @@ -272,15 +275,13 @@ export class ContainerType<Fields extends Record<string, Type<unknown>>> extends

// Merkleization

protected getRoots(struct: ValueOfFields<Fields>): Uint8Array[] {
const roots = new Array<Uint8Array>(this.fieldsEntries.length);

protected getChunkBytes(struct: ValueOfFields<Fields>): Uint8Array {
for (let i = 0; i < this.fieldsEntries.length; i++) {
const {fieldName, fieldType} = this.fieldsEntries[i];
roots[i] = fieldType.hashTreeRoot(struct[fieldName]);
fieldType.hashTreeRootInto(struct[fieldName], this.chunkBytesBuffer, i * 32);
}

return roots;
// remaining bytes are zeroed as we never write them
return this.chunkBytesBuffer;
}

// Proofs
Expand Down
1 change: 0 additions & 1 deletion packages/ssz/src/type/containerNodeStruct.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ export class ContainerNodeStructType<Fields extends Record<string, Type<unknown>
return new BranchNodeStruct(this.valueToTree.bind(this), value);
}

// TODO: Optimize conversion
private valueToTree(value: ValueOfFields<Fields>): Node {
const uint8Array = new Uint8Array(this.value_serializedSize(value));
const dataView = new DataView(uint8Array.buffer, uint8Array.byteOffset, uint8Array.byteLength);
Expand Down
Loading
Loading