Skip to content

Commit

Permalink
Adds tests for all load defaults cases
Browse files Browse the repository at this point in the history
  • Loading branch information
zachgk committed Oct 26, 2023
1 parent e8c86ae commit e0e8206
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
17 changes: 11 additions & 6 deletions wlm/src/main/java/ai/djl/serving/wlm/ModelInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,13 @@ void checkAvailableMemory(Device device) throws IOException {
@Override
public String[] getLoadOnDevices() {
Engine eng = Engine.getEngine(engineName);
int gpuCount = eng.getGpuCount();
int neurons = NeuronUtils.getNeuronCores();
return getLoadOnDevices(loadOnDevices, engineName, prop, gpuCount, neurons);
}

static String[] getLoadOnDevices(String loadOnDevices, String engineName, Properties prop, int gpuCount, int neurons) {
if ("*".equals(loadOnDevices)) {
int gpuCount = eng.getGpuCount();

boolean mpiMode;
if (prop.containsKey("option.mpi_mode")) {
Expand All @@ -825,11 +830,12 @@ public String[] getLoadOnDevices() {
mpiMode = false;
}

String tpDegreeProp = Utils.getEnvOrSystemProperty("TENSOR_PARALLEL_DEGREE");
String tpDegreeString;
if (prop.containsKey("option.tensor_parallel_degree")) {
tpDegreeString = prop.getProperty("option.tensor_parallel_degree");
} else if (Utils.getenv().containsKey("TENSOR_PARALLEL_DEGREE")) {
tpDegreeString = Utils.getenv("TENSOR_PARALLEL_DEGREE");
} else if (tpDegreeProp != null) {
tpDegreeString = tpDegreeProp;
} else if (mpiMode && gpuCount > 0) {
tpDegreeString = "max";
logger.info(
Expand All @@ -843,7 +849,7 @@ public String[] getLoadOnDevices() {
if (gpuCount > 0) {
tpDegree = gpuCount;
} else {
tpDegree = NeuronUtils.getNeuronCores();
tpDegree = neurons;
}
} else {
tpDegree = Integer.parseInt(tpDegreeString);
Expand Down Expand Up @@ -877,8 +883,7 @@ public String[] getLoadOnDevices() {
ret[i] = String.valueOf(i * gpuPerWorker);
}
return ret;
} else if (NeuronUtils.hasNeuron()) {
int neurons = NeuronUtils.getNeuronCores();
} else if (neurons > 0) {
int ncPerWorker;
if (tpDegree > 0) {
// Assume user understand TP only works on inf2
Expand Down
64 changes: 64 additions & 0 deletions wlm/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,70 @@ public void testOutOfMemory() throws IOException, ModelException {
}
}

@Test
public void testLoadOnDevices() {
Properties emptyProperties = new Properties();
Properties mpiProperties = new Properties();
mpiProperties.put("option.mpi_mode", "true");
Properties tp0Properties = new Properties();
tp0Properties.put("option.tensor_parallel_degree", "0");
Properties tp2Properties = new Properties();
tp2Properties.put("option.tensor_parallel_degree", "2");
Properties tpMaxProperties = new Properties();
tpMaxProperties.put("option.tensor_parallel_degree", "max");

// All devices (no tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 2, 0), new String[]{"0", "1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "PyTorch", emptyProperties, 0, 2), new String[]{"nc0", "nc1"});

// All devices (mpi mode, default tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"}); // GPU MPI defaults to TP=max
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 0, 2), new String[]{"nc0", "nc1"}); // Neuron MPI defaults to TP=1

// All devices (mpi mode, explicit tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 4), new String[]{"nc0+nc1", "nc2+nc3"});

// All devices (mpi mode, 0 tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 2, 0), new String[]{"0", "1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp0Properties, 0, 2), new String[]{"nc0", "nc1"});

// All devices (mpi mode, max tensor parallel)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tpMaxProperties, 0, 2), new String[]{"nc0+nc1"});

// All devices (insufficient TP assignment)
Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 1, 0));
Assert.assertThrows(() -> ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 1));

// All devices (uneven TP assignment)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 3, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 3), new String[]{"nc0+nc1"});

// All devices (only CPU available)
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", emptyProperties, 0, 0), new String[]{"-1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 0, 0), new String[]{"-1"}); // Ignored MPI and TPI and falls back to CPU

// Ways to enable mpi mode
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "Python", mpiProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", emptyProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "DeepSpeed", emptyProperties, 2, 0), new String[]{"0+1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "FasterTransformer", emptyProperties, 2, 0), new String[]{"0+1"});

// Ways to set tensor parallel
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", tp2Properties, 4, 0), new String[]{"0+1", "2+3"});
System.setProperty("TENSOR_PARALLEL_DEGREE", "2");
Assert.assertEquals(ModelInfo.getLoadOnDevices("*", "MPI", mpiProperties, 4, 0), new String[]{"0+1", "2+3"});
System.clearProperty("TENSOR_PARALLEL_DEGREE");

// Providing devices
Assert.assertEquals(ModelInfo.getLoadOnDevices("1", "", emptyProperties, 10, 10), new String[]{"1"});
Assert.assertEquals(ModelInfo.getLoadOnDevices("1;2", "", emptyProperties, 10, 10), new String[]{"1", "2"});

// Empty case
Assert.assertEquals(ModelInfo.getLoadOnDevices("", "", emptyProperties, 10, 10), new String[]{"-1"});
}

@Test
public void testInitModel() throws IOException, ModelException {
Path modelStore = Paths.get("build/models");
Expand Down

0 comments on commit e0e8206

Please sign in to comment.