-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
train_gpt2.cu
1904 lines (1751 loc) · 102 KB
/
train_gpt2.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
GPT-2 Transformer Neural Net training loop. See README.md for usage.
*/
#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string>
#include <string_view>
#include <sys/stat.h>
#include <sys/types.h>
// ----------- CPU utilities -----------
// defines: fopenCheck, freadCheck, fcloseCheck, fseekCheck, mallocCheck
// defines: create_dir_if_not_exists, find_max_step, ends_with_bin
#include "llmc/utils.h"
// defines: tokenizer_init, tokenizer_decode, tokenizer_free
#include "llmc/tokenizer.h"
// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free
// defines: evalloader_init, evalloader_reset, evalloader_next_batch, evalloader_free
#include "llmc/dataloader.h"
// defines: manual_seed, normal_ (same as torch.manual_seed and torch.normal)
#include "llmc/rand.h"
// defines: lr_scheduler_init, get_learning_rate
#include "llmc/schedulers.h"
// defines: sample_softmax, random_f32
#include "llmc/sampler.h"
// defines: logger_init, logger_log_eval, logger_log_val, logger_log_train
#include "llmc/logger.h"
// defines: get_flops_promised
#include "llmc/mfu.h"
// defines: OutlierDetector, init_detector, update_detector
#include "llmc/outlier_detector.h"
// ----------- GPU utilities -----------
// defines:
// WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE
// NVTX_RANGE_FN
#include "llmc/cuda_common.h"
// defines:
// Packed128, f128, x128
// warpReduceSum, warpReduceMax, blockReduce, copy_and_cast_kernel, cudaMallocConditionallyManaged
#include "llmc/cuda_utils.cuh"
// defines: CUBLAS_LOWP, cublasCheck, cublaslt_workspace_size, cublaslt_workspace
// defines: cublas_compute, cublaslt_handle, cublas_handle
#include "llmc/cublas_common.h"
// ----------- Layer implementations in CUDA -----------
// defines: encoder_forward, encoder_backward
#include "llmc/encoder.cuh"
// defines: layernorm_forward, residual_forward, fused_residual_forward5, layernorm_backward
#include "llmc/layernorm.cuh"
// defines: matmul_cublaslt, matmul_forward, matmul_backward, gelu_forward, gelu_backward_inplace
#include "llmc/matmul.cuh"
#ifdef ENABLE_CUDNN
// defines: create_cudnn, destroy_cudnn, attention_forward_cudnn, attention_backward_cudnn
#include "llmc/cudnn_att.h"
#else
// defines: attention_forward, attention_backward
#include "llmc/attention.cuh"
#endif
// defines: fused_classifier
#include "llmc/fused_classifier.cuh"
// defines: adamw_kernel3
#include "llmc/adamw.cuh"
// defines: global_norm_squared
#include "llmc/global_norm.cuh"
// ----------- Multi-GPU support -----------
// defines: ncclFloatX, ncclCheck, MultiGpuConfig, ShardInfo
// defines: printf0, multi_gpu_config
// defines: multi_gpu_config_init, multi_gpu_config_free
// defines: set_zero_configs, multi_gpu_cpu_float_sum, multi_gpu_barrier
// defines: multi_gpu_get_shard_offset, multi_gpu_async_reduce_gradient
#include "llmc/zero.cuh"
// ----------------------------------------------------------------------------
// global vars for I/O
char filename_buffer[512];
// ----------------------------------------------------------------------------
// global vars containing information about the GPU this process is running on
cudaDeviceProp deviceProp; // fills in common_start()
cudaStream_t main_stream;
// buffer size to use for device <-> disk io
constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024;
// ----------------------------------------------------------------------------
// GPT-2 model definition
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int padded_vocab_size; // padded to e.g. %128==0, 50304
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
} GPT2Config;
// the parameters of the model
constexpr const int NUM_PARAMETER_TENSORS = 16;
typedef struct {
floatX* wte; // (V, C)
floatX* wpe; // (maxT, C)
floatX* ln1w; // (L, C)
floatX* ln1b; // (L, C)
floatX* qkvw; // (L, 3*C, C)
floatX* qkvb; // (L, 3*C)
floatX* attprojw; // (L, C, C)
floatX* attprojb; // (L, C)
floatX* ln2w; // (L, C)
floatX* ln2b; // (L, C)
floatX* fcw; // (L, 4*C, C)
floatX* fcb; // (L, 4*C)
floatX* fcprojw; // (L, C, 4*C)
floatX* fcprojb; // (L, C)
floatX* lnfw; // (C)
floatX* lnfb; // (C)
} ParameterTensors;
static_assert(sizeof(ParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");
void fill_in_parameter_sizes(size_t* param_sizes, size_t* param_sizeof, GPT2Config config) {
size_t Vp = config.padded_vocab_size;
size_t C = config.channels;
size_t maxT = config.max_seq_len;
size_t L = config.num_layers;
param_sizes[0] = Vp * C; // wte
param_sizes[1] = maxT * C; // wpe
param_sizes[2] = L * C; // ln1w
param_sizes[3] = L * C; // ln1b
param_sizes[4] = L * (3 * C) * C; // qkvw
param_sizes[5] = L * (3 * C); // qkvb
param_sizes[6] = L * C * C; // attprojw
param_sizes[7] = L * C; // attprojb
param_sizes[8] = L * C; // ln2w
param_sizes[9] = L * C; // ln2b
param_sizes[10] = L * (4 * C) * C; // fcw
param_sizes[11] = L * (4 * C); // fcb
param_sizes[12] = L * C * (4 * C); // fcprojw
param_sizes[13] = L * C; // fcprojb
param_sizes[14] = C; // lnfw
param_sizes[15] = C; // lnfb
// populate the parameter sizes in bytes (all the same for now, keeping for future use)
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
param_sizeof[i] = sizeof(floatX);
}
}
// allocate memory for the parameters and point the individual tensors to the right places
void* malloc_and_point_parameters(ParameterTensors* params, size_t* param_elements, size_t *param_sizeof) {
// calculate the total number of parameters and bytes across all tensors
size_t num_parameters_bytes = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters_bytes += param_elements[i] * param_sizeof[i];
}
// malloc all parameters all at once on the device
void* params_memory;
cudaCheck(cudaMalloc((void**)¶ms_memory, num_parameters_bytes));
// assign all the tensors their place in the array
floatX** ptrs[] = {
¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb,
¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb,
¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb
};
char* params_memory_iterator = (char*)params_memory;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = (floatX*)params_memory_iterator;
params_memory_iterator += param_elements[i] * param_sizeof[i];
}
return params_memory;
}
constexpr int NUM_ACTIVATION_TENSORS = 21;
typedef struct {
floatX* encoded; // (B, T, C)
floatX* ln1; // (L, B, T, C)
float* ln1_mean; // (L, B, T)
float* ln1_rstd; // (L, B, T)
floatX* atty; // (L, B, T, C)
// cuDNN saves only some statistics information
#if ENABLE_CUDNN
float* att; // (L, B, NH, T)
#else
floatX* att; // (L, B, NH, T, T)
#endif
floatX* residual2; // (L, B, T, C)
floatX* ln2; // (L, B, T, C)
float* ln2_mean; // (L, B, T)
float* ln2_rstd; // (L, B, T)
floatX* fch; // (L, B, T, 4*C)
floatX* fch_gelu; // (L, B, T, 4*C)
floatX* residual3; // (L, B, T, C)
floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms
float* lnf_mean; // (B, T)
float* lnf_rstd; // (B, T)
float* losses; // (B, T), will be accumulated in micro-steps
// adding these two compared to the CPU .c code, needed for attention kernel as buffers
floatX* qkvr; // (L, B, T, 3*C)
// in inference mode, this buffer will store the logits
// in training mode, this buffer will contain the *gradients* of the logits.
// during the processing of transformer blocks, we will also use this as a
// general scratchpad buffer. Allocation is made large enough to hold (B, T, 3C),
// (B, NH, T, T), and (B, T, V) shaped tensors.
floatX* output;
// some additional scratch buffers
floatX* scratch_bt4c; // (B, T, 4*C)
floatX* scratch_btc; // (B, T, C)
} ActivationTensors;
struct TensorSpec {
void** ptr;
size_t size;
DType type;
};
#define TENSOR_SPEC(pointer, size) TensorSpec{(void**)(&pointer), (size), dtype_of(pointer)};
void fill_in_activation_sizes(const ActivationTensors* data, TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS], size_t B, size_t T, GPT2Config config, int recompute) {
size_t Vp = config.padded_vocab_size;
size_t L = config.num_layers;
size_t NH = config.num_heads;
size_t C = config.channels;
tensors[0] = TENSOR_SPEC(data->encoded, B * T * C);
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
tensors[1] = TENSOR_SPEC(data->ln1, (recompute < 2) ? L * B * T * C : 0);
tensors[2] = TENSOR_SPEC(data->ln1_mean, L * B * T);
tensors[3] = TENSOR_SPEC(data->ln1_rstd, L * B * T);
tensors[4] = TENSOR_SPEC(data->atty, L * B * T * C);
#ifdef ENABLE_CUDNN
// FP32 stats tensor for cuDNN to be passed to backward pass
tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T);
#else
tensors[5] = TENSOR_SPEC(data->att, L * B * NH * T * T);
#endif
tensors[6] = TENSOR_SPEC(data->residual2, L * B * T * C);
// if recompute >= 1 then we will recompute the layernorm forward activation during backward pass
tensors[7] = TENSOR_SPEC(data->ln2, (recompute < 2) ? L * B * T * C : 0);
tensors[8] = TENSOR_SPEC(data->ln2_mean, L * B * T);
tensors[9] = TENSOR_SPEC(data->ln2_rstd, L * B * T);
tensors[10] = TENSOR_SPEC(data->fch, L * B * T * 4*C);
// if recompute >= 1 then we will recompute gelu_forward during backward and use this as scratch buffer
tensors[11] = TENSOR_SPEC(data->fch_gelu, (recompute < 1) ? L * B * T * 4*C : B * T * 4*C);
tensors[12] = TENSOR_SPEC(data->residual3, L * B * T * C);
tensors[13] = TENSOR_SPEC(data->lnf, B * T * C);
tensors[14] = TENSOR_SPEC(data->lnf_mean, B * T);
tensors[15] = TENSOR_SPEC(data->lnf_rstd, B * T);
tensors[16] = TENSOR_SPEC(data->losses, B * T);
tensors[17] = TENSOR_SPEC(data->qkvr, L * B * T * 3*C);
tensors[18] = TENSOR_SPEC(data->output, B * T * max(3*C, max(NH*T, Vp)));
tensors[19] = TENSOR_SPEC(data->scratch_bt4c, B * T * 4 * C);
tensors[20] = TENSOR_SPEC(data->scratch_btc, B * T * C);
}
void* malloc_and_point_activations(TensorSpec (&tensors)[NUM_ACTIVATION_TENSORS]) {
size_t bytes = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
bytes += tensors[i].size * sizeof_dtype(tensors[i].type);
}
printf0("allocating %d MiB for activations\n", (int)round(bytes / (1024 * 1024)));
void* acts_memory;
cudaCheck(cudaMalloc((void**)&acts_memory, bytes));
// cudaMalloc does not guarantee initial memory values so we memset the allocation here
// this matters because e.g. non-cuDNN attention assumes the attention buffer is zeroed
// todo - up to ~100ms on slow GPUs, could theoretically be more selective, but this is safer
cudaCheck(cudaMemset(acts_memory, 0, bytes));
char* acts_memory_iterator = (char*)acts_memory;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
// extra protection so we don't accidentally use an empty buffer
if(tensors[i].size == 0) {
*(tensors[i].ptr) = NULL;
}else {
*(tensors[i].ptr) = acts_memory_iterator;
acts_memory_iterator += tensors[i].size * sizeof_dtype(tensors[i].type);
}
}
return acts_memory;
}
typedef struct {
GPT2Config config;
// the weights of the model, and their sizes
ParameterTensors params;
size_t param_elements[NUM_PARAMETER_TENSORS];
size_t param_sizeof[NUM_PARAMETER_TENSORS];
void* params_memory;
size_t num_parameters;
size_t num_parameters_bytes;
// gradients of the weights
ParameterTensors grads;
void* grads_memory;
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
float* master_weights; // is NULL unless fp32 weights is enabled.
// the activations of the model, and their sizes
ActivationTensors acts;
TensorSpec acts_specs[NUM_ACTIVATION_TENSORS];
void* acts_memory;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
int* inputs; // the input tokens for the current forward pass
int* targets; // the target tokens for the current forward pass
float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps
float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps
float* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
unsigned long long rng_state_last_update; // RNG before last gpt2_update() to re-round identically from master weights
int use_master_weights; // keep master weights copy in float for optim update? 0|1
bool init_state; // set to true if master weights need to be initialized
int gelu_fusion; // fuse gelu via cuBLASLt (0=none, 1=forward, 2=forward+backward)
int recompute; // recompute gelu | layernorm forward during model backward? 0|1|2
// todo - if other functions need cpu scratch buffers in the future, reuse as generic scratch?
int* workload_indices; // encoder_backward, B*T*num_c_groups (int)
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
} GPT2;
void gpt2_init_common(GPT2 *model) {
// common inits outside of the model weights
// memory lazily initialized in forward()
model->acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->accumulated_mean_loss = NULL;
model->cpu_losses = NULL;
// the B,T params are determined and set, fixed on first batch in forward()
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f designates no loss, set at end of forward()
model->params_memory = NULL;
// memory lazily initialized in backward()
model->grads_memory = NULL;
model->workload_indices = NULL; // on cpu, for encoder_backward
model->bucket_info = NULL; // on cpu, for encoder_backward
// memory lazily initialized in update()
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
// other default settings
model->rng_state = 13371337 + multi_gpu_config.process_rank; // used in stochastic rounding
model->use_master_weights = 1; // safe default: do keep master weights in fp32
model->init_state = true;
model->recompute = 1; // good default: recompute gelu but not layernorm
model->gelu_fusion = 0; //deviceProp.major >= 9 ? 2 : 0; // default: off for now (default must match main())
}
void gpt2_allocate_weights(GPT2 *model) {
// fill in all the parameter tensor dimensions and types
fill_in_parameter_sizes(model->param_elements, model->param_sizeof, model->config);
model->num_parameters = 0;
model->num_parameters_bytes = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
model->num_parameters += model->param_elements[i];
model->num_parameters_bytes += model->param_elements[i] * model->param_sizeof[i];
}
// create memory for model parameters on the device
assert(model->params_memory == nullptr);
model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);
}
void gpt2_allocate_state(GPT2 *model, int B, int T) {
printf0("allocating %d MiB for parameter gradients\n", (int)round(model->num_parameters * sizeof(floatX) / (1024 * 1024)));
assert(model->grads_memory == nullptr);
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_elements, model->param_sizeof);
// record the current B,T as well
model->batch_size = B;
model->seq_len = T;
// allocate the space
fill_in_activation_sizes(&model->acts, model->acts_specs, B, T, model->config, model->recompute);
model->acts_memory = malloc_and_point_activations(model->acts_specs);
// also create memory for caching inputs and targets
cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int)));
cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int)));
cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float)));
cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(float)));
// initialise cpu scratch buffers for encoder backward
size_t num_c_groups = CEIL_DIV(model->config.channels, (WARP_SIZE * x128::size));
assert((size_t)(model->batch_size * model->seq_len) * num_c_groups < (1ULL<<31ULL)); // todo - maybe an issue for llama3-400B(?)
model->workload_indices = (int*)mallocCheck(sizeof(int) * model->batch_size * model->seq_len * num_c_groups);
model->bucket_info = (int4*)mallocCheck(sizeof(int4) * model->batch_size * model->seq_len * num_c_groups);
// cudaMallocConditionallyManaged can fall back to cudaMallocManaged if not enough memory on device
// and returns a status code of 1 if it had to fall back, in that case we want to print warning.
int memory_status = 0;
// we will now init the optimizer states and master weights
// this is usually a substantial amount of memory allocation right here.
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; // num parameters we are responsible for
printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20);
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
assert(model->m_memory == nullptr);
assert(model->v_memory == nullptr);
memory_status |= cudaMallocConditionallyManaged((void**)&model->m_memory, shard_num_parameters * sizeof(float));
memory_status |= cudaMallocConditionallyManaged((void**)&model->v_memory, shard_num_parameters * sizeof(float));
if (model->use_master_weights == 1) {
assert(model->master_weights == nullptr);
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
memory_status |= cudaMallocConditionallyManaged((void**) &model->master_weights, shard_num_parameters * sizeof(float));
}
// report on mixed memory allocation status (re-using our float reduce function, bit awk ok)
int reduced_memory_status = (int) multi_gpu_cpu_float_sum((float)memory_status, &multi_gpu_config);
if (reduced_memory_status >= 1) {
printf0("WARNING: Fell back to cudaMallocManaged when initializing m,v,master_weights on %d GPUs\n", reduced_memory_status);
printf0(" Prevents an OOM, but code may run much slower due to device <-> host memory movement\n");
}
// report on device memory usage
size_t free, total;
cudaCheck(cudaMemGetInfo(&free, &total));
printf0("device memory usage: %zd MiB / %zd MiB\n", (total-free) / 1024 / 1024, total / 1024 / 1024);
// give an estimate of the maximum batch size
size_t bytes_per_sequence = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
bytes_per_sequence += model->acts_specs[i].size * sizeof_dtype(model->acts_specs[i].type) / B;
}
printf0("memory per sequence: %zu MiB\n", bytes_per_sequence / 1024 / 1024);
printf0(" -> estimated maximum batch size: %zu\n", B + free / bytes_per_sequence);
}
void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
// write the model to a checkpoint file
printf0("Writing model to %s\n", checkpoint_path);
FILE *model_file = fopenCheck(checkpoint_path, "wb");
// write the header first
int model_header[256];
memset(model_header, 0, sizeof(model_header));
model_header[0] = 20240326; // magic number
assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16);
model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; // version
model_header[2] = model->config.max_seq_len;
model_header[3] = model->config.vocab_size;
model_header[4] = model->config.num_layers;
model_header[5] = model->config.num_heads;
model_header[6] = model->config.channels;
model_header[7] = model->config.padded_vocab_size;
fwriteCheck(model_header, sizeof(int), 256, model_file);
// write the parameters
device_to_file(model_file, model->params_memory, model->num_parameters_bytes,
IO_BUF_SIZE, main_stream);
// close file, we're done
fcloseCheck(model_file);
}
void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path, bool weight_init=true) {
// If weight_init is true, we will load the weights from this checkpoint .bin file
// We sometimes want this to be false, if we are going to initialize these weights from
// the master weights that are instead stored in the state .bin file.
// In that case, this function mostly loads the model hyperparameters from the header.
if (PRECISION_MODE == PRECISION_FP16) {
// TODO for later perhaps, would require us dynamically converting the
// model weights from fp32 to fp16 online, here in this function, or writing
// the fp16 weights directly from Python, which we only do for fp32/bf16 atm.
fprintf(stderr, "build_from_checkpoint() does not support fp16 right now.\n");
exit(EXIT_FAILURE);
}
// read in model from a checkpoint file
FILE *model_file = fopenCheck(checkpoint_path, "rb");
int model_header[256];
freadCheck(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file\n"); exit(EXIT_FAILURE); }
int version = model_header[1];
if (!(version == 3 || version == 5)) {
// 3 = fp32, padded vocab
// 5 = bf16, padded vocab, layernorms also in bf16
fprintf(stderr, "Bad version in model file\n");
fprintf(stderr, "---> HINT: try to re-run `python train_gpt2.py`\n");
exit(EXIT_FAILURE);
}
// check if the precision mode of the checkpoing matches the model precision
if (weight_init) {
if (PRECISION_MODE == PRECISION_BF16 && version != 5) {
fprintf(stderr, "Precision is configured as BF16 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: are you sure you're loading a _bf16.bin file?\n");
exit(EXIT_FAILURE);
}
if (PRECISION_MODE == PRECISION_FP32 && version != 3) {
fprintf(stderr, "Precision is configured as FP32 but model at %s is not.\n", checkpoint_path);
fprintf(stderr, "---> HINT: to turn on FP32 you have to compile like: `make train_gpt2cu PRECISION=FP32`\n");
fprintf(stderr, "---> HINT: are you sure you're loading a .bin file without any _bf16 in the name?\n");
exit(EXIT_FAILURE);
}
}
// read in hyperparameters
model->config.max_seq_len = model_header[2];
model->config.vocab_size = model_header[3];
model->config.num_layers = model_header[4];
model->config.num_heads = model_header[5];
model->config.channels = model_header[6];
model->config.padded_vocab_size = model_header[7];
// allocate memory for the model parameters
gpt2_allocate_weights(model);
// read in the parameters if weight_init is true
if (weight_init) {
assert(model->params_memory != NULL);
file_to_device(model->params_memory, model_file, model->num_parameters_bytes, IO_BUF_SIZE, main_stream);
}
fcloseCheck(model_file);
// only return from this function once we are certain the params are ready on the GPU
cudaCheck(cudaDeviceSynchronize());
}
void gpt2_set_hyperparameters(GPT2Config* config, const char* depth_str) {
int depth = atoi(depth_str);
assert(depth > 0); // atoi returns 0 if not a number
int channels, num_heads;
if (depth == 6) { channels = 384; num_heads = 6; } // (unofficial) gpt2-tiny (30M)
else if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M)
else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M)
else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M)
else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M)
else if (depth == 60) { channels = 1920; num_heads = 30; } // (unofficial) 2.7B
else if (depth == 72) { channels = 2880; num_heads = 30; } // (unofficial) 7.3B
else if (depth == 84) { channels = 3456; num_heads = 36; } // (unofficial) 12.2B
else { fprintf(stderr, "Unsupported GPT-2 depth: %d\n", depth); exit(EXIT_FAILURE); }
config->num_layers = depth;
config->channels = channels;
config->num_heads = num_heads;
config->max_seq_len = 1024;
}
void gpt3_set_hyperparameters(GPT2Config* config, const char* channels_str) {
// we use channels instead of depth for GPT-3 because GPT-3 model depths are not one-to-one
// note that our models are not necessarily identical to GPT-3 because
// we use dense attention, not the alternating dense/banded attention of GPT-3
int channels = atoi(channels_str);
assert(channels > 0); // atoi returns 0 if not a number
int depth, head_size;
if (channels == 384) { depth = 6; head_size = 64; } // (unofficial) gpt3-tiny (31M)
else if (channels == 768) { depth = 12; head_size = 64; } // gpt3-small (125M)
else if (channels == 1024) { depth = 24; head_size = 64; } // gpt3-medium (350M)
else if (channels == 1536) { depth = 24; head_size = 96; } // gpt3-large (760M)
else if (channels == 2048) { depth = 24; head_size = 128; } // gpt3-xl (1.3B) [heads fixed]
else if (channels == 2560) { depth = 32; head_size = 80; } // gpt3-2.7B
else if (channels == 4096) { depth = 32; head_size = 128; } // gpt3-6.7B
else if (channels == 5140) { depth = 40; head_size = 128; } // gpt3-13B
else if (channels == 12288) { depth = 96; head_size = 128; } // gpt3 (175B)
else { fprintf(stderr, "Unsupported GPT-3 channels: %d\n", channels); exit(EXIT_FAILURE); }
assert(channels % head_size == 0);
config->num_layers = depth;
config->channels = channels;
config->num_heads = channels / head_size;
config->max_seq_len = 2048; // NOTE: GPT-3 uses context length of 2048 tokens, up from 1024 in GPT-2
}
void gpt_build_from_descriptor(GPT2 *model, const char* descriptor) {
// The model descriptor can be:
// - legacy format "dX", where X is number, e.g. "d12". This creates GPT-2 model with 12 layers.
// - new explicit format "gpt2:dX", same as above, e.g. "gpt2:d48" for GPT-2 with 48 layers.
// - "gpt3:cX", where X is now the channel count, e.g. "gpt3:c768" is the smallest GPT-3 model.
// check the valid prexies and dispatch to the right setup function
assert(descriptor != NULL);
size_t len = strlen(descriptor);
if (len > 1 && descriptor[0] == 'd') {
gpt2_set_hyperparameters(&model->config, descriptor + 1); // pass along the depth str without the 'd'
} else if (len > 6 && strncmp(descriptor, "gpt2:d", 6) == 0) {
gpt2_set_hyperparameters(&model->config, descriptor + 6); // pass along the depth str without the 'gpt2:d'
} else if (len > 6 && strncmp(descriptor, "gpt3:c", 6) == 0) {
gpt3_set_hyperparameters(&model->config, descriptor + 6); // pass along the channels str without the 'gpt3:c'
} else {
fprintf(stderr, "Unsupported model descriptor: %s\n", descriptor); exit(EXIT_FAILURE);
}
// both GPT-2 and GPT-3 use the same tokenizer with 50257 tokens
model->config.vocab_size = 50257;
model->config.padded_vocab_size = 50304; // padded to 128 for CUDA kernel efficiency
gpt2_allocate_weights(model);
// allocate and random init the memory for all the parameters with GPT-2 schema
// weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5)
// NOTE: assuming all parameters are of the type floatX, could be relaxed later
mt19937_state init_rng;
manual_seed(&init_rng, 42);
floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes);
memset(params_memory_cpu, 0, model->num_parameters_bytes);
// fill in all the weights with random values
float residual_scale = 1.0f / sqrtf(2.0f * model->config.num_layers);
// we have to init all these tensors exactly in the order that PyTorch initializes them
// so that we can match them up and get correctness and exactly the same initial conditions
size_t L = model->config.num_layers;
size_t offset = 0;
for (int l = 0; l < L; l++) {
offset = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// the layernorm parameters are all initialized to 1
if (l == 0 && (i == 2 || i == 8 || i == 14)) { // only at l = 0 to init these just once
for (size_t j = 0; j < model->param_elements[i]; j++) {
params_memory_cpu[offset + j] = 1.0f;
}
}
// weights tensors are handled here
if ((l == 0 && (i == 0 || i == 1)) // only at l = 0, init the wte and wpe tensors
|| i == 4 || i == 6 || i == 10 || i == 12) {
size_t n = model->param_elements[i];
size_t layer_offset = 0;
if (i == 0) {
// for wte tensor (padded vocab) override to init V instead of Vp rows
n = model->config.vocab_size * model->config.channels;
}
if (i == 4 || i == 6 || i == 10 || i == 12) {
// weight tensors, we are only initializing layer l
assert(n % L == 0);
n = n / L;
layer_offset = l * n;
}
// in GPT-2, the projections back into the residual stream are additionally
// scaled by 1/sqrt(2*L) for training stability
float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;
// okay let's draw the random numbers and write them
float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));
normal_(fp32_buffer, n, 0.0f, scale, &init_rng);
for (size_t j = 0; j < n; j++) {
params_memory_cpu[offset + layer_offset + j] = (floatX)fp32_buffer[j];
}
free(fp32_buffer);
}
offset += model->param_elements[i];
}
}
// copy them to GPU
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);
}
// propagate inputs through the network to produce logits.
// right now, this function is fully synchronous with the host
void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) {
NVTX_RANGE_FN();
// we must be careful and use size_t instead of int, otherwise
// we could overflow int. E.g. l * B * NH * T * T overflows int at B 16.
// ensure the model was initialized or error out
if (model->params_memory == NULL) {
printf("Error: model was not initialized properly.\n");
exit(EXIT_FAILURE);
}
// convenience parameters
const size_t V = model->config.vocab_size;
const size_t Vp = model->config.padded_vocab_size;
const size_t L = model->config.num_layers;
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;
// validate B,T are not larger than the values used at initialisation
// (smaller B,T are okay for inference only)
if (B > model->batch_size || T > model->seq_len) {
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, (int)B, (int)T);
exit(EXIT_FAILURE);
}
// copy inputs/targets to the model
cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice));
// validate inputs, all indices must be in the range [0, V)
// we can do this while the copies are already underway
tokenCheck(inputs, B*T, V);
// forward pass
ParameterTensors params = model->params; // for brevity
ActivationTensors acts = model->acts;
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C, main_stream); // encoding goes into residual[0]
// first layernorm isn't fused
layernorm_forward((model->recompute < 2) ? acts.ln1 : acts.lnf, acts.ln1_mean, acts.ln1_rstd, acts.encoded, params.ln1w, params.ln1b, B, T, C, main_stream);
for (int l = 0; l < L; l++) {
NvtxRange layer_range("Layer", l);
floatX* residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
floatX* l_qkvw = params.qkvw + l * 3*C * C;
floatX* l_qkvb = params.qkvb + l * 3*C;
floatX* l_attprojw = params.attprojw + l * C * C;
floatX* l_attprojb = params.attprojb + l * C;
floatX* l_ln2w = params.ln2w + l * C;
floatX* l_ln2b = params.ln2b + l * C;
floatX* l_fcw = params.fcw + l * 4*C * C;
floatX* l_fcb = params.fcb + l * 4*C;
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
floatX* l_fcprojb = params.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
floatX* l_fch = acts.fch + l * B * T * 4*C;
// reuse the same activation buffer at each layer, as we'll re-compute the gelu during backward
// very useful because we dramatically reduce VRAM usage, and may be able to fit larger batch size
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
floatX* l_residual3 = acts.residual3 + l * B * T * C;
floatX* scratch = (floatX*)acts.output; // used for non-cudnn attention, fcproj, attproj, etc.
// now do the forward pass
#ifdef ENABLE_CUDNN
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
matmul_forward_cublaslt(l_qkvr, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
attention_forward_cudnn(l_atty, (float*)l_att, l_qkvr, B, T, NH, C, main_stream);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
if (T != model->seq_len) { // unused parts of attention buffer must be zeroed (T-dependent)
cudaCheck(cudaMemset(l_att, 0, B * NH * T * T * sizeof(floatX)));
}
// these are only needed as scratchpads for the forward pass, but
// need not be stored for backward
matmul_forward_cublaslt(scratch, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C, main_stream);
attention_forward(l_atty, l_qkvr, l_att, scratch, B, T, C, NH, main_stream);
#endif
matmul_forward_cublaslt(scratch, l_atty, l_attprojw, l_attprojb, B, T, C, C, main_stream);
fused_residual_forward5(l_residual2, l_ln2, l_ln2_mean, l_ln2_rstd, residual, scratch, l_ln2w, l_ln2b, B*T, C, main_stream);
matmul_forward_cublaslt(l_fch_gelu, l_ln2, l_fcw, l_fcb, B, T, C, 4*C, main_stream, l_fch, model->gelu_fusion);
matmul_forward_cublaslt(scratch, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C, main_stream);
// OK, fusion across blocks.
if(l+1 != L) {
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + (l + 1) * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + (l + 1) * B * T;
float* l_ln1_rstd = acts.ln1_rstd + (l + 1) * B * T;
const floatX* l_ln1w = params.ln1w + (l + 1) * C;
const floatX* l_ln1b = params.ln1b + (l + 1) * C;
fused_residual_forward5(l_residual3, l_ln1, l_ln1_mean, l_ln1_rstd, l_residual2, scratch, l_ln1w, l_ln1b,
B * T, C, main_stream);
} else {
fused_residual_forward5(l_residual3, acts.lnf, acts.lnf_mean, acts.lnf_rstd, l_residual2, scratch,
params.lnfw, params.lnfb,
B * T, C, main_stream);
}
}
matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
cudaCheck(cudaDeviceSynchronize());
}
// Forwards both the model and the loss and is used for validation splits and evals.
// In particular it populates cpu_losses with loss at each token.
// Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here.
float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) {
assert(targets != NULL);
// forward the model itself
gpt2_forward(model, inputs, B, T);
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t V = model->config.vocab_size;
const size_t Vp = model->config.padded_vocab_size;
NvtxRange classifier_and_loss_range("classifier_and_loss");
ActivationTensors acts = model->acts;
float mean_loss = 0.0f;
// fused classifier: does the forward pass and first part of the backward pass
const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements
// note: we don't need to generate dlogits here
cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(float)));
cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));
tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets
fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream);
cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost));
for (int i = 0; i < B*T; i++) {
mean_loss += model->cpu_losses[i];
}
mean_loss /= B*T;
cudaCheck(cudaDeviceSynchronize());
return mean_loss;
}
void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, int micro_step) {
if(model->grads_memory == nullptr) {
fprintf(stderr, "Need to allocate gradients before backward");
exit(EXIT_FAILURE);
}
NVTX_RANGE_FN();
bool last_step = micro_step == grad_accum_steps - 1;
// on the first micro-step zero the gradients, as we're about to += accumulate into them
if (micro_step == 0) {
// there are currently two state vars during the gradient accumulation inner loop:
// 1) the losses accumulate += into acts.losses, reset here
// 2) the gradients accumulate += into grads_memory, reset here
cudaCheck(cudaMemsetAsync(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(float), main_stream));
cudaCheck(cudaMemsetAsync(model->grads_memory, 0, model->num_parameters * sizeof(floatX), main_stream));
}
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t B = model->batch_size;
const size_t T = model->seq_len;
const size_t V = model->config.vocab_size;
const size_t Vp = model->config.padded_vocab_size;
const size_t L = model->config.num_layers;
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;
ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
ActivationTensors acts = model->acts;
// accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier
NvtxRange classifier_and_loss_range("classifier_and_loss");
const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements
cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice));
tokenCheck(targets, B*T, V);
fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream);
// backward pass: go in the reverse order of the forward pass, and call backward() functions
// reset residual stream gradients (put here to work with gradient accumulation)
floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass
cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX)));
// re-use the output buffer of the forward pass as a scratchpad during backward pass
float* scratchF = (float*)acts.output;
floatX* scratchX = (floatX*)acts.output;
// we kick off the chain rule by filling in dlosses with 1.0f/(B*T)
// this was done in the fused classifier kernel as last step of forward pass
// technically that is a small, inline backward() pass of calculating
// total, final loss as the mean over all losses over all (B,T) positions in the batch
// next: backward the classifier matmul
matmul_backward(model->acts.scratch_bt4c, grads.wte, NULL, acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
// backward the final layernorm
floatX* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, scratchF, model->acts.scratch_bt4c, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C, main_stream);
// from this point on, we no longer need the values stored in the last residual, so we can reuse that memory as generic
// scratch for backward computations
floatX* dl_btc = residual;
// now backward all the layers
for (int l = L-1; l >= 0; l--) {
NvtxRange layer_range("Layer", l);
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
floatX* l_ln1w = params.ln1w + l * C;
floatX* l_ln1b = params.ln1b + l * C;
floatX* l_qkvw = params.qkvw + l * 3*C * C;
floatX* l_attprojw = params.attprojw + l * C * C;
floatX* l_ln2w = params.ln2w + l * C;
floatX* l_ln2b = params.ln2b + l * C;
floatX* l_fcw = params.fcw + l * 4*C * C;
floatX* l_fcprojw = params.fcprojw + l * C * 4*C;
// get the pointers of the gradients of the weights for this layer
floatX* dl_ln1w = grads.ln1w + l * C;
floatX* dl_ln1b = grads.ln1b + l * C;
floatX* dl_qkvw = grads.qkvw + l * 3*C * C;
floatX* dl_qkvb = grads.qkvb + l * 3*C;
floatX* dl_attprojw = grads.attprojw + l * C * C;
floatX* dl_attprojb = grads.attprojb + l * C;
floatX* dl_ln2w = grads.ln2w + l * C;
floatX* dl_ln2b = grads.ln2b + l * C;
floatX* dl_fcw = grads.fcw + l * 4*C * C;
floatX* dl_fcb = grads.fcb + l * 4*C;
floatX* dl_fcprojw = grads.fcprojw + l * C * 4*C;
floatX* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
floatX* l_ln1 = (model->recompute < 2) ? acts.ln1 + l * B * T * C : acts.lnf;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
floatX* l_qkvr = acts.qkvr + l * B * T * 3*C;
floatX* l_atty = acts.atty + l * B * T * C;
floatX* l_residual2 = acts.residual2 + l * B * T * C;
floatX* l_ln2 = (model->recompute < 2) ? acts.ln2 + l * B * T * C : acts.lnf;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
floatX* l_fch_pre_gelu = acts.fch + l * B * T * 4*C;
floatX* l_fch_gelu = (model->recompute < 1) ? acts.fch_gelu + l * B * T * 4*C : acts.fch_gelu;
// get the pointers of the gradients of the activations for this layer
// notice that there is no l *, because we just have a single copy, and keep
// re-using this memory in every Transformer block as we calculate backward pass
floatX* dl_bt4c = (floatX*)model->acts.scratch_bt4c;
// start the backward pass for this layer
if(model->recompute >= 1) {
// recompute >= 1 means we recompute gelu. in this case,
// l_fch_gelu is just a buffer, so re-compute the gelu from l_fch here
gelu_forward(l_fch_gelu, l_fch_pre_gelu, B*T*4*C, main_stream);
}
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C, main_stream, l_fch_pre_gelu, model->gelu_fusion);
if(model->recompute >= 2) {
// same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C, main_stream);
}
matmul_backward(dl_btc, dl_fcw, dl_fcb, dl_bt4c, l_ln2, l_fcw, scratchF, B, T, C, 4 * C, main_stream);
// layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above
layernorm_backward(dresidual, dl_ln2w, dl_ln2b, scratchF, dl_btc, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C, main_stream);
matmul_backward(dl_btc, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, scratchF, B, T, C, C, main_stream);
#ifdef ENABLE_CUDNN
float* l_att = (float*)acts.att + l * B * NH * T; // cuDNN needs a smaller FP32 tensor
attention_backward_cudnn(dl_bt4c, dl_btc, l_qkvr, l_atty, (float*)l_att, B, T, NH, C, main_stream);
#else
floatX* l_att = acts.att + l * B * NH * T * T;
// we need B x T x (4)C buffers. l_atty and l_fch aren't needed anymore at this point, so reuse their memory
floatX* buffer_a = l_atty;
floatX* buffer_b = l_fch_pre_gelu; // this is B x T x 4C, so even larger than what we need
attention_backward(dl_bt4c, buffer_b, scratchX, buffer_a, dl_btc, l_qkvr, l_att, B, T, C, NH, main_stream);
#endif
if(model->recompute >= 2) {
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C, main_stream);
}
// QKV parameter gradients
matmul_backward(dl_btc, dl_qkvw, dl_qkvb, dl_bt4c, l_ln1, l_qkvw, scratchF, B, T, C, 3 * C, main_stream);
// layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, scratchF, dl_btc, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C, main_stream);
// Accumulate gradients from this layer in a background stream.
if(last_step) {
floatX* const pointers[] = {
dl_ln1w, dl_ln1b,
dl_qkvw, dl_qkvb,
dl_attprojw, dl_attprojb,
dl_ln2w, dl_ln2b,
dl_fcw, dl_fcb,
dl_fcprojw, dl_fcprojb
};
const size_t nelem[] = {
C, C,
3 * C * C, 3 * C,
C * C, C,
C, C,
4 * C * C, 4 * C,
C * 4 * C, C
};
multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
}
}
encoder_backward(grads.wte, grads.wpe, scratchX, model->workload_indices, model->bucket_info,
dresidual, model->inputs, inputs, B, T, C, random_u32(&model->rng_state), main_stream);
// Aggregate all gradients that are not part of the transformer blocks
if(last_step) {
// reduce all the losses within the current GPU (across all microsteps)
global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream);
// reduce loss across GPUs to a single, final float across all microsteps and GPUs
#if MULTI_GPU
ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream));
#endif
cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream));
// reduce the gradients for non-transformer block parameters
floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb};
const size_t nelem[] = {Vp * C, T * C, C, C};
multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream);
}
cudaCheck(cudaDeviceSynchronize());
if(last_step) {
model->mean_loss /= B*T*grad_accum_steps;
} else {
model->mean_loss = -1.f; // no loss available yet
}
}
// Gets the offset of a specific tensor for a specific layer in the GPT2 model
// layer_id is ignored for weights that are not part of a transformer block
ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) {
// first offset our way to the parameter tensor start
ptrdiff_t offset = 0;
for (int i = 0; i < param_tensor_id; i++) {
offset += (ptrdiff_t)model->param_elements[i];
}
size_t size = model->param_elements[param_tensor_id] ;
// if we are in the transformer block, we need to additionally offset by the layer id
if(2 <= param_tensor_id && param_tensor_id <= 13) {
size /= model->config.num_layers;
offset += (ptrdiff_t)(layer_id * size);
}
return {offset, size};
}
float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) {
NVTX_RANGE_FN();
floatX* grads_memory = (floatX*)model->grads_memory;
// repurposing this buffer (which isn't needed now) to write grad norm into it
float* grad_norm_squared = (float*)model->acts.output;
float grad_norm_squared_cpu = 0.0f;
int num_slices[2] = {1, model->config.num_layers};