Skip to content

Commit

Permalink
Cleans tensorParallelDegree with MultiDevice
Browse files Browse the repository at this point in the history
This is a refactor to simplify the handling of tensor parallel degree. Before,
it is read independently in 3+ locations in code and the behavior determining
the tpDegree is hard to follow. This moves the reading to a single place and
then represents it using a MultiDevice. This also means that the behavior can be
entirely visible - a worker group will show up with the tp devices that it's
worker will be using.
  • Loading branch information
zachgk committed Dec 11, 2023
1 parent 6c412d6 commit f0f1ab3
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.python.engine;

import ai.djl.Device;
import ai.djl.Device.MultiDevice;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.inference.streaming.ChunkedBytesSupplier;
Expand Down Expand Up @@ -123,7 +124,11 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) {
Device device = model.getNDManager().getDevice();
int deviceId = device.getDeviceId();
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
int tensorParallelDegree = 0;
if (model.getNDManager().getDevice() instanceof MultiDevice) {
tensorParallelDegree =
((MultiDevice) model.getNDManager().getDevice()).getDevices().size();
}
if (pyEnv.isMpiMode()) {
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
Expand Down
42 changes: 4 additions & 38 deletions engines/python/src/main/java/ai/djl/python/engine/PyEnv.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
*/
package ai.djl.python.engine;

import ai.djl.Device.MultiDevice;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.util.NeuronUtils;
import ai.djl.util.Platform;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;
Expand Down Expand Up @@ -52,7 +52,6 @@ public class PyEnv {
private String handler;
private int predictTimeout;
private int modelLoadingTimeout;
private int tensorParallelDegree;
private Map<String, String> envs;
private Map<String, String> initParameters;
private boolean initialized;
Expand Down Expand Up @@ -302,41 +301,7 @@ public void setPythonExecutable(String pythonExecutable) {
this.pythonExecutable = pythonExecutable;
}

/**
* Returns the tensor parallel degree.
*
* @return the tensor parallel degree
*/
public int getTensorParallelDegree() {
if (tensorParallelDegree == 0) {
String value = Utils.getenv("TENSOR_PARALLEL_DEGREE");
if ("max".equals(value)) {
tensorParallelDegree = getDefaultTensorParallelDegree();
} else if (value != null) {
tensorParallelDegree = Integer.parseInt(value);
}
}
return tensorParallelDegree;
}

static int getDefaultTensorParallelDegree() {
int gpus = CudaUtils.getGpuCount();
if (gpus > 0) {
return gpus;
}
return NeuronUtils.getNeuronCores();
}

/**
* Sets the tensor parallel degree.
*
* @param tensorParallelDegree the tensor parallel degree
*/
public void setTensorParallelDegree(int tensorParallelDegree) {
this.tensorParallelDegree = tensorParallelDegree;
}

int getMpiWorkers() {
int getMpiWorkers(MultiDevice multiDevice) {
int gpuCount = CudaUtils.getGpuCount();
String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES");
if (gpuCount > 0 && visibleDevices != null) {
Expand All @@ -346,7 +311,8 @@ int getMpiWorkers() {
}
gpuCount = visibleCount;
}
return gpuCount / getTensorParallelDegree();
int tensorParallelDegree = multiDevice.getDevices().size();
return gpuCount / tensorParallelDegree;
}

/**
Expand Down
56 changes: 26 additions & 30 deletions engines/python/src/main/java/ai/djl/python/engine/PyModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import ai.djl.BaseModel;
import ai.djl.Device;
import ai.djl.Device.MultiDevice;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.translate.Translator;
import ai.djl.util.Utils;
import ai.djl.util.cuda.CudaUtils;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -127,13 +127,6 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
case "parallel_loading":
parallelLoading = Boolean.parseBoolean(value);
break;
case "tensor_parallel_degree":
if ("max".equals(value)) {
pyEnv.setTensorParallelDegree(PyEnv.getDefaultTensorParallelDegree());
} else {
pyEnv.setTensorParallelDegree(Integer.parseInt(value));
}
break;
case "handler":
pyEnv.setHandler(value);
break;
Expand Down Expand Up @@ -164,8 +157,13 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
entryPoint = modelFile.toFile().getName();
} else if ("DeepSpeed".equals(engineName)) {
entryPoint = "djl_python.deepspeed";
} else if ("nc".equals(manager.getDevice().getDeviceType())
&& pyEnv.getTensorParallelDegree() > 0) {
} else if (manager.getDevice() instanceof MultiDevice
&& "nc"
.equals(
((MultiDevice) manager.getDevice())
.getDevices()
.get(0)
.getDeviceType())) {
entryPoint = "djl_python.transformers_neuronx";
} else if (isTrtLlmBackend) {
entryPoint = "djl_python.tensorrt_llm";
Expand Down Expand Up @@ -200,17 +198,13 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
}

if (pyEnv.isMpiMode()) {
int partitions = pyEnv.getTensorParallelDegree();
if (partitions == 0) {
partitions = CudaUtils.getGpuCount();
pyEnv.setTensorParallelDegree(partitions);
setProperty("tensor_parallel_degree", String.valueOf(partitions));
logger.info(
"No tensor parallel degree specified. Defaulting to all available GPUs.");
}
logger.info("Loading model in MPI mode with TP: {}.", partitions);
MultiDevice multiDevice = (MultiDevice) manager.getDevice();
int partitions = multiDevice.getDevices().size();
setProperty("tensor_parallel_degree", String.valueOf(partitions));
logger.info(
"Loading model in MPI mode with TP: {}. Devices {}", partitions, multiDevice);

int mpiWorkers = pyEnv.getMpiWorkers();
int mpiWorkers = pyEnv.getMpiWorkers(multiDevice);
if (mpiWorkers <= 0) {
throw new EngineException(
"GPU devices are not enough to run " + partitions + " partitions.");
Expand All @@ -233,13 +227,14 @@ public void load(Path modelPath, String prefix, Map<String, ?> options) throws I
+ " but the value is set to "
+ getProperty("gpu.maxWorkers"));
}
mpiWorkers = Integer.parseInt(getProperty("gpu.maxWorkers"));

properties.forEach((k, v) -> pyEnv.addParameter(k, v));

createAllPyProcesses(mpiWorkers, partitions);
createAllPyProcesses(multiDevice.getDevices(), partitions);
} else {
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
int tensorParallelDegree = 0;
if (manager.getDevice() instanceof MultiDevice) {
tensorParallelDegree = ((MultiDevice) manager.getDevice()).getDevices().size();
}
if (tensorParallelDegree > 0) {
if (getProperty("maxWorkers") == null && getProperty("gpu.maxWorkers") == null) {
setProperty("gpu.minWorkers", "1");
Expand Down Expand Up @@ -300,21 +295,22 @@ private Path findModelFile(String prefix) {
return modelFile;
}

private void createAllPyProcesses(int mpiWorkers, int tp) {
private void createAllPyProcesses(List<Device> devices, int tp) {
int mpiWorkers = devices.size();
long begin = System.currentTimeMillis();
ExecutorService pool = null;
List<Future<?>> futures = new ArrayList<>();
if (parallelLoading) {
pool = Executors.newFixedThreadPool(mpiWorkers);
}
logger.info("Start {} mpiWorkers ...", mpiWorkers);
int deviceId = manager.getDevice().getDeviceId();
for (int i = 0; i < mpiWorkers; ++i) {
logger.debug("Pre-creating python worker: {} ", i);
PyProcess worker = new PyProcess(this, pyEnv, deviceId + i * tp);
for (Device device : devices) {
int deviceId = device.getDeviceId();
logger.debug("Pre-creating python worker: {} ", deviceId);
PyProcess worker = new PyProcess(this, pyEnv, deviceId * tp);
workerQueue.offer(worker);
if (pool != null) {
logger.debug("Submitting to pool: {}", i);
logger.debug("Submitting to pool: {}", deviceId);
futures.add(pool.submit(worker::startPythonProcess));
} else {
worker.startPythonProcess();
Expand Down
10 changes: 6 additions & 4 deletions engines/python/src/main/java/ai/djl/python/engine/PyProcess.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
*/
package ai.djl.python.engine;

import ai.djl.Device;
import ai.djl.Device.MultiDevice;
import ai.djl.Model;
import ai.djl.engine.EngineException;
import ai.djl.metric.Metric;
Expand Down Expand Up @@ -59,12 +61,12 @@ class PyProcess {
this.workerId = workerId;
int port = counter.getAndIncrement();
if (pyEnv.isMpiMode()) {
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
connections = new ArrayList<>(tensorParallelDegree);
for (int i = 0; i < tensorParallelDegree; ++i) {
List<Device> devices = ((MultiDevice) model.getNDManager().getDevice()).getDevices();
connections = new ArrayList<>(devices.size());
for (int i = 0; i < devices.size(); ++i) {
connections.add(new Connection(pyEnv, port, i));
}
counter.set(port + tensorParallelDegree);
counter.set(port + devices.size());
} else {
connections = Collections.singletonList(new Connection(pyEnv, port, -1));
}
Expand Down
Loading

0 comments on commit f0f1ab3

Please sign in to comment.