Skip to content

Commit

Permalink
Add RoPE - support on cmdline
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jul 28, 2024
1 parent 362c6a8 commit 242566b
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ typedef struct {
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
int use_rope; // use rope position encoding
} GPT2;

void gpt2_init_common(GPT2 *model) {
Expand Down Expand Up @@ -347,6 +348,8 @@ void gpt2_init_common(GPT2 *model) {
model->init_state = true;
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())
// architecture specific settings
model->use_rope = 0; // use rope position encoding
}

void gpt2_allocate_weights(GPT2 *model) {
Expand Down Expand Up @@ -1364,6 +1367,8 @@ void error_usage() {
// memory management
fprintf(stderr, " -z <int> zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n");
fprintf(stderr, " -r <int> recompute: less memory but less speed. (default = 1), 0|1|2 = none,gelu,gelu+ln\n");
// architectural settings
fprintf(stderr, " -er <int> enable RoPE positional embeddings? (default = 0)\n");
// multi-node settings
fprintf(stderr, " -pn <int> num_processes (default = 1)\n");
fprintf(stderr, " -pr <int> process_rank (default = 0)\n");
Expand Down Expand Up @@ -1408,6 +1413,8 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
int hellaswag_eval = 0;
// architectural settings
int use_rope = 0; // use RoPE positional embeddings
// multi-node settings
int num_processes = 1; // this should be set by the slurm environment
int process_rank = 0; // this should be set by the slurm environment
Expand All @@ -1422,7 +1429,7 @@ int main(int argc, char *argv[]) {
// read in the args
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'e' && argv[i][2] == '\0') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; }
else if (argv[i][1] == 'n' && argv[i][2] == '\0') { checkpoint_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); }
Expand Down Expand Up @@ -1456,6 +1463,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'k') { checkpoints_keep = atoi(argv[i+1]); }
else if (argv[i][1] == 'n' && argv[i][2] == 'm') { major_checkpoint_every = atoi(argv[i+1]); }
else if (argv[i][1] == 'e' && argv[i][2] == 'r') { use_rope = atoi(argv[i+1]); }
else { error_usage(); }
}

Expand Down Expand Up @@ -1530,6 +1538,8 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model
GPT2 model;
gpt2_init_common(&model);
// architectural modifications
model.use_rope = use_rope;
if (resuming == 1) {
// if `-y 1` was set, then we are resuming from the latest checkpoint
gpt2_build_from_checkpoint(&model, filename_buffer);
Expand Down

0 comments on commit 242566b

Please sign in to comment.