Skip to content

Commit

Permalink
Add option to specify device when in single GPU setup
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 1, 2024
1 parent 61625d1 commit 34504d5
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
6 changes: 3 additions & 3 deletions llmc/zero.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ int multi_gpu_get_local_device_idx(int process_rank, int num_processes) {

#endif

MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char* server_ip, char* fs_path, char* init_method) {
MultiGpuConfig multi_gpu_config_init(int device, int num_processes, int process_rank, int gpus_per_node, char* server_ip, char* fs_path, char* init_method) {
#ifdef MULTI_GPU
MultiGpuConfig result;
ncclUniqueId nccl_id;
Expand Down Expand Up @@ -455,11 +455,11 @@ MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gp
return result;
#else
printf("Multi-GPU support is disabled. Using a single GPU.\n");
cudaCheck(cudaSetDevice(0));
cudaCheck(cudaSetDevice(device));
MultiGpuConfig result;
result.process_rank = 0;
result.num_processes = 1;
result.local_device_idx = 0;
result.local_device_idx = device;
return result;
#endif
}
Expand Down
3 changes: 2 additions & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ int main(int argc, char *argv[]) {
int gpus_per_node = -1; // doesn't matter when using MPI
char server_ip[256] = ""; // doesn't matter when using MPI
char fs_path[256] = ""; // doesn't matter when using MPI
multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
int device = 0;
multi_gpu_config = multi_gpu_config_init(device, num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
common_start(true, true);

// build the GPT-2 model from a checkpoint
Expand Down
3 changes: 2 additions & 1 deletion test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ int main(int argc, char *argv[]) {
int gpus_per_node = -1; // doesn't matter when using MPI
char server_ip[256] = ""; // doesn't matter when using MPI
char fs_path[256] = ""; // doesn't matter when using MPI
multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
int device = 0;
multi_gpu_config = multi_gpu_config_init(device, num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
common_start(false, true);

// set the right paths
Expand Down
9 changes: 7 additions & 2 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,8 @@ void error_usage() {
// memory management
fprintf(stderr, " -z <int> zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n");
fprintf(stderr, " -r <int> recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n");
// single-node settings
fprintf(stderr, " -dn <int> which GPU to use (default = 0)\n");
// multi-node settings
fprintf(stderr, " -pn <int> num_processes (default = 1)\n");
fprintf(stderr, " -pr <int> process_rank (default = 0)\n");
Expand Down Expand Up @@ -1432,6 +1434,8 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;
// single-node settings
int device = 0; // which GPU to use
// multi-node settings
int num_processes = 1; // this should be set by the slurm environment
int process_rank = 0; // this should be set by the slurm environment
Expand All @@ -1452,7 +1456,8 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); }
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size
else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); }
else if (argv[i][1] == 'd' && argv[i][2] == '\0') { total_batch_size = atoi(argv[i+1]); }
else if (argv[i][1] == 'd' && argv[i][2] == 'n') { device = atoi(argv[i+1]); }
else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); }
else if (argv[i][1] == 'u') { warmup_iterations = atoi(argv[i+1]); }
else if (argv[i][1] == 'q') { final_learning_rate_frac = atof(argv[i+1]); }
Expand Down Expand Up @@ -1481,7 +1486,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); }
else { error_usage(); }
}
multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);
multi_gpu_config = multi_gpu_config_init(device, num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method);

// should do a bit more error checking here
assert(warmup_iterations >= 0);
Expand Down

0 comments on commit 34504d5

Please sign in to comment.