Skip to content

Commit

Permalink
Formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk committed Oct 26, 2023
1 parent e0e8206 commit 093fef8
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 24 deletions.
3 changes: 2 additions & 1 deletion wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,8 @@ public String[] getLoadOnDevices() {
return getLoadOnDevices(loadOnDevices, engineName, prop, gpuCount, neurons);
}

static String[] getLoadOnDevices(String loadOnDevices, String engineName, Properties prop, int gpuCount, int neurons) {
static String[] getLoadOnDevices(
String loadOnDevices, String engineName, Properties prop, int gpuCount, int neurons) {
if ("*".equals(loadOnDevices)) {

boolean mpiMode;
Expand Down
89 changes: 66 additions & 23 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,55 +167,98 @@ public void testLoadOnDevices() {
tpMaxProperties.put("option.tensor_parallel_degree", "max");

// All devices (no tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 2, 0), new String[]{"0", "1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 0, 2), new String[]{"nc0", "nc1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 2, 0),
new String[] {"0", "1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 0, 2),
new String[] {"nc0", "nc1"});

// All devices (mpi mode, default tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"}); // GPU MPI defaults to TP=max
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 0, 2), new String[]{"nc0", "nc1"}); // Neuron MPI defaults to TP=1
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0),
new String[] {"0+1"}); // GPU MPI defaults to TP=max
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 0, 2),
new String[] {"nc0", "nc1"}); // Neuron MPI defaults to TP=1

// All devices (mpi mode, explicit tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 4), new String[]{"nc0+nc1", "nc2+nc3"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0),
new String[] {"0+1", "2+3"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 4),
new String[] {"nc0+nc1", "nc2+nc3"});

// All devices (mpi mode, 0 tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 2, 0), new String[]{"0", "1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 0, 2), new String[]{"nc0", "nc1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 2, 0),
new String[] {"0", "1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 0, 2),
new String[] {"nc0", "nc1"});

// All devices (mpi mode, max tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 0, 2), new String[]{"nc0+nc1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 2, 0),
new String[] {"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 0, 2),
new String[] {"nc0+nc1"});

// All devices (insufficient TP assignment)
Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 1, 0));
Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 1));

// All devices (uneven TP assignment)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 3, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 3), new String[]{"nc0+nc1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 3, 0), new String[] {"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 3),
new String[] {"nc0+nc1"});

// All devices (only CPU available)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", emptyProperties, 0, 0), new String[]{"-1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 0), new String[]{"-1"}); // Ignored MPI and TPI and falls back to CPU
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "Python", emptyProperties, 0, 0),
new String[] {"-1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 0),
new String[] {"-1"}); // Ignored MPI and TPI and falls back to CPU

// Ways to enable mpi mode
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", mpiProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "DeepSpeed", emptyProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "FasterTransformer", emptyProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "Python", mpiProperties, 2, 0),
new String[] {"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0),
new String[] {"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "DeepSpeed", emptyProperties, 2, 0),
new String[] {"0+1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "FasterTransformer", emptyProperties, 2, 0),
new String[] {"0+1"});

// Ways to set tensor parallel
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0),
new String[] {"0+1", "2+3"});
System.setProperty("TENSOR_PARALLEL_DEGREE", "2");
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", mpiProperties, 4, 0), new String[]{"0+1", "2+3"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("*", "MPI", mpiProperties, 4, 0),
new String[] {"0+1", "2+3"});
System.clearProperty("TENSOR_PARALLEL_DEGREE");

// Providing devices
Assert.assertEquals(ModelInfo.getLoadOnDevices("1", "", emptyProperties, 10, 10), new String[]{"1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("1;2", "", emptyProperties, 10, 10), new String[]{"1", "2"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("1", "", emptyProperties, 10, 10), new String[] {"1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("1;2", "", emptyProperties, 10, 10),
new String[] {"1", "2"});

// Empty case
Assert.assertEquals(ModelInfo.getLoadOnDevices("", "", emptyProperties, 10, 10), new String[]{"-1"});
Assert.assertEquals(
ModelInfo.getLoadOnDevices("", "", emptyProperties, 10, 10), new String[] {"-1"});
}

@Test
Expand Down

0 comments on commit 093fef8

Please sign in to comment.