Skip to content

Commit

Permalink
Add learning rate schedulers
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Jun 17, 2024
1 parent fa7f2b6 commit ff05329
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
43 changes: 43 additions & 0 deletions llmc/schedulers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@


// Cosine learning rate scheduler

#ifndef SCHEDULERS_H

#define SCHEDULERS_H

#include <assert.h>
#include <math.h>

typedef struct {
float learning_rate;
int warmup_iterations;
int train_num_batches;
float final_learning_rate_frac;
} CosineLearningRateScheduler;


// learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float get_learning_rate(CosineLearningRateScheduler *scheduler, int step) {
float step_learning_rate = scheduler->learning_rate;
if (step < scheduler->warmup_iterations) {
step_learning_rate = scheduler->learning_rate * ((float)(step + 1)) / scheduler->warmup_iterations;
} else {
float decay_ratio = ((float)(step - scheduler->warmup_iterations)) / (scheduler->train_num_batches - scheduler->warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = scheduler->learning_rate * scheduler->final_learning_rate_frac;
step_learning_rate = min_lr + coeff * (scheduler->learning_rate - min_lr);
}
return step_learning_rate;
}

void lr_scheduler_init(CosineLearningRateScheduler *scheduler, float learning_rate, int warmup_iterations, int train_num_batches, float final_learning_rate_frac) {
scheduler->learning_rate = learning_rate;
scheduler->warmup_iterations = warmup_iterations;
scheduler->train_num_batches = train_num_batches;
scheduler->final_learning_rate_frac = final_learning_rate_frac;
}

#endif // SCHEDULERS_H
20 changes: 8 additions & 12 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage.
#include "llmc/dataloader.h"
// defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal)
#include "llmc/rand.h"
// defines learning rate schedulers
#include "llmc/schedulers.h"
// defines: sample_softmax, random_f32
#include "llmc/sampler.h"
// defines: logger_init, logger_log_eval, logger_log_val, logger_log_train
Expand Down Expand Up @@ -1506,6 +1508,10 @@ int main(int argc, char *argv[]) {
Tokenizer tokenizer;
tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");

// set up learning rate scheduler
CosineLearningRateScheduler lr_scheduler;
lr_scheduler_init(&lr_scheduler, learning_rate, warmup_iterations, train_num_batches, final_learning_rate_frac);

// some memory for generating samples from the model
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX));
Expand Down Expand Up @@ -1664,18 +1670,8 @@ int main(int argc, char *argv[]) {
model.mean_loss = lossf;
// average the loss and the gradients between all processes
gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config);
// learning rate schedule: warmup linearly to max LR, then cosine decay to LR * final_learning_rate_frac
float step_learning_rate = learning_rate;
if (step < warmup_iterations) {
step_learning_rate = learning_rate * ((float)(step + 1)) / warmup_iterations;
} else {
float decay_ratio = ((float)(step - warmup_iterations)) / (train_num_batches - warmup_iterations);
assert(0.0f <= decay_ratio && decay_ratio <= 1.0f);
float coeff = 0.5f * (1.0f + cosf(M_PI * decay_ratio)); // coeff starts at 1 and goes to 0
assert(0.0f <= coeff && coeff <= 1.0f);
float min_lr = learning_rate * final_learning_rate_frac;
step_learning_rate = min_lr + coeff * (learning_rate - min_lr);
}
// fetch the next learning rate
float step_learning_rate = get_learning_rate(&lr_scheduler, step);
// update the model parameters
float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config);
// zero out the gradients for the next iteration
Expand Down

0 comments on commit ff05329

Please sign in to comment.