From b7fca110bdd6a2913c117e6382678d5b850aefd8 Mon Sep 17 00:00:00 2001 From: nkpng2k Date: Wed, 5 Jan 2022 12:17:05 -0800 Subject: [PATCH 1/4] set supported capabilities dynamically --- .../deploy/common/transform/MojoScorer.java | 4 +++ .../rest/controller/ModelsApiController.java | 32 ++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java index ad4c5419..30fe9456 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java @@ -302,6 +302,10 @@ public Model getModelInfo() { return modelInfoConverter.apply(pipeline); } + public ShapleyLoadOption getEnabledShapleyTypes() { + return enabledShapleyTypes; + } + /** * Method to load mojo pipelines for shapley scoring based on configuration * diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index 591d7d19..8dafb183 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -9,9 +9,10 @@ import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse; import ai.h2o.mojos.deploy.common.transform.MojoScorer; import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; +import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import com.google.common.base.Strings; import java.io.IOException; -import java.util.Arrays; +import java.util.ArrayList; import java.util.List; import org.slf4j.Logger; @@ -24,13 +25,13 @@ @Controller public class ModelsApiController implements ModelApi { - private static final List SUPPORTED_CAPABILITIES - = Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); private static final Logger log = LoggerFactory.getLogger(ModelsApiController.class); private final MojoScorer scorer; private final SampleRequestBuilder sampleRequestBuilder; + private final List supportedCapabilities; + /** * Simple Api controller. Inherits from {@link ModelApi}, which controls global, expected request * mappings for the rest service. @@ -44,6 +45,7 @@ public class ModelsApiController implements ModelApi { public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequestBuilder) { this.scorer = scorer; this.sampleRequestBuilder = sampleRequestBuilder; + this.supportedCapabilities = setSupportedCapabilities(); } @Override @@ -58,7 +60,7 @@ public ResponseEntity getModelId() { @Override public ResponseEntity> getCapabilities() { - return ResponseEntity.ok(SUPPORTED_CAPABILITIES); + return ResponseEntity.ok(supportedCapabilities); } @Override @@ -116,4 +118,26 @@ public ResponseEntity getContribution( public ResponseEntity getSampleRequest() { return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta())); } + + private List setSupportedCapabilities() { + List capabilityTypes = new ArrayList<>(); + capabilityTypes.add(CapabilityType.SCORE); + ShapleyLoadOption enabledShapleyTypes = scorer.getEnabledShapleyTypes(); + switch (enabledShapleyTypes) { + case ALL: + capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); + capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); + break; + case ORIGINAL: + capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); + break; + case TRANSFORMED: + capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); + break; + case NONE: + default: + break; + } + return capabilityTypes; + } } From 0e6202a2ddd34bdf79c9ea692c8b4b16bb978946 Mon Sep 17 00:00:00 2001 From: nkpng2k Date: Mon, 10 Jan 2022 16:39:58 -0800 Subject: [PATCH 2/4] review: small improvements --- .../rest/controller/ModelsApiController.java | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index 8dafb183..2111bf6e 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -45,7 +45,9 @@ public class ModelsApiController implements ModelApi { public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequestBuilder) { this.scorer = scorer; this.sampleRequestBuilder = sampleRequestBuilder; - this.supportedCapabilities = setSupportedCapabilities(); + this.supportedCapabilities = assembleSupportedCapabilities( + scorer.getEnabledShapleyTypes() + ); } @Override @@ -119,25 +121,24 @@ public ResponseEntity getSampleRequest() { return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta())); } - private List setSupportedCapabilities() { + private List assembleSupportedCapabilities( + ShapleyLoadOption enabledShapleyTypes) { List capabilityTypes = new ArrayList<>(); capabilityTypes.add(CapabilityType.SCORE); - ShapleyLoadOption enabledShapleyTypes = scorer.getEnabledShapleyTypes(); switch (enabledShapleyTypes) { case ALL: capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); - break; + return capabilityTypes; case ORIGINAL: capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); - break; + return capabilityTypes; case TRANSFORMED: capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); - break; + return capabilityTypes; case NONE: default: - break; + return capabilityTypes; } - return capabilityTypes; } } From 67bd739d3faf6097764ad35be7786ccbcbd098f6 Mon Sep 17 00:00:00 2001 From: nkpng2k Date: Mon, 10 Jan 2022 17:36:26 -0800 Subject: [PATCH 3/4] adds tests --- local-rest-scorer/build.gradle | 2 + .../controller/ModelsApiControllerTest.java | 174 ++++++++++++++++++ .../test/resources/multinomial-pipeline.mojo | 0 3 files changed, 176 insertions(+) create mode 100644 local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java create mode 100644 local-rest-scorer/src/test/resources/multinomial-pipeline.mojo diff --git a/local-rest-scorer/build.gradle b/local-rest-scorer/build.gradle index bc7b0bc6..3e9d56c7 100644 --- a/local-rest-scorer/build.gradle +++ b/local-rest-scorer/build.gradle @@ -20,6 +20,8 @@ dependencies { testImplementation group: 'org.springframework.boot', name: 'spring-boot-starter-test' testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' + testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.4.0' + testImplementation group: 'org.mockito', name : 'mockito-core', version: '3.4.0' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params' testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' diff --git a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java new file mode 100644 index 00000000..748a48c5 --- /dev/null +++ b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java @@ -0,0 +1,174 @@ +package ai.h2o.mojos.deploy.local.rest.controller; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.h2o.mojos.deploy.common.rest.model.CapabilityType; +import ai.h2o.mojos.deploy.common.transform.MojoScorer; +import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; +import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; +import ai.h2o.mojos.runtime.MojoPipeline; +import ai.h2o.mojos.runtime.api.BasePipelineListener; +import ai.h2o.mojos.runtime.api.MojoPipelineService; +import ai.h2o.mojos.runtime.frame.MojoColumn; +import ai.h2o.mojos.runtime.frame.MojoFrame; +import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; +import ai.h2o.mojos.runtime.frame.MojoFrameMeta; + +import java.io.File; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.ResponseEntity; + +@ExtendWith(MockitoExtension.class) +class ModelsApiControllerTest { + private static final String MOJO_PIPELINE_PATH = "src/test/resources/multinomial-pipeline.mojo"; + private static final String TEST_UUID = "TEST_UUID"; + + @Mock private SampleRequestBuilder sampleRequestBuilder; + + @BeforeAll + static void setup() { + System.setProperty("mojo.path", "src/test/resources/multinomial-pipeline.mojo"); + mockDummyPipeline(); + } + + private static void mockDummyPipeline() { + MojoPipeline dummyPipeline = + new DummyPipeline(TEST_UUID, MojoFrameMeta.getEmpty(), MojoFrameMeta.getEmpty()); + MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); + theMock.when(() -> MojoPipelineService + .loadPipeline(new File(MOJO_PIPELINE_PATH))).thenReturn(dummyPipeline); + } + + @Test + void verifyCapabilities_DefaultShapley_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList(CapabilityType.SCORE); + + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.NONE); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ALL); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ORIGINAL); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_TRANSFORMED); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + /** Dummy pipeline {@link MojoPipeline} just to mock the static methods used inside scoring. */ + private static class DummyPipeline extends MojoPipeline { + private final MojoFrameMeta inputMeta; + private final MojoFrameMeta outputMeta; + + private DummyPipeline(String uuid, MojoFrameMeta inputMeta, MojoFrameMeta outputMeta) { + super(uuid, null, null); + this.inputMeta = inputMeta; + this.outputMeta = outputMeta; + } + + @Override + public MojoFrameMeta getInputMeta() { + return inputMeta; + } + + @Override + public MojoFrameMeta getOutputMeta() { + return outputMeta; + } + + @Override + protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) { + return new MojoFrameBuilder(outputMeta); + } + + @Override + protected MojoFrameMeta getMeta(MojoColumn.Kind kind) { + return outputMeta; + } + + @Override + public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) { + return outputFrame; + } + + @Override + public void setShapPredictContrib(boolean enable) { + } + + @Override + public void setShapPredictContribOriginal(boolean enable) { + } + + @Override + public void setListener(BasePipelineListener listener) { + } + } +} diff --git a/local-rest-scorer/src/test/resources/multinomial-pipeline.mojo b/local-rest-scorer/src/test/resources/multinomial-pipeline.mojo new file mode 100644 index 00000000..e69de29b From e404eb8fa89f8880ebafaf97a8352e7bc98602ee Mon Sep 17 00:00:00 2001 From: nkpng2k Date: Tue, 11 Jan 2022 13:30:54 -0800 Subject: [PATCH 4/4] review: improvements to tests, better mocking, improve maintainability --- gradle.properties | 4 +- local-rest-scorer/build.gradle | 4 +- .../rest/controller/ModelsApiController.java | 21 +++--- .../controller/ModelsApiControllerTest.java | 72 +++---------------- .../test/resources/multinomial-pipeline.mojo | 0 5 files changed, 21 insertions(+), 80 deletions(-) delete mode 100644 local-rest-scorer/src/test/resources/multinomial-pipeline.mojo diff --git a/gradle.properties b/gradle.properties index 317dbe7f..ca2ecb9b 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,9 +13,9 @@ awsLambdaEventsVersion = 2.2.3 awsSdkS3Version = 1.11.445 javaxAnnotationVersion = 1.3.2 gsonVersion = 2.8.5 -jupiterVersion = 5.3.1 +jupiterVersion = 5.4.0 jupiterSystemStubsVersion = 1.2.0 -mockitoVersion = 3.0.0 +mockitoVersion = 3.4.0 springFoxVersion = 3.0.0 swaggerCodegenVersion = 3.0.0 swaggerCoreVersion = 2.0.5 diff --git a/local-rest-scorer/build.gradle b/local-rest-scorer/build.gradle index 3e9d56c7..58a5fab6 100644 --- a/local-rest-scorer/build.gradle +++ b/local-rest-scorer/build.gradle @@ -20,8 +20,8 @@ dependencies { testImplementation group: 'org.springframework.boot', name: 'spring-boot-starter-test' testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' - testImplementation group: 'org.mockito', name: 'mockito-inline', version: '3.4.0' - testImplementation group: 'org.mockito', name : 'mockito-core', version: '3.4.0' + testImplementation group: 'org.mockito', name: 'mockito-inline' + testImplementation group: 'org.mockito', name : 'mockito-core' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params' testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index 2111bf6e..00e4c811 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -12,7 +12,7 @@ import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import com.google.common.base.Strings; import java.io.IOException; -import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.slf4j.Logger; @@ -121,24 +121,21 @@ public ResponseEntity getSampleRequest() { return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta())); } - private List assembleSupportedCapabilities( + private static List assembleSupportedCapabilities( ShapleyLoadOption enabledShapleyTypes) { - List capabilityTypes = new ArrayList<>(); - capabilityTypes.add(CapabilityType.SCORE); switch (enabledShapleyTypes) { case ALL: - capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); - capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); - return capabilityTypes; + return Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED); case ORIGINAL: - capabilityTypes.add(CapabilityType.CONTRIBUTION_ORIGINAL); - return capabilityTypes; + return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL); case TRANSFORMED: - capabilityTypes.add(CapabilityType.CONTRIBUTION_TRANSFORMED); - return capabilityTypes; + return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); case NONE: default: - return capabilityTypes; + return Arrays.asList(CapabilityType.SCORE); } } } diff --git a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java index 748a48c5..31af174e 100644 --- a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java +++ b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java @@ -9,14 +9,10 @@ import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import ai.h2o.mojos.runtime.MojoPipeline; -import ai.h2o.mojos.runtime.api.BasePipelineListener; import ai.h2o.mojos.runtime.api.MojoPipelineService; -import ai.h2o.mojos.runtime.frame.MojoColumn; -import ai.h2o.mojos.runtime.frame.MojoFrame; -import ai.h2o.mojos.runtime.frame.MojoFrameBuilder; -import ai.h2o.mojos.runtime.frame.MojoFrameMeta; import java.io.File; +import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -32,23 +28,20 @@ @ExtendWith(MockitoExtension.class) class ModelsApiControllerTest { - private static final String MOJO_PIPELINE_PATH = "src/test/resources/multinomial-pipeline.mojo"; - private static final String TEST_UUID = "TEST_UUID"; - @Mock private SampleRequestBuilder sampleRequestBuilder; @BeforeAll - static void setup() { - System.setProperty("mojo.path", "src/test/resources/multinomial-pipeline.mojo"); - mockDummyPipeline(); + static void setup() throws IOException { + File tmpModel = File.createTempFile("pipeline", ".mojo"); + System.setProperty("mojo.path", tmpModel.getAbsolutePath()); + mockMojoPipeline(tmpModel); } - private static void mockDummyPipeline() { - MojoPipeline dummyPipeline = - new DummyPipeline(TEST_UUID, MojoFrameMeta.getEmpty(), MojoFrameMeta.getEmpty()); + private static void mockMojoPipeline(File tmpModel) { + MojoPipeline mojoPipeline = Mockito.mock(MojoPipeline.class); MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); theMock.when(() -> MojoPipelineService - .loadPipeline(new File(MOJO_PIPELINE_PATH))).thenReturn(dummyPipeline); + .loadPipeline(new File(tmpModel.getAbsolutePath()))).thenReturn(mojoPipeline); } @Test @@ -122,53 +115,4 @@ void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() { // Then assertEquals(expectedCapabilities, response.getBody()); } - - /** Dummy pipeline {@link MojoPipeline} just to mock the static methods used inside scoring. */ - private static class DummyPipeline extends MojoPipeline { - private final MojoFrameMeta inputMeta; - private final MojoFrameMeta outputMeta; - - private DummyPipeline(String uuid, MojoFrameMeta inputMeta, MojoFrameMeta outputMeta) { - super(uuid, null, null); - this.inputMeta = inputMeta; - this.outputMeta = outputMeta; - } - - @Override - public MojoFrameMeta getInputMeta() { - return inputMeta; - } - - @Override - public MojoFrameMeta getOutputMeta() { - return outputMeta; - } - - @Override - protected MojoFrameBuilder getFrameBuilder(MojoColumn.Kind kind) { - return new MojoFrameBuilder(outputMeta); - } - - @Override - protected MojoFrameMeta getMeta(MojoColumn.Kind kind) { - return outputMeta; - } - - @Override - public MojoFrame transform(MojoFrame inputFrame, MojoFrame outputFrame) { - return outputFrame; - } - - @Override - public void setShapPredictContrib(boolean enable) { - } - - @Override - public void setShapPredictContribOriginal(boolean enable) { - } - - @Override - public void setListener(BasePipelineListener listener) { - } - } } diff --git a/local-rest-scorer/src/test/resources/multinomial-pipeline.mojo b/local-rest-scorer/src/test/resources/multinomial-pipeline.mojo deleted file mode 100644 index e69de29b..00000000