diff --git a/README.md b/README.md index dc3601e8..a065252a 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![NPM version](https://img.shields.io/npm/v/@dao-xyz/borsh.svg?style=flat-square)](https://npmjs.com/@dao-xyz/borsh) [![Size on NPM](https://img.shields.io/bundlephobia/minzip/@dao-xyz/borsh.svg?style=flat-square)](https://npmjs.com/@dao-xyz/borsh) -**Borsh TS** is *unofficial* implementation of the [Borsh] binary serialization format for TypeScript projects. The motivation behind this library is to provide more convinient methods using **field and class decorators.** +**Borsh TS** is a Typescript implementation of the [Borsh] binary serialization format for TypeScript projects. The motivation behind this library is to provide more convinient methods using **field and class decorators.** Borsh stands for _Binary Object Representation Serializer for Hashing_. It is meant to be used in security-critical projects as it prioritizes consistency, safety, speed, and comes with a strict specification. @@ -123,6 +123,18 @@ class TestStruct { } ``` +Variants can be 'number', 'number[]' (represents nested Rust Enums) or 'string' (not part of the Borsh specification). i.e. + +```typescript +@variant(0) +class ClazzA +... +@variant([0,1]) +class ClazzB +... +@variant("clazz c") +class ClazzC +``` **Nested Schema generation for structs** @@ -230,7 +242,27 @@ validate([TestStruct]) ``` ## Inheritance -Schema generation with class inheritance is not supported (yet) +Schema generation is supported if deserialization is deterministic +e.g. +```typescript +class A { + @field({type: 'number'}) + a: number +} + +@variant(0) +class B1 extends A{ + @field({type: 'number'}) + b1: number +} + +@variant(1) +class B2 extends A{ + @field({type: 'number'}) + b2: number +} + +``` ## Type Mappings diff --git a/package.json b/package.json index cc7308e4..ce79f611 100644 --- a/package.json +++ b/package.json @@ -1,9 +1,9 @@ { "name": "@dao-xyz/borsh", - "version": "2.0.7", + "version": "2.1.0", "readme": "README.md", "homepage": "https://github.com/dao-xyz/borsh-ts#README", - "description": "Binary Object Representation Serializer for Hashing", + "description": "Binary Object Representation Serializer for Hashing simplified with decorators", "author": "dao.xyz", "license": "Apache-2.0", "type": "module", diff --git a/src/__tests__/index.test.ts b/src/__tests__/index.test.ts index f04129e7..8a2f1dfb 100644 --- a/src/__tests__/index.test.ts +++ b/src/__tests__/index.test.ts @@ -30,7 +30,7 @@ describe("struct", () => { } } } - validate([TestStruct]); + validate(TestStruct); const expectedResult: StructKind = new StructKind({ fields: [ { @@ -66,7 +66,7 @@ describe("struct", () => { public a: InnerStruct; } - validate([TestStruct]); + validate(TestStruct); expect(getSchema(TestStruct)).toEqual( new StructKind({ fields: [{ key: "a", type: InnerStruct }], @@ -91,7 +91,8 @@ describe("struct", () => { public c: number; } - let schema = validate([TestStruct]).get(TestStruct); + validate(TestStruct); + let schema = getSchema(TestStruct); expect(schema.fields.length).toEqual(2); expect(schema.fields[0].key).toEqual("a"); expect(schema.fields[1].key).toEqual("c"); @@ -114,7 +115,7 @@ describe("bool", () => { } } } - validate([TestStruct]); + validate(TestStruct); const expectedResult: StructKind = new StructKind({ fields: [ { @@ -153,7 +154,7 @@ describe("arrays", () => { } } - validate([TestStruct]); + validate(TestStruct); const buf = serialize(new TestStruct({ a: [1, 2, 3] })); expect(buf).toEqual(Buffer.from([3, 0, 0, 0, 1, 2, 3])); const deserialized = deserialize(Buffer.from(buf), TestStruct); @@ -172,7 +173,7 @@ describe("arrays", () => { } } - validate([TestStruct]); + validate(TestStruct); const buf = serialize(new TestStruct({ a: [1, 2, 3] })); expect(buf).toEqual(Buffer.from([1, 2, 3])); const deserialized = deserialize(Buffer.from(buf), TestStruct); @@ -191,7 +192,7 @@ describe("arrays", () => { } } - validate([TestStruct]); + validate(TestStruct); expect(() => serialize(new TestStruct({ a: [1, 2] }))).toThrowError(); }); @@ -206,7 +207,7 @@ describe("arrays", () => { } } } - validate([TestStruct]); + validate(TestStruct); expect(() => deserialize(Buffer.from([1, 2]), TestStruct)).toThrowError(); }); @@ -233,7 +234,7 @@ describe("arrays", () => { } } - validate([TestStruct]); + validate(TestStruct); const arr = [ new Element({ a: 1 }), new Element({ a: 2 }), @@ -258,7 +259,7 @@ describe("enum", () => { } } const instance = new TestEnum(3); - validate([TestEnum]); + validate(TestEnum); const buf = serialize(instance); expect(buf).toEqual(Buffer.from([1, 3])); const deserialized = deserialize(Buffer.from(buf), TestEnum); @@ -269,7 +270,7 @@ describe("enum", () => { @variant(1) class TestEnum {} const instance = new TestEnum(); - validate([TestEnum]); + validate(TestEnum); const buf = serialize(instance); expect(buf).toEqual(Buffer.from([1])); }); @@ -291,7 +292,7 @@ describe("enum", () => { this.variant = variant; } } - validate([TestStruct]); + validate(TestStruct); expect(getSchema(TestStruct)).toBeDefined(); expect(getSchema(ImplementationByVariant)).toBeDefined(); }); @@ -330,7 +331,7 @@ describe("enum", () => { } } const instance = new TestStruct(new Enum1(4)); - validate([Enum0, Enum1, TestStruct]); + validate(Super); expect(getSchema(Enum0)).toBeDefined(); expect(getSchema(Enum1)).toBeDefined(); expect(getSchema(TestStruct)).toBeDefined(); @@ -347,7 +348,7 @@ describe("enum", () => { expect((deserialied.enum as Enum1).b).toEqual(4); }); - test("extended enum", () => { + test("extended enum top variants", () => { class SuperSuper {} class Super extends SuperSuper { @@ -395,6 +396,101 @@ describe("enum", () => { expect((deserialied as Enum1).b).toEqual(4); }); + test("extended enum inheritance variants", () => { + @variant(1) + class SuperSuper {} + + @variant(2) + class Super extends SuperSuper { + constructor() { + super(); + } + } + + @variant([3, 100]) + class Enum0 extends Super { + @field({ type: "u8" }) + public a: number; + + constructor(a: number) { + super(); + this.a = a; + } + } + + @variant([3, 4]) + class Enum1 extends Super { + @field({ type: "u8" }) + public b: number; + + constructor(b: number) { + super(); + this.b = b; + } + } + + const instance = new Enum1(5); + // validate([Enum0, Enum1, Super, SuperSuper]); + expect(getSchema(Enum0)).toBeDefined(); + expect(getSchema(Enum1)).toBeDefined(); + const serialized = serialize(instance); + expect(serialized).toEqual(Buffer.from([1, 2, 3, 4, 5])); + + const deserialied = deserialize( + Buffer.from(serialized), + SuperSuper, + false, + BinaryReader + ); + expect(deserialied).toBeInstanceOf(Enum1); + expect((deserialied as Enum1).b).toEqual(5); + }); + + test("inheritance without variant", () => { + class Super {} + class A extends Super { + @field({ type: "u8" }) + public a: number; + } + class B extends A { + @field({ type: "u8" }) + public b: number; + + constructor(opts?: { a: number; b: number }) { + super(); + if (opts) { + Object.assign(this, opts); + } + } + } + @variant(0) + class C1 extends B { + constructor(opts?: { a: number; b: number }) { + super(); + if (opts) { + Object.assign(this, opts); + } + } + } + @variant(1) + class C2 extends B {} + + validate(Super); + + const serialized = serialize(new C1({ a: 1, b: 2 })); + expect(serialized).toEqual(Buffer.from([1, 2, 0])); + + const deserialied = deserialize( + Buffer.from(serialized), + Super, + false, + BinaryReader + ); + expect(deserialied).toBeInstanceOf(C1); + expect((deserialied as C1).a).toEqual(1); + expect((deserialied as C1).b).toEqual(2); + }); + test("wrapped enum", () => { class Super {} @@ -418,7 +514,7 @@ describe("enum", () => { } } const instance = new TestStruct(new Enum2(3)); - validate([Enum2, TestStruct]); + validate(Super); expect(getSchema(Enum2)).toBeDefined(); expect(getSchema(TestStruct)).toBeDefined(); const serialized = serialize(instance); @@ -467,7 +563,7 @@ describe("enum", () => { } } const instance = new TestStruct(new Enum1(5)); - validate([Enum0, Enum1, TestStruct]); + validate(Super); expect(getSchema(Enum1)).toBeDefined(); expect(getSchema(TestStruct)).toBeDefined(); const serialized = serialize(instance); @@ -481,6 +577,43 @@ describe("enum", () => { expect(deserialied.enum).toBeInstanceOf(Enum1); expect((deserialied.enum as Enum0).a).toEqual(5); }); + + test("enum string variant", () => { + class Ape { + @field({ type: "String" }) + name: string; + + constructor(name?: string) { + this.name = name; + } + } + + @variant("🦍") + class Gorilla extends Ape {} + + @variant("🦧") + class Orangutan extends Ape {} + + class HighCouncil { + @field({ type: vec(Ape) }) + members: Ape[]; + constructor(members?: Ape[]) { + if (members) { + this.members = members; + } + } + } + + let bytes = serialize( + new HighCouncil([new Gorilla("Go"), new Orangutan("Ora")]) + ); + let deserialized = deserialize(Buffer.from(bytes), HighCouncil); + expect(deserialized).toBeInstanceOf(HighCouncil); + expect(deserialized.members[0]).toBeInstanceOf(Gorilla); + expect(deserialized.members[0].name).toEqual("Go"); + expect(deserialized.members[1]).toBeInstanceOf(Orangutan); + expect(deserialized.members[1].name).toEqual("Ora"); + }); }); describe("option", () => { @@ -492,7 +625,7 @@ describe("option", () => { this.a = a; } } - validate([TestStruct]); + validate(TestStruct); const expectedResult: StructKind = new StructKind({ fields: [ { @@ -528,7 +661,7 @@ describe("option", () => { this.a = a; } } - validate([TestStruct]); + validate(TestStruct); const expectedResult: StructKind = new StructKind({ fields: [ { @@ -578,7 +711,7 @@ describe("override", () => { } } - validate([TestStruct]); + validate(TestStruct); const serialized = serialize(new TestStruct({ a: 2, b: 3 })); const deserialied = deserialize( Buffer.from(serialized), @@ -606,7 +739,7 @@ describe("order", () => { this.b = b; } } - validate([TestStruct]); + validate(TestStruct); const expectedResult: StructKind = new StructKind({ fields: [ { @@ -639,7 +772,7 @@ describe("order", () => { public a: number; } const thrower = (): void => { - validate([TestStruct]); + validate(TestStruct); }; // Error is thrown since 1 field with index 1 is undefined behaviour @@ -655,7 +788,7 @@ describe("order", () => { public b: number; } const thrower = (): void => { - validate([TestStruct]); + validate(TestStruct); }; // Error is thrown since missing field with index 1 @@ -671,7 +804,9 @@ describe("order", () => { @field({ type: "u8" }) public b: number; } - const schema: StructKind = validate([TestStruct]).get(TestStruct); + validate(TestStruct); + const schema = getSchema(TestStruct); + const expectedResult: StructKind = new StructKind({ fields: [ { @@ -700,7 +835,7 @@ describe("Validation", () => { } const bytes = Uint8Array.from([1, 0]); // has an extra 0 - validate([TestStruct]); + validate(TestStruct); expect(() => deserialize(Buffer.from(bytes), TestStruct, false) ).toThrowError(BorshError); @@ -729,6 +864,102 @@ describe("Validation", () => { expect(() => classDef()).toThrowError(BorshError); }); + test("variant type conflict", () => { + class Super { + constructor() {} + } + @variant([0, 0]) // Same as B + class A extends Super { + constructor() { + super(); + } + } + + @variant(0) // Same as A + class B extends Super { + constructor() { + super(); + } + } + expect(() => validate(Super)).toThrowError(BorshError); + }); + + test("variant type conflict inheritance", () => { + class SuperSuper {} + + class Super extends SuperSuper {} + + @variant([0, 0]) // Same as B + class A extends Super { + constructor() { + super(); + } + } + + @variant(0) // Same as A + class B extends SuperSuper { + constructor() { + super(); + } + } + expect(() => validate(SuperSuper)).toThrowError(BorshError); + }); + + test("variant type conflict array length", () => { + class Super {} + + @variant([0, 0]) // Same as B + class A extends Super { + constructor() { + super(); + } + } + + @variant([0]) // Same as A + class B extends Super { + constructor() { + super(); + } + } + expect(() => validate(Super)).toThrowError(BorshError); + }); + + test("error for non optimized code", () => { + class TestStruct { + constructor() {} + } + + class A extends TestStruct { + @field({ type: "String" }) + string: string; + } + + class B extends TestStruct { + @field({ type: "String" }) + string: string; + } + expect(() => validate(TestStruct)).toThrowError(BorshError); + }); + + test("error for non optimized code on deserialization", () => { + class TestStruct { + constructor() {} + } + + class A extends TestStruct { + @field({ type: "String" }) + string: string = "A"; + } + + class B extends TestStruct { + @field({ type: "String" }) + string: string = "B"; + } + expect(() => + deserialize(Buffer.from(serialize(new A())), TestStruct) + ).toThrowError(BorshError); + }); + test("variant conflict, indices", () => { const classDef = () => { class TestStruct { @@ -767,62 +998,8 @@ describe("Validation", () => { this.missing = missing; } } - expect(() => validate([TestStruct])).toThrowError(BorshError); - validate([TestStruct], true); // Should be ok since we allow undefined - }); - - test("missing variant", () => { - class Super {} - - @variant(0) - class Enum0 extends Super { - constructor() { - super(); - } - } - class TestStruct { - @field({ type: Super }) - public missing: Super; - - constructor(missing?: Super) { - this.missing = missing; - } - } - validate([TestStruct]); - expect(getSchema(Enum0)).toBeDefined(); - expect(getSchema(TestStruct)).toBeDefined(); - expect(getSchema(Super)).toBeUndefined(); - }); - - test("missing variant one off", () => { - class Super {} - @variant(0) - class Enum0 extends Super { - constructor() { - super(); - } - } - - @variant(1) - class Enum1 extends Super { - constructor() { - super(); - } - } - class TestStruct { - @field({ type: Super }) - public missing: Super; - - constructor(missing?: Super) { - this.missing = missing; - } - } - - validate([TestStruct]); - expect(getSchema(Enum0)).toBeDefined(); - expect(getSchema(Enum1)).toBeDefined(); - expect(getSchema(TestStruct)).toBeDefined(); - expect(getSchema(Super)).toBeUndefined(); + expect(() => validate(TestStruct)).toThrowError(BorshError); + validate(TestStruct, true); // Should be ok since we allow undefined }); test("valid dependency", () => { @@ -842,6 +1019,6 @@ describe("Validation", () => { this.missing = missing; } } - validate([TestStruct]); + validate(TestStruct); }); }); diff --git a/src/index.ts b/src/index.ts index 4e0f20b9..bb1d16f1 100644 --- a/src/index.ts +++ b/src/index.ts @@ -8,6 +8,7 @@ import { SimpleField, CustomField, extendingClasses, + Constructor, } from "./types"; import { BorshError } from "./error"; import { BinaryWriter, BinaryReader } from "./binary"; @@ -87,18 +88,39 @@ export function serializeStruct( return; } - const structSchema = getSchema(obj.constructor) - if (!structSchema) { + + // Serialize content as struct, we do not invoke serializeStruct since it will cause circular calls to this method + const structSchemas = getSchemasBottomUp(obj.constructor); + + // If Schema has fields, "structSchema" will be non empty and "fields" will exist + if (structSchemas.length == 0) { throw new BorshError(`Class ${obj.constructor.name} is missing in schema`); } - if (structSchema instanceof StructKind) { - structSchema.fields.map((field) => { - serializeField(field.key, obj[field.key], field.type, writer); - }); - } else { - throw new BorshError(`Unexpected schema for ${obj.constructor.name}`); - } + structSchemas.forEach((v) => { + if (v.schema instanceof StructKind) { + const index = v.schema.variant; + if (index != undefined) { + if (typeof index === "number") { + writer.writeU8(index); + } else if (Array.isArray(index)) { + index.forEach((i) => { + writer.writeU8(i); + }); + } + else { // is string + writer.writeString(index); + } + } + + v.schema.fields.map((field) => { + serializeField(field.key, obj[field.key], field.type, writer); + }); + } else { + throw new BorshError(`Unexpected schema for ${obj.constructor.name}`); + } + }) + } /// Serialize given object using schema of the form: @@ -164,73 +186,119 @@ function deserializeStruct(clazz: any, reader: BinaryReader) { return clazz.borshDeserialize(reader); } - let structSchema = getSchema(clazz);//schema.get(clazz); - let idx = undefined; + const result: { [key: string]: any } = {}; + + // assume clazz is super class + if (getVariantIndex(clazz) !== undefined) { + // It is an (stupid) enum, but we deserialize into its variant directly + // This means we should omit the variant index + let index = getVariantIndex(clazz); + if (typeof index === "number") { + reader.readU8(); + } else if (Array.isArray(index)) { + for (const _ of index) { + reader.readU8(); + } + } + else { // string + reader.readString(); + } + } + - if (!structSchema) { - // We find the deserialization schema from one of the subclasses + // Polymorphic serialization, i.e. reversed prototype iteration using descriminators + let once = false; + let currClazz = clazz; + while ((getSchema(currClazz) || getDependencies(currClazz).size > 0)) { - // it must be an enum - idx = [reader.readU8()]; + let structSchema = getSchema(currClazz); - // Try polymorphic deserialziation (i.e. get all subclasses and find best - // class this can be deserialized to) + once = true; + let variantsIndex: number[] = undefined; + let variantString: string = undefined; + let nextClazz = undefined; + let dependencies = getNonTrivialDependencies(currClazz); + if (structSchema) { + for (const field of structSchema.fields) { + result[field.key] = deserializeField( + field.key, + field.type, + reader + ); + } + } // We know that we should serialize into the variant that accounts to the first byte of the read - for (const [_key, actualClazz] of getDependenciesRecursively(clazz)) { + for (const [_key, actualClazz] of dependencies) { const variantIndex = getVariantIndex(actualClazz); if (variantIndex !== undefined) { if (typeof variantIndex === "number") { - if (variantIndex == idx[0]) { - clazz = actualClazz; - structSchema = getSchema(clazz); + + if (!variantsIndex) { + variantsIndex = [reader.readU8()]; + } + if (variantIndex == variantsIndex[0]) { + nextClazz = actualClazz; break; } - } // variant is array, check all values - else { - while (idx.length < variantIndex.length) { - idx.push(reader.readU8()); + } + else if (Array.isArray(variantIndex)) { // variant is array, check all values + + if (!variantsIndex) { + variantsIndex = []; + while (variantsIndex.length < variantIndex.length) { + variantsIndex.push(reader.readU8()); + } } + // Compare variants if ( - idx.length === variantIndex.length && - idx.every((value, index) => value === variantIndex[index]) + variantsIndex.length === variantIndex.length && + (variantsIndex as number[]).every((value, index) => value === variantIndex[index]) ) { - clazz = actualClazz; - structSchema = getSchema(clazz); + nextClazz = actualClazz; break; } } + else { // is string + if (variantString == undefined) { + variantString = reader.readString(); + } + // Compare variants is just string compare + if ( + variantString === variantIndex + ) { + nextClazz = actualClazz; + break; + } + } } } - if (!structSchema) - throw new BorshError(`Class ${clazz.name} is missing in schema`); - } else if (getVariantIndex(clazz) !== undefined) { - // It is an enum, but we deserialize into its variant directly - // This means we should omit the variant index - let index = getVariantIndex(clazz); - if (typeof index === "number") { - reader.readU8(); - } else { - for (const _ of index) { - reader.readU8(); + if (nextClazz == undefined) { + // do a recursive call and copy result, + // this is not computationally performant since we are going to traverse multiple path + // and possible do deserialziation on bad paths + if (dependencies.size == 1) // still deterministic + nextClazz = dependencies.values().next().value; + else if (dependencies.size > 1) { + const classes = [...dependencies.values()].map((f) => f.name).join(', ') + throw new BorshError(`Multiple ambigious deserialization paths from ${currClazz.name} found: ${classes}. This is not allowed, and would not be performant if allowed`) } } - } - if (structSchema instanceof StructKind) { - const result: { [key: string]: any } = {}; - for (const field of getSchema(clazz).fields) { - result[field.key] = deserializeField( - field.key, - field.type, - reader - ); + if (nextClazz == undefined) { + break; } - return Object.assign(new clazz(), result); + currClazz = nextClazz; + /* if (!structSchema) + throw new BorshError(`Class ${clazz.name} is missing in schema`); */ + } + if (!once) { + throw new BorshError(`Unexpected schema ${clazz.constructor.name}`); } - throw new BorshError(`Unexpected schema ${clazz.constructor.name}`); + return Object.assign(new currClazz(), result); + } /** @@ -282,11 +350,20 @@ const getOrCreateStructMeta = (clazz: any): StructKind => { schema } */ } +const setDependencyToProtoType = (ctor: Function) => { + let proto = Object.getPrototypeOf(ctor); + if (proto.prototype?.constructor != undefined) + setDependency(proto, ctor); +} const setDependency = (ctor: Function, dependency: Function) => { let dependencies = getDependencies(ctor); let key = JSON.stringify(getVariantIndex(dependency)); if (key != undefined && dependencies.has(key)) { + if (dependencies.get(key) == dependency) { + // already added; + return; + } throw new BorshError(`Conflicting variants: Dependency ${dependencies.get(key).name} and ${dependency.name} share same variant index(es)`) } if (key == undefined) { @@ -303,7 +380,7 @@ const setDependency = (ctor: Function, dependency: Function) => { key = ctor.name + "/" + dependency.name; } dependencies.set(key, dependency); - ctor.prototype._borsh_dependency = dependencies; + setDependencies(ctor, dependencies); } const hasDependencies = (ctor: Function, schema: Map): boolean => { @@ -319,10 +396,40 @@ const hasDependencies = (ctor: Function, schema: Map): boolean return true; } +const getDependencyKey = (ctor: Function) => "_borsh_dependency_" + ctor.name + const getDependencies = (ctor: Function): Map => { - return ctor.prototype._borsh_dependency ? ctor.prototype._borsh_dependency : new Map(); + let existing = ctor.prototype.constructor[getDependencyKey(ctor)] + if (existing) + return existing; + return new Map(); } +const getNonTrivialDependencies = (ctor: Function): Map => { + let ret = new Map(); + let existing = ctor.prototype.constructor[getDependencyKey(ctor)] as Map; + if (existing) + existing.forEach((v, k) => { + let schema = getSchema(v); + if (schema.fields.length > 0 || schema.variant != undefined) { // non trivial + ret.set(k, v); + } + else { // check recursively + let req = getNonTrivialDependencies(v); + req.forEach((rv, rk) => { + ret.set(rk, rv); + }) + } + + }); + return ret; +} + +const setDependencies = (ctor: Function, dependencies: Map): Map => { + return ctor.prototype.constructor[getDependencyKey(ctor)] = dependencies +} + + /** * Flat map class inheritance tree into hashmap where key represents variant key * @param ctor @@ -344,51 +451,46 @@ const getDependenciesRecursively = (ctor: Function, mem: Map = const setSchema = (ctor: Function, schema: StructKind) => { - ctor.prototype._borsh_schema = schema; + + ctor.prototype.constructor["_borsh_schema_" + ctor.name] = schema } export const getSchema = (ctor: Function): StructKind => { - return ctor.prototype._borsh_schema + if (ctor.prototype == undefined) { + const t = 123; + } + return ctor.prototype.constructor["_borsh_schema_" + ctor.name]; } +export const getSchemasBottomUp = (ctor: Function): { clazz: Function, schema: StructKind }[] => { + let schemas: { clazz: Function, schema: StructKind }[] = []; + while (ctor.prototype != undefined) { + let schema = getSchema(ctor); + if (schema) + schemas.push({ + clazz: ctor, + schema + }); + ctor = Object.getPrototypeOf(ctor); + } + return schemas.reverse(); + +} + + + /** * * @param kind 'struct' or 'variant. 'variant' equivalnt to Rust Enum * @returns Schema decorator function for classes */ -export const variant = (index: number | number[]) => { +export const variant = (index: number | number[] | string) => { return (ctor: Function) => { - getOrCreateStructMeta(ctor); + let schema = getOrCreateStructMeta(ctor); // Create a custom serialization, for enum by prepend instruction index - ctor.prototype.borshSerialize = function ( - writer: BinaryWriter - ) { - if (typeof index === "number") { - writer.writeU8(index); - } else { - index.forEach((i) => { - writer.writeU8(i); - }); - } + schema.variant = index; - // Serialize content as struct, we do not invoke serializeStruct since it will cause circular calls to this method - const structSchema: StructKind = getSchema(ctor); - - // If Schema has fields, "structSchema" will be non empty and "fields" will exist - if (structSchema?.fields) - for (const field of structSchema.fields) { - serializeField( - field.key, - this[field.key], - field.type, - writer - ); - } - }; - ctor.prototype._borsh_variant_index = function () { - return index; // creates a function that returns the variant index on the class - }; // Define Schema for this class, even though it might miss fields since this is a variant const clazzes = extendingClasses(ctor); let prev = ctor; @@ -401,10 +503,8 @@ export const variant = (index: number | number[]) => { }; }; -export const getVariantIndex = (clazz: any): number | number[] | undefined => { - if (clazz.prototype._borsh_variant_index) - return clazz.prototype._borsh_variant_index(); - return undefined; +export const getVariantIndex = (clazz: any): number | number[] | string | undefined => { + return getOrCreateStructMeta(clazz).variant; }; /** @@ -413,7 +513,7 @@ export const getVariantIndex = (clazz: any): number | number[] | undefined => { */ export function field(properties: SimpleField | CustomField) { return (target: {} | any, name?: PropertyKey): any => { - + setDependencyToProtoType(target.constructor); const schema = getOrCreateStructMeta(target.constructor); const key = name.toString(); @@ -454,74 +554,70 @@ export function field(properties: SimpleField | CustomField) { * @param validate, run validation? * @returns Schema map */ -export const validate = (clazzes: any[], allowUndefined = false) => { +export const validate = (clazzes: Constructor | Constructor[], allowUndefined = false) => { return validateIterator(clazzes, allowUndefined, new Set()); }; -const validateIterator = (clazzes: any[], allowUndefined: boolean, visited: Set) => { +const validateIterator = (clazzes: Constructor | Constructor[], allowUndefined: boolean, visited: Set) => { + clazzes = Array.isArray(clazzes) ? clazzes : [clazzes]; let schemas = new Map(); - let dependencies = new Set(); - clazzes.forEach((clazz) => { - visited.add(clazz.name); - const schema = getSchema(clazz); - if (schema) { - - schemas.set(clazz, schema); - - // By field - schema.getDependencies().forEach((depenency) => { - dependencies.add(depenency); - }); + clazzes.forEach((clazz, ix) => { + while (Object.getPrototypeOf(clazz).prototype != undefined) { + clazz = Object.getPrototypeOf(clazz); } - // Class dependencies (inheritance) - getDependenciesRecursively(clazz).forEach((dependency) => { - if (clazzes.find(c => c == dependency) == undefined) { - dependencies.add(dependency); + let dependencies = getDependenciesRecursively(clazz); + dependencies.set('_', clazz); + dependencies.forEach((v, k) => { + const schema = getSchema(v); + if (!schema) { + return; } - }) + schemas.set(v, schema); + visited.add(v.name); - }); - let filteredDependencies: Function[] = []; - dependencies.forEach((dependency) => { - if (visited.has(dependency.name)) { - return; - } - filteredDependencies.push(dependency); - visited.add(dependency.name); - }) + }); + let lastVariant: number | number[] | string = undefined; + let lastKey: string = undefined; + getNonTrivialDependencies(clazz).forEach((dependency, key) => { + if (!lastVariant) + lastVariant = getVariantIndex(dependency); + else if (!validateVariantAreCompatible(lastVariant, getVariantIndex(dependency))) { + throw new BorshError(`Class ${dependency.name} is extended by classes with variants of different types. Expecting only one of number, number[] or string`) + } - // Generate schemas for nested types - filteredDependencies.forEach((dependency) => { - if (!schemas.has(dependency)) { - const dependencySchema = validateIterator([dependency], allowUndefined, visited); - dependencySchema.forEach((value, key) => { - schemas.set(key, value); - }); - } - }); - schemas.forEach((structSchema, clazz) => { - if (!structSchema.fields && !hasDependencies(clazz, schemas)) { - throw new BorshError("Missing schema for class " + clazz.name); - } - structSchema.fields.forEach((field) => { - if (!field) { - throw new BorshError( - "Field is missing definition, most likely due to field indexing with missing indices" - ); + if (lastKey != undefined && lastVariant == undefined) { + throw new BorshError(`Classes inherit ${clazz} and are introducing new field without introducing variants. This leads to unoptimized deserialization`) } - if (allowUndefined) { - return; + lastKey = key; + }) + + schemas.forEach((structSchema, clazz) => { + if (!structSchema.fields && !hasDependencies(clazz, schemas)) { + throw new BorshError("Missing schema for class " + clazz.name); } - if (field.type instanceof Function) { - if (!schemas.has(field.type) && !hasDependencies(field.type, schemas)) { - throw new BorshError("Unknown field type: " + field.type.name); + structSchema.fields.forEach((field) => { + if (!field) { + throw new BorshError( + "Field is missing definition, most likely due to field indexing with missing indices" + ); } - } - }); - }) - return schemas; + if (allowUndefined) { + return; + } + if (field.type instanceof Function) { + if (!getSchema(field.type) && !hasDependencies(field.type, schemas)) { + throw new BorshError("Unknown field type: " + field.type.name); + } + + // Validate field + validateIterator(field.type, allowUndefined, visited); + } + }); + }) + }); + } @@ -530,3 +626,15 @@ const resize = (arr: Array, newSize: number, defaultValue: any) => { while (newSize > arr.length) arr.push(defaultValue); arr.length = newSize; }; + +const validateVariantAreCompatible = (a: number | number[] | string, b: number | number[] | string) => { + if (typeof a != typeof b) { + return false; + } + if (Array.isArray(a) && Array.isArray(b)) { + if (a.length != b.length) { + return false; + } + } + return true; +} \ No newline at end of file diff --git a/src/types.ts b/src/types.ts index 8a952e12..e6fa3fb0 100644 --- a/src/types.ts +++ b/src/types.ts @@ -89,10 +89,12 @@ export interface Field { } export class StructKind { + variant?: number | number[] | string fields: Field[]; - constructor(properties?: { fields: Field[] }) { + constructor(properties?: { variant?: number | number[] | string, fields: Field[] }) { if (properties) { this.fields = properties.fields; + this.variant = properties.variant; } else { this.fields = []; }