Skip to content

Commit

Permalink
Ts/3630 interventions for calibrate (#3677)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom-Szendrey <Tom-Szendrey@users.noreply.github.com>
  • Loading branch information
Tom-Szendrey and Tom-Szendrey authored May 28, 2024
1 parent f3d0b14 commit d1b27c6
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ watch(
console.log('dill URL is', dillURL);
const forecastResponse = await makeForecastJobCiemss({
projectId: '',
modelConfigId: modelConfigId.value as string,
timespan: {
start: 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,16 @@ export async function getOptimizedInterventions(optimizeRunId: string) {
// Get the interventionPolicyGroups from the simulation object.
// This will prevent any inconsistencies being passed via knobs or state when matching with result file.
const simulation = await getSimulation(optimizeRunId);
const interventions = simulation?.executionPayload?.interventions;
const interventionType = interventions.selection ?? '';
const paramNames: string[] = interventions.param_names ?? [];
const paramValue: number[] = interventions.param_values ?? [];
const startTime: number[] = interventions.start_time ?? [];
const simulationIntervetions: SimulationIntervention[] =
simulation?.executionPayload.fixed_static_parameter_interventions ?? [];
const policyInterventions = simulation?.executionPayload?.policy_interventions;
const interventionType = policyInterventions.selection ?? '';
const paramNames: string[] = policyInterventions.param_names ?? [];
const paramValue: number[] = policyInterventions.param_values ?? [];
const startTime: number[] = policyInterventions.start_time ?? [];

const policyResult = await getRunResult(optimizeRunId, 'policy.json');
const simulationIntervetions: SimulationIntervention[] = [];

if (interventionType === InterventionTypes.paramValue && startTime.length !== 0) {
// intervention type == parameter value
for (let i = 0; i < paramNames.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ const pollResult = async (runId: string) => {
const startForecast = async (simulationIntervetions) => {
const simulationPayload: SimulationRequest = {
projectId: '',
modelConfigId: modelConfigId.value as string,
timespan: {
start: 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
<div class="constraint-row">
<div class="label-and-input">
<label>Target-variable</label>
<Dropdown
<MultiSelect
class="p-inputtext-sm"
:options="modelStateAndObsOptions"
v-model="knobs.targetVariables"
Expand Down Expand Up @@ -296,6 +296,7 @@ import InputText from 'primevue/inputtext';
import InputNumber from 'primevue/inputnumber';
import SelectButton from 'primevue/selectbutton';
import Dialog from 'primevue/dialog';
import MultiSelect from 'primevue/multiselect';
import TeraOptimizeChart from '@/components/workflow/tera-optimize-chart.vue';
import TeraSimulateChart from '@/components/workflow/tera-simulate-chart.vue';
import TeraDatasetDatatable from '@/components/dataset/tera-dataset-datatable.vue';
Expand Down Expand Up @@ -323,7 +324,7 @@ import {
ModelParameter,
OptimizeRequestCiemss,
CsvAsset,
OptimizedIntervention
PolicyInterventions
} from '@/types/Types';
import { logger } from '@/utils/logger';
import { chartActionsProxy, drilldownChartSize } from '@/components/workflow/util';
Expand Down Expand Up @@ -545,8 +546,8 @@ const runOptimize = async () => {
listBoundsInterventions.push([ele.upperBound]);
});
const optimizeInterventions: OptimizedIntervention = {
selection: knobs.value.interventionType,
const optimizeInterventions: PolicyInterventions = {
interventionType: knobs.value.interventionType,
paramNames,
startTime,
paramValues
Expand All @@ -560,7 +561,7 @@ const runOptimize = async () => {
start: 0,
end: knobs.value.endTime
},
interventions: optimizeInterventions,
policyInterventions: optimizeInterventions,
qoi: {
contexts: knobs.value.targetVariables,
method: knobs.value.qoiMethod
Expand All @@ -581,7 +582,6 @@ const runOptimize = async () => {
if (inferredParameters.value) {
optimizePayload.extra.inferredParameters = inferredParameters.value[0];
}
const optResult = await makeOptimizeJobCiemss(optimizePayload);
const state = _.cloneDeep(props.node.state);
state.inProgressOptimizeId = optResult.simulationId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ const makeForecastRequest = async () => {
const state = props.node.state;
const payload: SimulationRequest = {
projectId: '',
modelConfigId,
timespan: {
start: state.currentTimespan.start,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ const makeForecastRequest = async (): Promise<string> => {
const state = props.node.state;
const payload: SimulationRequest = {
projectId: useProjects().activeProject.value?.id as string,
modelConfigId: configId,
timespan: {
start: state.currentTimespan.start,
Expand Down
9 changes: 5 additions & 4 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ export interface CalibrationRequestCiemss {
modelConfigId: string;
extra: any;
timespan?: TimeSpan;
interventions?: Intervention[];
dataset: DatasetLocation;
engine: string;
}
Expand Down Expand Up @@ -687,7 +688,8 @@ export interface EnsembleSimulationCiemssRequest {
export interface OptimizeRequestCiemss {
modelConfigId: string;
timespan: TimeSpan;
interventions?: OptimizedIntervention;
policyInterventions?: PolicyInterventions;
fixedStaticParameterInterventions?: Intervention[];
stepSize?: number;
qoi: OptimizeQoi;
riskBound: number;
Expand All @@ -712,7 +714,6 @@ export interface SimulationRequest {
timespan: TimeSpan;
extra: any;
engine: string;
projectId: string;
interventions?: Intervention[];
}

Expand Down Expand Up @@ -749,8 +750,8 @@ export interface OptimizeQoi {
method: string;
}

export interface OptimizedIntervention {
selection: string;
export interface PolicyInterventions {
interventionType: string;
paramNames: string[];
paramValues?: number[];
startTime?: number[];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public ResponseEntity<Simulation> makeForecastRun(
sim.setStatus(ProgressState.QUEUED);

// FIXME: These fiels are arguable unnecessary
final Optional<Project> project = projectService.getProject(request.getProjectId());
final Optional<Project> project = projectService.getProject(projectId);
if (project.isPresent()) {
sim.setProjectId(project.get().getId());
sim.setUserId(project.get().getUserId());
Expand Down Expand Up @@ -155,7 +155,7 @@ public ResponseEntity<Simulation> makeForecastRunCiemss(
request.setInterventions(allInterventions);
}
} catch (IOException e) {
String error = "Unable to find model configuration";
String error = "Server error has occured while fetching the model configuration";
log.error(error, e);
throw new ResponseStatusException(org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR, error);
}
Expand All @@ -174,7 +174,7 @@ public ResponseEntity<Simulation> makeForecastRunCiemss(
sim.setExecutionPayload(objectMapper.convertValue(request, JsonNode.class));
sim.setStatus(ProgressState.QUEUED);

final Optional<Project> project = projectService.getProject(request.getProjectId());
final Optional<Project> project = projectService.getProject(projectId);
if (project.isPresent()) {
sim.setProjectId(project.get().getId());
sim.setUserId(project.get().getUserId());
Expand Down Expand Up @@ -205,15 +205,66 @@ public ResponseEntity<JobResponse> makeCalibrateJob(@RequestBody final Calibrati

@PostMapping("ciemss/calibrate")
@Secured(Roles.USER)
public ResponseEntity<JobResponse> makeCalibrateJobCiemss(@RequestBody final CalibrationRequestCiemss request) {
public ResponseEntity<JobResponse> makeCalibrateJobCiemss(
@RequestBody final CalibrationRequestCiemss request, @RequestParam("project-id") final UUID projectId) {
Schema.Permission permission =
projectService.checkPermissionCanWrite(currentUserService.get().getId(), projectId);
// Get model config's interventions and append them to requests:
try {
final Optional<ModelConfiguration> modelConfiguration =
modelConfigService.getAsset(request.getModelConfigId(), permission);
if (modelConfiguration.isEmpty()) {
return ResponseEntity.notFound().build();
}
final List<Intervention> modelInterventions =
modelConfiguration.get().getInterventions();
if (modelInterventions != null) {
List<Intervention> allInterventions = request.getInterventions();
if (allInterventions == null) {
allInterventions = new ArrayList<Intervention>();
}
allInterventions.addAll(modelInterventions);
request.setInterventions(allInterventions);
}
} catch (IOException e) {
String error = "Server error has occured while fetching the model configuration";
log.error(error, e);
throw new ResponseStatusException(org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR, error);
}
return ResponseEntity.ok(simulationCiemssServiceProxy
.makeCalibrateJob(convertObjectToSnakeCaseJsonNode(request))
.getBody());
}

@PostMapping("ciemss/optimize")
@Secured(Roles.USER)
public ResponseEntity<JobResponse> makeOptimizeJobCiemss(@RequestBody final OptimizeRequestCiemss request) {
public ResponseEntity<JobResponse> makeOptimizeJobCiemss(
@RequestBody final OptimizeRequestCiemss request, @RequestParam("project-id") final UUID projectId) {
Schema.Permission permission =
projectService.checkPermissionCanWrite(currentUserService.get().getId(), projectId);

// Get model config's interventions and append them to requests:
try {
final Optional<ModelConfiguration> modelConfiguration =
modelConfigService.getAsset(request.getModelConfigId(), permission);
if (modelConfiguration.isEmpty()) {
return ResponseEntity.notFound().build();
}
final List<Intervention> modelInterventions =
modelConfiguration.get().getInterventions();
if (modelInterventions != null) {
List<Intervention> allInterventions = request.getFixedStaticParameterInterventions();
if (allInterventions == null) {
allInterventions = new ArrayList<Intervention>();
}
allInterventions.addAll(modelInterventions);
request.setFixedStaticParameterInterventions(allInterventions);
}
} catch (IOException e) {
String error = "Server error has occured while fetching the model configuration";
log.error(error, e);
throw new ResponseStatusException(org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR, error);
}
return ResponseEntity.ok(simulationCiemssServiceProxy
.makeOptimizeJob(convertObjectToSnakeCaseJsonNode(request))
.getBody());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@

import com.fasterxml.jackson.annotation.JsonAlias;
import java.io.Serializable;
import java.util.List;
import java.util.UUID;
import lombok.Data;
import lombok.experimental.Accessors;
import software.uncharted.terarium.hmiserver.annotations.TSModel;
import software.uncharted.terarium.hmiserver.annotations.TSOptional;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.DatasetLocation;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.Intervention;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.TimeSpan;

@Data
Expand All @@ -15,13 +18,16 @@
// Used to kick off a calibration job in simulation-service
public class CalibrationRequestCiemss implements Serializable {
@JsonAlias("model_config_id")
private String modelConfigId;
private UUID modelConfigId;

private Object extra;

@TSOptional
private TimeSpan timespan;

@TSOptional
private List<Intervention> interventions;

private DatasetLocation dataset;
private String engine;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import com.fasterxml.jackson.annotation.JsonAlias;
import java.io.Serializable;
import java.util.List;
import java.util.UUID;
import lombok.Data;
import lombok.experimental.Accessors;
import software.uncharted.terarium.hmiserver.annotations.TSModel;
import software.uncharted.terarium.hmiserver.annotations.TSOptional;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.Intervention;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizeExtra;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizeQoi;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizedIntervention;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.PolicyInterventions;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.TimeSpan;

@Data
Expand All @@ -18,14 +20,17 @@
// Used to kick off a Optimize job in simulation-service
public class OptimizeRequestCiemss implements Serializable {
@JsonAlias("model_config_id")
private String modelConfigId;
private UUID modelConfigId;

private TimeSpan timespan;

@TSOptional
// FIXME: make pluraal more consistent here:
// https://github.com/DARPA-ASKEM/pyciemss-service/blob/main/service/models/operations/optimize.py#L80
private OptimizedIntervention interventions;
private PolicyInterventions policyInterventions;

@TSOptional
// The interventions provided via the model config which are not being optimized on
private List<Intervention> fixedStaticParameterInterventions;

@JsonAlias("step_size")
@TSOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ public class SimulationRequest implements Serializable {

private String engine;

private UUID projectId;

@TSOptional
private List<Intervention> interventions;

Expand All @@ -44,7 +42,6 @@ public SimulationRequest clone() {
: null);
clone.setExtra(this.extra.deepCopy());
clone.setEngine(this.engine);
clone.setProjectId(this.projectId);
clone.setInterventions(new ArrayList<>());
for (final Intervention intervention : this.interventions) {
clone.getInterventions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
@Data
@Accessors(chain = true)
@TSModel
// Used to specify any interventions provided by the AMR and given to the simulation-service.
public class OptimizedIntervention {
private String selection;
// Interventions applied by the user within the optimization box.
public class PolicyInterventions {
// This denotes whether the intervention is on a start date, or a parameter value.
// https://github.com/DARPA-ASKEM/pyciemss-service/blob/main/service/models/operations/optimize.py#L99
private String interventionType;

@JsonAlias("param_names")
private List<String> paramNames;
Expand Down

0 comments on commit d1b27c6

Please sign in to comment.