Skip to content

Commit

Permalink
Support resolving variants for deep inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
marcus-pousette committed Jun 16, 2022
1 parent ff46fc7 commit e944512
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 5 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@dao-xyz/borsh",
"version": "2.0.6",
"version": "2.0.7",
"readme": "README.md",
"homepage": "https://github.com/dao-xyz/borsh-ts#README",
"description": "Binary Object Representation Serializer for Hashing",
Expand Down
48 changes: 48 additions & 0 deletions src/__tests__/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,54 @@ describe("enum", () => {
expect((deserialied.enum as Enum1).b).toEqual(4);
});

test("extended enum", () => {
class SuperSuper {}

class Super extends SuperSuper {
constructor() {
super();
}
}

@variant(0)
class Enum0 extends Super {
@field({ type: "u8" })
public a: number;

constructor(a: number) {
super();
this.a = a;
}
}

@variant(1)
class Enum1 extends Super {
@field({ type: "u8" })
public b: number;

constructor(b: number) {
super();
this.b = b;
}
}

const instance = new Enum1(4);
// validate([Enum0, Enum1, Super, SuperSuper]);
expect(getSchema(Enum0)).toBeDefined();
expect(getSchema(Enum1)).toBeDefined();
const serialized = serialize(instance);
expect(serialized).toEqual(Buffer.from([1, 4]));

const deserialied = deserialize(
Buffer.from(serialized),
SuperSuper,
false,
BinaryReader
);
expect(deserialied).toBeInstanceOf(Enum1);
expect((deserialied as Enum1).b).toEqual(4);
});

test("wrapped enum", () => {
class Super {}

Expand Down
39 changes: 35 additions & 4 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ function deserializeStruct(clazz: any, reader: BinaryReader) {
// class this can be deserialized to)

// We know that we should serialize into the variant that accounts to the first byte of the read
for (const [_key, actualClazz] of getDependencies(clazz)) {
for (const [_key, actualClazz] of getDependenciesRecursively(clazz)) {
const variantIndex = getVariantIndex(actualClazz);
if (variantIndex !== undefined) {
if (typeof variantIndex === "number") {
Expand Down Expand Up @@ -286,9 +286,22 @@ const getOrCreateStructMeta = (clazz: any): StructKind => {
const setDependency = (ctor: Function, dependency: Function) => {
let dependencies = getDependencies(ctor);
let key = JSON.stringify(getVariantIndex(dependency));
if (dependencies.has(key)) {
if (key != undefined && dependencies.has(key)) {
throw new BorshError(`Conflicting variants: Dependency ${dependencies.get(key).name} and ${dependency.name} share same variant index(es)`)
}
if (key == undefined) {
/**
* Class is not a variant but a "bridging class" i.e
* class A {}
* class B extends A {}
*
* @variant(0)
* class C extends B {}
*
* class B has no variant even though A is a dependency on it, so it gets the key "A/B" instead
*/
key = ctor.name + "/" + dependency.name;
}
dependencies.set(key, dependency);
ctor.prototype._borsh_dependency = dependencies;
}
Expand All @@ -298,7 +311,7 @@ const hasDependencies = (ctor: Function, schema: Map<any, StructKind>): boolean
return false
}

for (const [_key, dependency] of getDependencies(ctor)) {
for (const [_key, dependency] of getDependenciesRecursively(ctor)) {
if (!schema.has(dependency)) {
return false;
}
Expand All @@ -310,6 +323,24 @@ const getDependencies = (ctor: Function): Map<string, Function> => {
return ctor.prototype._borsh_dependency ? ctor.prototype._borsh_dependency : new Map();
}

/**
* Flat map class inheritance tree into hashmap where key represents variant key
* @param ctor
* @param mem
* @returns a map of dependencies
*/
const getDependenciesRecursively = (ctor: Function, mem: Map<string, Function> = new Map()): Map<string, Function> => {
let dep = getDependencies(ctor);
for (const [key, f] of dep) {
if (mem.has(key)) {
continue;
}
mem.set(key, f);
getDependenciesRecursively(f, mem);
}
return mem
}



const setSchema = (ctor: Function, schema: StructKind) => {
Expand Down Expand Up @@ -443,7 +474,7 @@ const validateIterator = (clazzes: any[], allowUndefined: boolean, visited: Set<
});
}
// Class dependencies (inheritance)
getDependencies(clazz).forEach((dependency) => {
getDependenciesRecursively(clazz).forEach((dependency) => {
if (clazzes.find(c => c == dependency) == undefined) {
dependencies.add(dependency);
}
Expand Down

0 comments on commit e944512

Please sign in to comment.