Skip to content

Commit

Permalink
add response shims
Browse files Browse the repository at this point in the history
  • Loading branch information
fulpm committed Sep 27, 2024
1 parent d19ff7c commit 24161b6
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 15 deletions.
1 change: 1 addition & 0 deletions codegen/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export interface LabelDefinition {
parameters: Record<string, Schema | Ref> | null;
returns?: SchemaType | Schema | Ref | null;
description?: string | null;
attribute?: boolean;
long_running?: boolean;
read_only?: boolean;
idempotent?: boolean;
Expand Down
13 changes: 9 additions & 4 deletions codegen/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export class Generator {
schemaDir: string;
clientDir: string;
ignoredLabels: string[];
responseShims: string[];

/**
* Construct a new Generator.
Expand All @@ -37,6 +38,9 @@ export class Generator {
"get_api",
];

// Temporary shims until return values are defined
this.responseShims = ["get_cases", "react", "train"];

// Setup template engine
const loader = new nunjucks.FileSystemLoader(path.join(__dirname, "templates"));
this.env = new nunjucks.Environment(loader, { throwOnUndefined: true });
Expand All @@ -56,13 +60,14 @@ export class Generator {
*/
private renderClient() {
const targetLabels: Record<string, LabelDefinition> = {};
for (const [label, value] of Object.entries(this.doc.labels)) {
if (!this.ignoredLabels.includes(label)) {
targetLabels[label] = value;
for (const [label, definition] of Object.entries(this.doc.labels)) {
if (!this.ignoredLabels.includes(label) || definition.attribute) {
targetLabels[label] = definition;
}
}
this.renderFile(this.clientDir, "trainee.ts", "client/trainee.njk", {
labels: targetLabels,
shims: this.responseShims,
});
}

Expand All @@ -87,7 +92,7 @@ export class Generator {

// Render label schemas
for (const [label, definition] of Object.entries(this.doc.labels)) {
if (this.ignoredLabels.includes(label)) continue;
if (this.ignoredLabels.includes(label) || definition.attribute) continue;
// Add schemas for label parameters and/or return value if it has any
if (definition.parameters != null || definition.returns != null) {
const title = toPascalCase(label);
Expand Down
6 changes: 4 additions & 2 deletions codegen/templates/client/trainee.njk
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
import type { Session, Trainee } from "@/types";
import type * as schemas from "@/types/schemas";
import type * as shims from "@/types/shims";
import { AbstractHowsoClient } from "./base";

export interface LabelResponse<R = unknown> {
Expand Down Expand Up @@ -52,15 +53,16 @@ export abstract class TraineeClient extends AbstractHowsoClient {
{% for label, def in labels | dictsort %}
{% set requestName = "schemas." + label | pascalCase + "Request" %}
{% set responseName = "schemas." + label | pascalCase + "Response" %}
{% set shimName = "shims." + label | pascalCase + "Response" %}
/**
* {{ def.description | capitalize | blockComment | safe | indent(2) }}
* @param traineeId The Trainee identifier.
* @param request The operation parameters.
* @returns The response of the operation, including any warnings.
*/
public async {{ label | camelCase }}(traineeId: string{% if def.parameters %}, request: {{ requestName }}{% endif %}): Promise<LabelResponse<{% if def.returns %}{{ responseName }}{% else %}any{% endif %}>> {
public async {{ label | camelCase }}(traineeId: string{% if def.parameters %}, request: {{ requestName }}{% endif %}): Promise<LabelResponse<{% if def.returns %}{{ responseName }}{% elif label in shims %}{{ shimName }}{% else %}any{% endif %}>> {
const trainee = await this.autoResolveTrainee(traineeId);
const response = await this.execute<{% if def.returns %}{{ responseName }}{% else %}any{% endif %}>(trainee.id, "{{ label }}", {% if def.parameters %}request{% else %}{}{% endif %});
const response = await this.execute<{% if def.returns %}{{ responseName }}{% elif label in shims %}{{ shimName }}{% else %}any{% endif %}>(trainee.id, "{{ label }}", {% if def.parameters %}request{% else %}{}{% endif %});
{%- if not def.read_only %}
this.autoPersistTrainee(trainee.id);
{%- endif %}
Expand Down
16 changes: 10 additions & 6 deletions src/client/trainee.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
import type { Session, Trainee } from "@/types";
import type * as schemas from "@/types/schemas";
import type * as shims from "@/types/shims";
import { AbstractHowsoClient } from "./base";

export interface LabelResponse<R = unknown> {
Expand Down Expand Up @@ -261,9 +262,12 @@ export abstract class TraineeClient extends AbstractHowsoClient {
* @param request The operation parameters.
* @returns The response of the operation, including any warnings.
*/
public async getCases(traineeId: string, request: schemas.GetCasesRequest): Promise<LabelResponse<any>> {
public async getCases(
traineeId: string,
request: schemas.GetCasesRequest,
): Promise<LabelResponse<shims.GetCasesResponse>> {
const trainee = await this.autoResolveTrainee(traineeId);
const response = await this.execute<any>(trainee.id, "get_cases", request);
const response = await this.execute<shims.GetCasesResponse>(trainee.id, "get_cases", request);
return { payload: response.payload, warnings: response.warnings };
}

Expand Down Expand Up @@ -596,9 +600,9 @@ export abstract class TraineeClient extends AbstractHowsoClient {
* @param request The operation parameters.
* @returns The response of the operation, including any warnings.
*/
public async react(traineeId: string, request: schemas.ReactRequest): Promise<LabelResponse<any>> {
public async react(traineeId: string, request: schemas.ReactRequest): Promise<LabelResponse<shims.ReactResponse>> {
const trainee = await this.autoResolveTrainee(traineeId);
const response = await this.execute<any>(trainee.id, "react", request);
const response = await this.execute<shims.ReactResponse>(trainee.id, "react", request);
this.autoPersistTrainee(trainee.id);
return { payload: response.payload, warnings: response.warnings };
}
Expand Down Expand Up @@ -986,9 +990,9 @@ export abstract class TraineeClient extends AbstractHowsoClient {
* @param request The operation parameters.
* @returns The response of the operation, including any warnings.
*/
public async train(traineeId: string, request: schemas.TrainRequest): Promise<LabelResponse<any>> {
public async train(traineeId: string, request: schemas.TrainRequest): Promise<LabelResponse<shims.TrainResponse>> {
const trainee = await this.autoResolveTrainee(traineeId);
const response = await this.execute<any>(trainee.id, "train", request);
const response = await this.execute<shims.TrainResponse>(trainee.id, "train", request);
this.autoPersistTrainee(trainee.id);
return { payload: response.payload, warnings: response.warnings };
}
Expand Down
15 changes: 12 additions & 3 deletions src/client/worker/client.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { FeatureAttributesIndex, Session, Trainee } from "@/types";
import type { FeatureAttributesIndex, Session, Trainee, TrainResponse } from "@/types";
import type * as schemas from "@/types/schemas";
import {
AmalgamError,
Expand Down Expand Up @@ -540,31 +540,40 @@ export class HowsoWorkerClient extends TraineeClient {
* Train data into the Trainee using batched requests to the Engine.
* @param traineeId The Trainee identifier.
* @param request The train parameters.
* @returns The train result.
*/
public async batchTrain(traineeId: string, request: schemas.TrainRequest): Promise<void> {
public async batchTrain(traineeId: string, request: schemas.TrainRequest): Promise<TrainResponse> {
const trainee = await this.autoResolveTrainee(traineeId);
const { cases = [], ...rest } = request;

// WASM builds are currently sensitive to large request sizes and may throw memory errors,
// so we cap it to a smaller size for now
const batchOptions: BatchOptions = { startSize: 50, limits: [1, 50] };

let num_trained = 0;
let status = null;
const ablated_indices: number[] = [];

// Batch scale the requests
await batcher(
async function* (this: HowsoWorkerClient, size: number) {
let offset = 0;
while (offset < cases.length) {
await this.train(trainee.id, {
const { payload: response } = await this.train(trainee.id, {
...rest,
cases: cases.slice(offset, offset + size),
});
offset += size;
if (response.status) status = response.status;
if (response.num_trained) num_trained += response.num_trained;
if (response.ablated_indices) ablated_indices.push(...response.ablated_indices);
size = yield;
}
}.bind(this),
batchOptions,
);

await this.autoPersistTrainee(trainee.id);
return { num_trained, status, ablated_indices };
}
}
4 changes: 4 additions & 0 deletions src/types/shims/GetCases.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export type GetCasesResponse = {
features?: string[];
cases?: any[][];
};
58 changes: 58 additions & 0 deletions src/types/shims/React.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
export type ReactResponse = {
boundary_cases?: any[][];
categorical_action_probabilities?: Array<{ [key: string]: any }>;
derivation_parameters?: Array<{
k?: number;
p?: number;
distance_transform?: number;
feature_weights?: { [key: string]: number };
feature_deviations?: { [key: string]: number };
nominal_class_counts?: { [key: string]: number };
use_irw?: boolean;
}>;
feature_residuals_full?: Array<{ [key: string]: any }>;
feature_residuals_robust?: Array<{ [key: string]: any }>;
prediction_stats?: Array<{ [key: string]: any }>;
outlying_feature_values?: Array<{
[key: string]: {
input_case_value?: number;
local_max?: number;
};
}>;
influential_cases?: any[][];
most_similar_cases?: any[][];
observational_errors?: Array<{ [key: string]: number }>;
feature_mda_full?: Array<{ [key: string]: number }>;
feature_mda_robust?: Array<{ [key: string]: number }>;
feature_mda_ex_post_full?: Array<{ [key: string]: number }>;
feature_mda_ex_post_robust?: Array<{ [key: string]: number }>;
directional_feature_contributions_full?: Array<{ [key: string]: number }>;
directional_feature_contributions_robust?: Array<{ [key: string]: number }>;
feature_contributions_full?: Array<{ [key: string]: number }>;
feature_contributions_robust?: Array<{ [key: string]: number }>;
case_directional_feature_contributions_full?: Array<{ [key: string]: number }>;
case_directional_feature_contributions_robust?: Array<{ [key: string]: number }>;
case_feature_contributions_full?: Array<{ [key: string]: number }>;
case_feature_contributions_robust?: Array<{ [key: string]: number }>;
case_mda_full?: Array<Array<{ [key: string]: any }>>;
case_mda_robust?: Array<Array<{ [key: string]: any }>>;
case_contributions_full?: Array<Array<{ [key: string]: any }>>;
case_contributions_robust?: Array<Array<{ [key: string]: any }>>;
case_feature_residuals_full?: Array<{ [key: string]: number }>;
case_feature_residuals_robust?: Array<{ [key: string]: number }>;
case_feature_residual_convictions_full?: Array<{ [key: string]: number }>;
case_feature_residual_convictions_robust?: Array<{ [key: string]: number }>;
hypothetical_values?: Array<{ [key: string]: any }>;
distance_ratio?: Array<number>;
distance_ratio_parts?: Array<{
local_distance_contribution?: number | null;
nearest_distance?: number | null;
}>;
distance_contribution?: Array<number>;
similarity_conviction?: Array<number>;
most_similar_case_indices?: Array<Array<{ [key: string]: any }>>;
generate_attempts?: Array<number>;
series_generate_attempts?: Array<Array<number>>;
action_features?: Array<string> | null;
action_values?: Array<Array<any>> | null;
};
5 changes: 5 additions & 0 deletions src/types/shims/Train.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export type TrainResponse = {
ablated_indices?: number[];
num_trained: number;
status?: "analyze" | "analyzed" | null;
};
3 changes: 3 additions & 0 deletions src/types/shims/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { FeatureAttributes } from "../schemas";

export * from "./FeatureOriginalType";
export * from "./GetCases";
export * from "./React";
export * from "./Train";

export type FeatureAttributesIndex = { [key: string]: FeatureAttributes };

0 comments on commit 24161b6

Please sign in to comment.