diff --git a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java index ad4c5419..30fe9456 100644 --- a/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java +++ b/common/transform/src/main/java/ai/h2o/mojos/deploy/common/transform/MojoScorer.java @@ -302,6 +302,10 @@ public Model getModelInfo() { return modelInfoConverter.apply(pipeline); } + public ShapleyLoadOption getEnabledShapleyTypes() { + return enabledShapleyTypes; + } + /** * Method to load mojo pipelines for shapley scoring based on configuration * diff --git a/gradle.properties b/gradle.properties index 317dbe7f..ca2ecb9b 100644 --- a/gradle.properties +++ b/gradle.properties @@ -13,9 +13,9 @@ awsLambdaEventsVersion = 2.2.3 awsSdkS3Version = 1.11.445 javaxAnnotationVersion = 1.3.2 gsonVersion = 2.8.5 -jupiterVersion = 5.3.1 +jupiterVersion = 5.4.0 jupiterSystemStubsVersion = 1.2.0 -mockitoVersion = 3.0.0 +mockitoVersion = 3.4.0 springFoxVersion = 3.0.0 swaggerCodegenVersion = 3.0.0 swaggerCoreVersion = 2.0.5 diff --git a/local-rest-scorer/build.gradle b/local-rest-scorer/build.gradle index bc7b0bc6..58a5fab6 100644 --- a/local-rest-scorer/build.gradle +++ b/local-rest-scorer/build.gradle @@ -20,6 +20,8 @@ dependencies { testImplementation group: 'org.springframework.boot', name: 'spring-boot-starter-test' testImplementation group: 'com.google.truth.extensions', name: 'truth-java8-extension' + testImplementation group: 'org.mockito', name: 'mockito-inline' + testImplementation group: 'org.mockito', name : 'mockito-core' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-api' testImplementation group: 'org.junit.jupiter', name: 'junit-jupiter-params' testRuntimeOnly group: 'org.junit.jupiter', name: 'junit-jupiter-engine' diff --git a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java index 591d7d19..00e4c811 100644 --- a/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java +++ b/local-rest-scorer/src/main/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiController.java @@ -9,6 +9,7 @@ import ai.h2o.mojos.deploy.common.rest.model.ScoreResponse; import ai.h2o.mojos.deploy.common.transform.MojoScorer; import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; +import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; import com.google.common.base.Strings; import java.io.IOException; import java.util.Arrays; @@ -24,13 +25,13 @@ @Controller public class ModelsApiController implements ModelApi { - private static final List SUPPORTED_CAPABILITIES - = Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); private static final Logger log = LoggerFactory.getLogger(ModelsApiController.class); private final MojoScorer scorer; private final SampleRequestBuilder sampleRequestBuilder; + private final List supportedCapabilities; + /** * Simple Api controller. Inherits from {@link ModelApi}, which controls global, expected request * mappings for the rest service. @@ -44,6 +45,9 @@ public class ModelsApiController implements ModelApi { public ModelsApiController(MojoScorer scorer, SampleRequestBuilder sampleRequestBuilder) { this.scorer = scorer; this.sampleRequestBuilder = sampleRequestBuilder; + this.supportedCapabilities = assembleSupportedCapabilities( + scorer.getEnabledShapleyTypes() + ); } @Override @@ -58,7 +62,7 @@ public ResponseEntity getModelId() { @Override public ResponseEntity> getCapabilities() { - return ResponseEntity.ok(SUPPORTED_CAPABILITIES); + return ResponseEntity.ok(supportedCapabilities); } @Override @@ -116,4 +120,22 @@ public ResponseEntity getContribution( public ResponseEntity getSampleRequest() { return ResponseEntity.ok(sampleRequestBuilder.build(scorer.getPipeline().getInputMeta())); } + + private static List assembleSupportedCapabilities( + ShapleyLoadOption enabledShapleyTypes) { + switch (enabledShapleyTypes) { + case ALL: + return Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED); + case ORIGINAL: + return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_ORIGINAL); + case TRANSFORMED: + return Arrays.asList(CapabilityType.SCORE, CapabilityType.CONTRIBUTION_TRANSFORMED); + case NONE: + default: + return Arrays.asList(CapabilityType.SCORE); + } + } } diff --git a/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java new file mode 100644 index 00000000..31af174e --- /dev/null +++ b/local-rest-scorer/src/test/java/ai/h2o/mojos/deploy/local/rest/controller/ModelsApiControllerTest.java @@ -0,0 +1,118 @@ +package ai.h2o.mojos.deploy.local.rest.controller; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import ai.h2o.mojos.deploy.common.rest.model.CapabilityType; +import ai.h2o.mojos.deploy.common.transform.MojoScorer; +import ai.h2o.mojos.deploy.common.transform.SampleRequestBuilder; +import ai.h2o.mojos.deploy.common.transform.ShapleyLoadOption; +import ai.h2o.mojos.runtime.MojoPipeline; +import ai.h2o.mojos.runtime.api.MojoPipelineService; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.ResponseEntity; + +@ExtendWith(MockitoExtension.class) +class ModelsApiControllerTest { + @Mock private SampleRequestBuilder sampleRequestBuilder; + + @BeforeAll + static void setup() throws IOException { + File tmpModel = File.createTempFile("pipeline", ".mojo"); + System.setProperty("mojo.path", tmpModel.getAbsolutePath()); + mockMojoPipeline(tmpModel); + } + + private static void mockMojoPipeline(File tmpModel) { + MojoPipeline mojoPipeline = Mockito.mock(MojoPipeline.class); + MockedStatic theMock = Mockito.mockStatic(MojoPipelineService.class); + theMock.when(() -> MojoPipelineService + .loadPipeline(new File(tmpModel.getAbsolutePath()))).thenReturn(mojoPipeline); + } + + @Test + void verifyCapabilities_DefaultShapley_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList(CapabilityType.SCORE); + + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.NONE); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_AllShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL, + CapabilityType.CONTRIBUTION_TRANSFORMED); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ALL); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_OriginalShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_ORIGINAL); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.ORIGINAL); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } + + @Test + void verifyCapabilities_TransformedShapleyEnabled_ReturnsExpected() { + // Given + List expectedCapabilities = Arrays.asList( + CapabilityType.SCORE, + CapabilityType.CONTRIBUTION_TRANSFORMED); + MojoScorer scorer = mock(MojoScorer.class); + when(scorer.getEnabledShapleyTypes()).thenReturn(ShapleyLoadOption.TRANSFORMED); + + ModelsApiController controller = new ModelsApiController(scorer, sampleRequestBuilder); + + // When + ResponseEntity> response = controller.getCapabilities(); + + // Then + assertEquals(expectedCapabilities, response.getBody()); + } +}