diff --git a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java index c6f7a35e9..06106192b 100644 --- a/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -816,7 +816,8 @@ public String[] getLoadOnDevices() { return getLoadOnDevices(loadOnDevices, engineName, prop, gpuCount, neurons); } - static String[] getLoadOnDevices(String loadOnDevices, String engineName, Properties prop, int gpuCount, int neurons) { + static String[] getLoadOnDevices( + String loadOnDevices, String engineName, Properties prop, int gpuCount, int neurons) { if ("*".equals(loadOnDevices)) { boolean mpiMode; diff --git a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index d820040f9..23ad0d5bd 100644 --- a/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -167,55 +167,98 @@ public void testLoadOnDevices() { tpMaxProperties.put("option.tensor_parallel_degree", "max"); // All devices (no tensor parallel) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 2, 0), new String[]{"0", "1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 0, 2), new String[]{"nc0", "nc1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 2, 0), + new String[] {"0", "1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 0, 2), + new String[] {"nc0", "nc1"}); // All devices (mpi mode, default tensor parallel) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"}); // GPU MPI defaults to TP=max - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 0, 2), new String[]{"nc0", "nc1"}); // Neuron MPI defaults to TP=1 + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), + new String[] {"0+1"}); // GPU MPI defaults to TP=max + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 0, 2), + new String[] {"nc0", "nc1"}); // Neuron MPI defaults to TP=1 // All devices (mpi mode, explicit tensor parallel) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 4), new String[]{"nc0+nc1", "nc2+nc3"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), + new String[] {"0+1", "2+3"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 4), + new String[] {"nc0+nc1", "nc2+nc3"}); // All devices (mpi mode, 0 tensor parallel) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 2, 0), new String[]{"0", "1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 0, 2), new String[]{"nc0", "nc1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 2, 0), + new String[] {"0", "1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 0, 2), + new String[] {"nc0", "nc1"}); // All devices (mpi mode, max tensor parallel) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 2, 0), new String[]{"0+1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 0, 2), new String[]{"nc0+nc1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 2, 0), + new String[] {"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 0, 2), + new String[] {"nc0+nc1"}); // All devices (insufficient TP assignment) Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 1, 0)); Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 1)); // All devices (uneven TP assignment) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 3, 0), new String[]{"0+1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 3), new String[]{"nc0+nc1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 3, 0), new String[] {"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 3), + new String[] {"nc0+nc1"}); // All devices (only CPU available) - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", emptyProperties, 0, 0), new String[]{"-1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 0), new String[]{"-1"}); // Ignored MPI and TPI and falls back to CPU + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "Python", emptyProperties, 0, 0), + new String[] {"-1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 0), + new String[] {"-1"}); // Ignored MPI and TPI and falls back to CPU // Ways to enable mpi mode - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", mpiProperties, 2, 0), new String[]{"0+1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "DeepSpeed", emptyProperties, 2, 0), new String[]{"0+1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "FasterTransformer", emptyProperties, 2, 0), new String[]{"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "Python", mpiProperties, 2, 0), + new String[] {"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), + new String[] {"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "DeepSpeed", emptyProperties, 2, 0), + new String[] {"0+1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "FasterTransformer", emptyProperties, 2, 0), + new String[] {"0+1"}); // Ways to set tensor parallel - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), + new String[] {"0+1", "2+3"}); System.setProperty("TENSOR_PARALLEL_DEGREE", "2"); - Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", mpiProperties, 4, 0), new String[]{"0+1", "2+3"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("*", "MPI", mpiProperties, 4, 0), + new String[] {"0+1", "2+3"}); System.clearProperty("TENSOR_PARALLEL_DEGREE"); // Providing devices - Assert.assertEquals(ModelInfo.getLoadOnDevices("1", "", emptyProperties, 10, 10), new String[]{"1"}); - Assert.assertEquals(ModelInfo.getLoadOnDevices("1;2", "", emptyProperties, 10, 10), new String[]{"1", "2"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("1", "", emptyProperties, 10, 10), new String[] {"1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("1;2", "", emptyProperties, 10, 10), + new String[] {"1", "2"}); // Empty case - Assert.assertEquals(ModelInfo.getLoadOnDevices("", "", emptyProperties, 10, 10), new String[]{"-1"}); + Assert.assertEquals( + ModelInfo.getLoadOnDevices("", "", emptyProperties, 10, 10), new String[] {"-1"}); } @Test