Skip to content

Commit

Permalink
Optimize new plumbing (#4112)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Yohann Paris <github@yohannparis.com>
  • Loading branch information
3 people authored Jul 15, 2024
1 parent 11088a7 commit c44ebd7
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ export const OptimizeCiemssOperation: Operation = {
{
type: 'policyInterventionId',
label: 'Interventions',
acceptMultiple: false,
isOptional: true
acceptMultiple: false
}
],
outputs: [{ type: 'simulationId' }],
Expand Down Expand Up @@ -168,11 +167,11 @@ export async function getOptimizedInterventions(optimizeRunId: string) {
const simulation = await getSimulation(optimizeRunId);
const simulationIntervetions =
simulation?.executionPayload.fixed_static_parameter_interventions ?? [];
const policyInterventions = simulation?.executionPayload?.policy_interventions;
const interventionType = policyInterventions.selection ?? '';
const paramNames: string[] = policyInterventions.param_names ?? [];
const paramValue: number[] = policyInterventions.param_values ?? [];
const startTime: number[] = policyInterventions.start_time ?? [];
const optimizeInterventions = simulation?.executionPayload?.optimize_interventions;
const interventionType = optimizeInterventions.intervention_type ?? '';
const paramNames: string[] = optimizeInterventions.param_names ?? [];
const paramValue: number[] = optimizeInterventions.param_values ?? [];
const startTime: number[] = optimizeInterventions.start_time ?? [];

const policyResult = await getRunResult(optimizeRunId, 'policy.json');

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,10 @@ import {
ModelParameter,
OptimizeRequestCiemss,
CsvAsset,
PolicyInterventions,
OptimizeInterventions,
OptimizeQoi,
InterventionPolicy
InterventionPolicy,
Intervention
} from '@/types/Types';
import { logger } from '@/utils/logger';
import { chartActionsProxy, drilldownChartSize, nodeMetadata } from '@/components/workflow/util';
Expand Down Expand Up @@ -404,6 +405,14 @@ const isSaveDisabled = computed<boolean>(() =>
isSaveDatasetDisabled(props.node.state.postForecastRunId, useProjects().activeProject.value?.id)
);
const activePolicyGroups = computed(() =>
props.node.state.interventionPolicyGroups.filter((ele) => ele.isActive === true)
);
const inactivePolicyGroups = computed(() =>
props.node.state.interventionPolicyGroups.filter((ele) => ele.isActive === false)
);
const menuItems = computed(() => [
{
label: 'Save as a new model configuration',
Expand Down Expand Up @@ -450,11 +459,13 @@ const outputs = computed(() => {
const isRunDisabled = computed(() => {
if (
!props.node.state.constraintGroups?.at(0)?.targetVariable ||
props.node.state.interventionPolicyGroups.length === 0
props.node.state.interventionPolicyGroups.length === 0 ||
activePolicyGroups.value.length <= 0
)
return true;
return false;
});
const selectedOutputId = ref<string>();
const outputViewSelection = ref(OutputView.Charts);
Expand Down Expand Up @@ -573,23 +584,36 @@ const runOptimize = async () => {
const startTime: number[] = [];
const listInitialGuessInterventions: number[] = [];
const listBoundsInterventions: number[][] = [];
props.node.state.interventionPolicyGroups.forEach((ele) => {
const initialGuess: number[] = [];
const objectiveFunctionOption: string[] = [];
activePolicyGroups.value.forEach((ele) => {
paramNames.push(ele.intervention.appliedTo);
paramValues.push(ele.intervention.staticInterventions[0].value);
startTime.push(ele.startTime);
initialGuess.push(ele.initialGuessValue);
objectiveFunctionOption.push(ele.objectiveFunctionOption);
listInitialGuessInterventions.push(ele.initialGuessValue);
listBoundsInterventions.push([ele.lowerBoundValue]);
listBoundsInterventions.push([ele.upperBoundValue]);
});
const interventionType = props.node.state.interventionPolicyGroups[0].optimizationType;
const optimizeInterventions: PolicyInterventions = {
// These are interventions to be optimized over.
const optimizeInterventions: OptimizeInterventions = {
interventionType,
paramNames,
startTime,
paramValues
paramValues,
initialGuess,
objectiveFunctionOption
};
// These are interventions to be considered but not optimized over.
const fixedStaticParameterInterventions: Intervention[] = _.cloneDeep(
inactivePolicyGroups.value.map((ele) => ele.intervention)
);
// TODO: https://github.com/DARPA-ASKEM/terarium/issues/3909
// The method should be a list but pyciemss + pyciemss service is not yet ready for this.
const qoi: OptimizeQoi = {
Expand All @@ -605,7 +629,8 @@ const runOptimize = async () => {
start: 0,
end: knobs.value.endTime
},
policyInterventions: optimizeInterventions,
optimizeInterventions,
fixedStaticParameterInterventions,
qoi,
riskBound: props.node.state.constraintGroups[0].threshold, // TODO: https://github.com/DARPA-ASKEM/terarium/issues/3909
initialGuessInterventions: listInitialGuessInterventions,
Expand All @@ -620,9 +645,11 @@ const runOptimize = async () => {
}
};
// InferredParameters is to link a calibration run to this optimize call.
if (inferredParameters.value) {
optimizePayload.extra.inferredParameters = inferredParameters.value[0];
}
const optResult = await makeOptimizeJobCiemss(optimizePayload, nodeMetadata(props.node));
const state = _.cloneDeep(props.node.state);
state.inProgressOptimizeId = optResult.simulationId;
Expand Down
18 changes: 10 additions & 8 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -727,8 +727,8 @@ export interface EnsembleSimulationCiemssRequest {
export interface OptimizeRequestCiemss {
modelConfigId: string;
timespan: TimeSpan;
policyInterventions?: PolicyInterventions;
fixedStaticParameterInterventions?: string;
optimizeInterventions?: OptimizeInterventions;
fixedStaticParameterInterventions?: Intervention[];
stepSize?: number;
qoi: OptimizeQoi;
riskBound: number;
Expand Down Expand Up @@ -803,16 +803,18 @@ export interface OptimizeExtra {
solverMethod?: string;
}

export interface OptimizeQoi {
contexts: string[];
method: string;
}

export interface PolicyInterventions {
export interface OptimizeInterventions {
interventionType: string;
paramNames: string[];
paramValues?: number[];
startTime?: number[];
objectiveFunctionOption?: string[];
initialGuess?: number[];
}

export interface OptimizeQoi {
contexts: string[];
method: string;
}

export interface TimeSpan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import lombok.experimental.Accessors;
import software.uncharted.terarium.hmiserver.annotations.TSModel;
import software.uncharted.terarium.hmiserver.annotations.TSOptional;
import software.uncharted.terarium.hmiserver.models.simulationservice.interventions.Intervention;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizeExtra;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizeInterventions;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.OptimizeQoi;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.PolicyInterventions;
import software.uncharted.terarium.hmiserver.models.simulationservice.parts.TimeSpan;

@Data
Expand All @@ -25,10 +26,11 @@ public class OptimizeRequestCiemss implements Serializable {

@TSOptional
// https://github.com/DARPA-ASKEM/pyciemss-service/blob/main/service/models/operations/optimize.py#L80
private PolicyInterventions policyInterventions;
private OptimizeInterventions optimizeInterventions;

@TSOptional
private UUID fixedStaticParameterInterventions;
@JsonAlias("fixed_static_parameter_interventions")
private List<Intervention> fixedStaticParameterInterventions;

@JsonAlias("step_size")
@TSOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@Accessors(chain = true)
@TSModel
// Interventions applied by the user within the optimization box.
public class PolicyInterventions {
public class OptimizeInterventions {
// This denotes whether the intervention is on a start date, or a parameter value.
// https://github.com/DARPA-ASKEM/pyciemss-service/blob/main/service/models/operations/optimize.py#L99
private String interventionType;
Expand All @@ -27,6 +27,14 @@ public class PolicyInterventions {
@JsonAlias("start_time")
private List<Integer> startTime;

@TSOptional
@JsonAlias("objective_function_option")
private List<String> objectiveFunctionOption;

@TSOptional
@JsonAlias("initial_guess")
private List<Double> initialGuess;

@Override
public String toString() {
return " { Parameter Names: " + this.paramNames + " start time: " + startTime.toString() + " } ";
Expand Down

0 comments on commit c44ebd7

Please sign in to comment.