Skip to content

Commit

Permalink
migrate trainee operations to Trainee class
Browse files Browse the repository at this point in the history
  • Loading branch information
fulpm committed Oct 7, 2024
1 parent 63cda70 commit eb02a40
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 1,406 deletions.
1 change: 1 addition & 0 deletions codegen/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export interface LabelDefinition {
parameters: Record<string, Schema | Ref> | null;
returns?: SchemaType | Schema | Ref | null;
description?: string | null;
use_active_session?: boolean;
attribute?: boolean;
long_running?: boolean;
read_only?: boolean;
Expand Down
17 changes: 13 additions & 4 deletions codegen/generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,19 @@ export class Generator {
react_into_features: "null",
};

// TODO - Remove once #21784 is merged
for (const label of [
"train",
"impute",
"clear_imputed_data",
"edit_cases",
"move_cases",
"add_feature",
"remove_feature",
]) {
this.doc.labels[label].use_active_session = true;
}

// Setup template engine
const loader = new nunjucks.FileSystemLoader(path.join(__dirname, "templates"));
this.env = new nunjucks.Environment(loader, { throwOnUndefined: true });
Expand All @@ -73,10 +86,6 @@ export class Generator {
targetLabels[label] = definition;
}
}
this.renderFile(this.clientDir, "AbstractBaseClient.ts", "client/AbstractBaseClient.njk", {
labels: targetLabels,
shims: this.responseShims,
});
this.renderFile(this.engineDir, "Trainee.ts", "engine/Trainee.njk", {
labels: targetLabels,
shims: this.responseShims,
Expand Down
187 changes: 0 additions & 187 deletions codegen/templates/client/AbstractBaseClient.njk

This file was deleted.

31 changes: 27 additions & 4 deletions codegen/templates/engine/Trainee.njk
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
*/
import { AbstractBaseClient } from "../client/AbstractBaseClient";
import { batcher, BatchOptions } from "../client/utilities";
import type { BaseTrainee, ClientBatchResponse } from "../types";
import type { BaseTrainee, ClientBatchResponse, ClientResponse } from "../types";
import type * as schemas from "../types/schemas";
import type * as shims from "../types/shims";

/**
* The interface for interacting with a Trainee. Should not be instantiated directly. Instead create or request a
Expand Down Expand Up @@ -93,7 +94,7 @@ export class Trainee implements BaseTrainee {
async function* (this: Trainee, size: number) {
let offset = 0;
while (offset < cases.length) {
const response = await this.client.train(this.id, {
const response = await this.train({
...rest,
cases: cases.slice(offset, offset + size),
});
Expand All @@ -115,6 +116,20 @@ export class Trainee implements BaseTrainee {
return { payload: { num_trained, status, ablated_indices }, warnings };
}

/**
* Include the active session in a request if not defined.
* @param request The Trainee request object.
* @returns The Trainee request object with a session.
*/
protected async includeSession<T extends Record<string, any>>(request: T): Promise<T> {
if (!request.session) {
// Include the active session
const session = await this.client.getActiveSession();
return { ...request, session: session.id };
}
return request;
}

{% for label, def in labels | dictsort %}
{%- set requestName = "schemas." + label | pascalCase + "Request" %}
{%- if def.returns | isRef %}
Expand All @@ -128,8 +143,16 @@ export class Trainee implements BaseTrainee {
* @param request The operation parameters.
* @returns The response of the operation, including any warnings.
*/
public async {{ label | camelCase }}({% if def.parameters | length %}request: {{ requestName }}{% endif %}) {
return this.client.{{ label | camelCase }}(this.id, {% if def.parameters | length %}request{% endif %});
public async {{ label | camelCase }}({% if def.parameters | length %}request: {{ requestName }}{% endif %}): Promise<ClientResponse<{% if def.returns %}{{ responseName }}{% elif label in shims %}{{ shims[label] }}{% else %}any{% endif %}>> {
await this.client.autoResolveTrainee(this.id);
{%- if def.use_active_session and def.parameters %}
request = await this.includeSession(request);
{%- endif %}
const response = await this.client.execute<{% if def.returns %}{{ responseName }}{% elif label in shims %}{{ shims[label] }}{% else %}any{% endif %}>(this.id, "{{ label }}", {% if def.parameters | length %}request{% else %}{}{% endif %});
{%- if not def.read_only %}
this.client.autoPersistTrainee(this.id);
{%- endif %}
return { payload: response.payload, warnings: response.warnings };
}
{%- if not loop.last %}
{% endif %}
Expand Down
Loading

0 comments on commit eb02a40

Please sign in to comment.