forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_gpt2.c
1116 lines (1018 loc) · 44.1 KB
/
train_gpt2.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
This file trains the GPT-2 model.
This version is the clean, minimal, reference. As such:
- it runs on CPU.
- it does not make the code too complex; it is readable.
- it does not use any processor-specific instructions, intrinsics and such.
- it _does_ use a few OpenMP pragmas because this is a large speedup at very low cost
There will be other versions of this code that specialize it and make it fast.
*/
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>
#include <string.h>
#include <unistd.h>
#ifdef OMP
#include <omp.h>
#endif
// ----------------------------------------------------------------------------
// all the individual layers' forward and backward passes
void encoder_forward(float* out,
int* inp, float* wte, float* wpe,
int B, int T, int C) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
// get the index of the token at inp[b, t]
int ix = inp[b * T + t];
// seek to the position in wte corresponding to the token
float* wte_ix = wte + ix * C;
// seek to the position in wpe corresponding to the position
float* wpe_t = wpe + t * C;
// add the two vectors and store the result in out[b,t,:]
for (int i = 0; i < C; i++) {
out_bt[i] = wte_ix[i] + wpe_t[i];
}
}
}
}
void encoder_backward(float* dwte, float* dwpe,
float* dout, int* inp,
int B, int T, int C) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * C + t * C;
int ix = inp[b * T + t];
float* dwte_ix = dwte + ix * C;
float* dwpe_t = dwpe + t * C;
for (int i = 0; i < C; i++) {
float d = dout_bt[i];
dwte_ix[i] += d;
dwpe_t[i] += d;
}
}
}
}
void layernorm_forward(float* out, float* mean, float* rstd,
float* inp, float* weight, float* bias,
int B, int T, int C) {
float eps = 1e-5f;
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the input position inp[b,t,:]
float* x = inp + b * T * C + t * C;
// calculate the mean
float m = 0.0f;
for (int i = 0; i < C; i++) {
m += x[i];
}
m = m/C;
// calculate the variance (without any bias correction)
float v = 0.0f;
for (int i = 0; i < C; i++) {
float xshift = x[i] - m;
v += xshift * xshift;
}
v = v/C;
// calculate the rstd
float s = 1.0f / sqrtf(v + eps);
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
for (int i = 0; i < C; i++) {
float n = (s * (x[i] - m)); // normalized output
float o = n * weight[i] + bias[i]; // scale and shift it
out_bt[i] = o; // write
}
// cache the mean and rstd for the backward pass later
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
}
}
void layernorm_backward(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* mean, float* rstd,
int B, int T, int C) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * C + t * C;
float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
float mean_bt = mean[b * T + t];
float rstd_bt = rstd[b * T + t];
// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}
dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;
// now iterate again and accumulate all the gradients
for (int i = 0; i < C; i++) {
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
dbias[i] += dout_bt[i];
// gradient contribution to weight
dweight[i] += norm_bti * dout_bt[i];
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
dval -= dnorm_mean; // term 2
dval -= norm_bti * dnorm_norm_mean; // term 3
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
}
}
}
void matmul_forward(float* out,
float* inp, float* weight, float* bias,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_backward
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C), bias is (OC)
// out will be (B,T,OC)
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* out_bt = out + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? bias[o] : 0.0f;
float* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += inp_bt[i] * wrow[i];
}
out_bt[o] = val;
}
}
}
}
void matmul_backward(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight,
int B, int T, int C, int OC) {
// most of the running time is spent here and in matmul_forward
// this backward could be done in a single "round" of loops
// but that doesn't afford an efficient parallelization strategy
// backward into inp first, parallelize over B,T
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* dinp_bt = dinp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float* wrow = weight + o*C;
float d = dout_bt[o];
for (int i = 0; i < C; i++) {
dinp_bt[i] += wrow[i] * d;
}
}
}
}
// backward into weight/bias, parallelize over output channels OC
#pragma omp parallel for
for (int o = 0; o < OC; o++) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
float* dwrow = dweight + o*C;
float d = dout_bt[o];
if (dbias != NULL) { dbias[o] += d; }
for (int i = 0; i < C; i++) {
dwrow[i] += inp_bt[i] * d;
}
}
}
}
}
void attention_forward(float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH) {
// input is (B, T, 3C) Q,K,V
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
// pass 1: calculate query dot key and maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= scale;
if (val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
// pass 2: calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;
// pass 3: normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}
// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}
}
}
void attention_backward(float* dinp, float* dpreatt, float* datt,
float* dout, float* inp, float* att,
int B, int T, int C, int NH) {
// inp/dinp are (B, T, 3C) Q,K,V
// att/datt/dpreatt are (B, NH, T, T)
// dout is (B, T, C)
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
float* datt_bth = datt + b*NH*T*T + h*T*T + t*T;
float* dpreatt_bth = dpreatt + b*NH*T*T + h*T*T + t*T;
float* dquery_t = dinp + b * T * C3 + t * C3 + h * hs;
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
// backward pass 4, through the value accumulation
float* dout_bth = dout + b * T * C + t * C + h * hs;
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float* dvalue_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C*2;
for (int i = 0; i < hs; i++) {
// in the forward pass this was:
// out_bth[i] += att_bth[t2] * value_t2[i];
// so now we have:
datt_bth[t2] += value_t2[i] * dout_bth[i];
dvalue_t2[i] += att_bth[t2] * dout_bth[i];
}
}
// backward pass 2 & 3, the softmax
// note that softmax (like e.g. tanh) doesn't need the input (preatt) to backward
for (int t2 = 0; t2 <= t; t2++) {
for (int t3 = 0; t3 <= t; t3++) {
float indicator = t2 == t3 ? 1.0f : 0.0f;
float local_derivative = att_bth[t2] * (indicator - att_bth[t3]);
dpreatt_bth[t3] += local_derivative * datt_bth[t2];
}
}
// backward pass 1, the query @ key matmul
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
float* dkey_t2 = dinp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
for (int i = 0; i < hs; i++) {
// in the forward pass this was:
// preatt_bth[t2] += (query_t[i] * key_t2[i]) * scale;
// so now we have:
dquery_t[i] += key_t2[i] * dpreatt_bth[t2] * scale;
dkey_t2[i] += query_t[i] * dpreatt_bth[t2] * scale;
}
}
}
}
}
}
void gelu_forward(float* out, float* inp, int N) {
float s = sqrtf(2.0f / M_PI);
for (int i = 0; i < N; i++) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
out[i] = 0.5f * x * (1.0f + tanhf(s * (x + cube)));
}
}
void gelu_backward(float* dinp, float* inp, float* dout, int N) {
float s = sqrtf(2.0f / M_PI);
for (int i = 0; i < N; i++) {
float x = inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = s * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * s * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] += local_grad * dout[i];
}
}
void residual_forward(float* out, float* inp1, float* inp2, int N) {
for (int i = 0; i < N; i++) {
out[i] = inp1[i] + inp2[i];
}
}
void residual_backward(float* dinp1, float* dinp2, float* dout, int N) {
for (int i = 0; i < N; i++) {
dinp1[i] += dout[i];
dinp2[i] += dout[i];
}
}
void softmax_forward(float* probs, float* logits, int B, int T, int V) {
// output: probs are (B,T,V) of the probabilities
// input: logits is (B,T,V) of the unnormalized log probabilities
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// probs <- softmax(logits)
float* logits_bt = logits + b * T * V + t * V;
float* probs_bt = probs + b * T * V + t * V;
float maxval = -10000.0f; // TODO something better
for (int i = 0; i < V; i++) {
if (logits_bt[i] > maxval) {
maxval = logits_bt[i];
}
}
float sum = 0.0f;
for (int i = 0; i < V; i++) {
probs_bt[i] = expf(logits_bt[i] - maxval);
sum += probs_bt[i];
}
for (int i = 0; i < V; i++) {
probs_bt[i] /= sum;
}
}
}
}
void crossentropy_forward(float* losses,
float* probs, int* targets,
int B, int T, int V) {
// output: losses is (B,T) of the individual losses at each position
// input: probs are (B,T,V) of the probabilities
// input: targets is (B,T) of integers giving the correct index in logits
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// loss = -log(probs[target])
float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}
}
void crossentropy_softmax_backward(float* dlogits,
float* dlosses, float* probs, int* targets,
int B, int T, int V) {
// backwards through both softmax and crossentropy
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dlogits_bt = dlogits + b * T * V + t * V;
float* probs_bt = probs + b * T * V + t * V;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
dlogits_bt[i] += (p - indicator) * dloss;
}
}
}
}
// ----------------------------------------------------------------------------
// GPT-2 model definition
// the parameters of the model
#define NUM_PARAMETER_TENSORS 16
typedef struct {
float* wte; // (V, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
float* qkvw; // (L, 3*C, C)
float* qkvb; // (L, 3*C)
float* attprojw; // (L, C, C)
float* attprojb; // (L, C)
float* ln2w; // (L, C)
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
float* lnfb; // (C)
} ParameterTensors;
// allocate memory for the parameters and point the individual tensors to the right places
float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes) {
size_t num_parameters = 0;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += param_sizes[i];
}
// malloc all parameters all at once
float* params_memory = (float*)malloc(num_parameters * sizeof(float));
// assign all the tensors
float** ptrs[] = {
¶ms->wte, ¶ms->wpe, ¶ms->ln1w, ¶ms->ln1b, ¶ms->qkvw, ¶ms->qkvb,
¶ms->attprojw, ¶ms->attprojb, ¶ms->ln2w, ¶ms->ln2b, ¶ms->fcw, ¶ms->fcb,
¶ms->fcprojw, ¶ms->fcprojb, ¶ms->lnfw, ¶ms->lnfb
};
float* params_memory_iterator = params_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
return params_memory;
}
#define NUM_ACTIVATION_TENSORS 23
typedef struct {
float* encoded; // (B, T, C)
float* ln1; // (L, B, T, C)
float* ln1_mean; // (L, B, T)
float* ln1_rstd; // (L, B, T)
float* qkv; // (L, B, T, 3*C)
float* atty; // (L, B, T, C)
float* preatt; // (L, B, NH, T, T)
float* att; // (L, B, NH, T, T)
float* attproj; // (L, B, T, C)
float* residual2; // (L, B, T, C)
float* ln2; // (L, B, T, C)
float* ln2_mean; // (L, B, T)
float* ln2_rstd; // (L, B, T)
float* fch; // (L, B, T, 4*C)
float* fch_gelu; // (L, B, T, 4*C)
float* fcproj; // (L, B, T, C)
float* residual3; // (L, B, T, C)
float* lnf; // (B, T, C)
float* lnf_mean; // (B, T)
float* lnf_rstd; // (B, T)
float* logits; // (B, T, V)
float* probs; // (B, T, V)
float* losses; // (B, T)
} ActivationTensors;
float* malloc_and_point_activations(ActivationTensors* acts, size_t* act_sizes) {
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += act_sizes[i];
}
float* acts_memory = (float*)malloc(num_activations * sizeof(float));
float** ptrs[] = {
&acts->encoded, &acts->ln1, &acts->ln1_mean, &acts->ln1_rstd, &acts->qkv, &acts->atty,
&acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean,
&acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf,
&acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses
};
float* acts_memory_iterator = acts_memory;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
*(ptrs[i]) = acts_memory_iterator;
acts_memory_iterator += act_sizes[i];
}
return acts_memory;
}
typedef struct {
int max_seq_len; // max sequence length, e.g. 1024
int vocab_size; // vocab size, e.g. 50257
int num_layers; // number of layers, e.g. 12
int num_heads; // number of heads in attention, e.g. 12
int channels; // number of channels, e.g. 768
} GPT2Config;
typedef struct {
GPT2Config config;
// the weights of the model, and their sizes
ParameterTensors params;
size_t param_sizes[NUM_PARAMETER_TENSORS];
float* params_memory;
int num_parameters;
// gradients of the weights
ParameterTensors grads;
float* grads_memory;
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
// the activations of the model, and their sizes
ActivationTensors acts;
size_t act_sizes[NUM_ACTIVATION_TENSORS];
float* acts_memory;
int num_activations;
// gradients of the activations
ActivationTensors grads_acts;
float* grads_acts_memory;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
int* inputs; // the input tokens for the current forward pass
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
} GPT2;
void gpt2_build_from_checkpoint(GPT2 *model, char* checkpoint_path) {
// read in model from a checkpoint file
FILE *model_file = fopen(checkpoint_path, "rb");
if (model_file == NULL) { printf("Error opening model file\n"); exit(1); }
int model_header[256];
fread(model_header, sizeof(int), 256, model_file);
if (model_header[0] != 20240326) { printf("Bad magic model file"); exit(1); }
if (model_header[1] != 1) { printf("Bad version in model file"); exit(1); }
// read in hyperparameters
int maxT, V, L, NH, C;
model->config.max_seq_len = maxT = model_header[2];
model->config.vocab_size = V = model_header[3];
model->config.num_layers = L = model_header[4];
model->config.num_heads = NH = model_header[5];
model->config.channels = C = model_header[6];
printf("[GPT-2]\n");
printf("max_seq_len: %d\n", maxT);
printf("vocab_size: %d\n", V);
printf("num_layers: %d\n", L);
printf("num_heads: %d\n", NH);
printf("channels: %d\n", C);
// allocate space for all the parameters and read them in
model->param_sizes[0] = V * C;
model->param_sizes[1] = maxT * C;
model->param_sizes[2] = L * C;
model->param_sizes[3] = L * C;
model->param_sizes[4] = L * (3 * C) * C;
model->param_sizes[5] = L * (3 * C);
model->param_sizes[6] = L * C * C;
model->param_sizes[7] = L * C;
model->param_sizes[8] = L * C;
model->param_sizes[9] = L * C;
model->param_sizes[10] = L * (4 * C) * C;
model->param_sizes[11] = L * (4 * C);
model->param_sizes[12] = L * C * (4 * C);
model->param_sizes[13] = L * C;
model->param_sizes[14] = C;
model->param_sizes[15] = C;
// cound the number of paramaters
size_t num_parameters = 0;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += model->param_sizes[i];
}
printf("num_parameters: %zu\n", num_parameters);
model->num_parameters = num_parameters;
// read in all the parameters from file
model->params_memory = malloc_and_point_parameters(&model->params, model->param_sizes);
fread(model->params_memory, sizeof(float), num_parameters, model_file);
fclose(model_file);
// other inits
model->acts_memory = NULL;
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
}
void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) {
// targets are optional and could be NULL
// ensure the model was initialized or error out
if (model->params_memory == NULL) {
printf("Error: model was not initialized properly.\n");
exit(1);
}
// convenience parameters
int V = model->config.vocab_size;
int L = model->config.num_layers;
int NH = model->config.num_heads;
int C = model->config.channels;
// allocate space for all the activations if needed (done here, lazily)
if(model->acts_memory == NULL) {
// record the current B,T as well
model->batch_size = B;
model->seq_len = T;
// and now allocate the space
model->act_sizes[0] = B * T * C;
model->act_sizes[1] = L * B * T * C;
model->act_sizes[2] = L * B * T;
model->act_sizes[3] = L * B * T;
model->act_sizes[4] = L * B * T * 3*C;
model->act_sizes[5] = L * B * T * C;
model->act_sizes[6] = L * B * NH * T * T;
model->act_sizes[7] = L * B * NH * T * T;
model->act_sizes[8] = L * B * T * C;
model->act_sizes[9] = L * B * T * C;
model->act_sizes[10] = L * B * T * C;
model->act_sizes[11] = L * B * T;
model->act_sizes[12] = L * B * T;
model->act_sizes[13] = L * B * T * 4*C;
model->act_sizes[14] = L * B * T * 4*C;
model->act_sizes[15] = L * B * T * C;
model->act_sizes[16] = L * B * T * C;
model->act_sizes[17] = B * T * C;
model->act_sizes[18] = B * T;
model->act_sizes[19] = B * T;
model->act_sizes[20] = B * T * V;
model->act_sizes[21] = B * T * V;
model->act_sizes[22] = B * T;
size_t num_activations = 0;
for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) {
num_activations += model->act_sizes[i];
}
printf("num_activations: %zu\n", num_activations);
model->num_activations = num_activations;
model->acts_memory = malloc_and_point_activations(&model->acts, model->act_sizes);
// also create memory for caching inputs and targets
model->inputs = (int*)malloc(B * T * sizeof(int));
model->targets = (int*)malloc(B * T * sizeof(int)); // might be unused if we never have targets but it's small
} else {
// validate B,T is no larger than what was previously allocated
// in principle, we could re-allocate a larger chunk of memory, for now we just error out
if (B > model->batch_size || T > model->seq_len) {
printf("Error: batch size or sequence length is inadequately large\n");
printf("Model: B=%d T=%d, Desired: B=%d T=%d\n", model->batch_size, model->seq_len, B, T);
exit(1);
}
}
// cache the inputs/targets
memcpy(model->inputs, inputs, B * T * sizeof(int));
if (targets != NULL) {
memcpy(model->targets, targets, B * T * sizeof(int));
}
// forward pass
ParameterTensors params = model->params; // for brevity
ActivationTensors acts = model->acts;
float* residual;
encoder_forward(acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]
for (int l = 0; l < L; l++) {
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
float* l_ln1w = params.ln1w + l * C;
float* l_ln1b = params.ln1b + l * C;
float* l_qkvw = params.qkvw + l * 3*C * C;
float* l_qkvb = params.qkvb + l * 3*C;
float* l_attprojw = params.attprojw + l * C * C;
float* l_attprojb = params.attprojb + l * C;
float* l_ln2w = params.ln2w + l * C;
float* l_ln2b = params.ln2b + l * C;
float* l_fcw = params.fcw + l * 4*C * C;
float* l_fcb = params.fcb + l * 4*C;
float* l_fcprojw = params.fcprojw + l * C * 4*C;
float* l_fcprojb = params.fcprojb + l * C;
// get the pointers of the activations for this layer
float* l_ln1 = acts.ln1 + l * B * T * C;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
float* l_qkv = acts.qkv + l * B * T * 3*C;
float* l_atty = acts.atty + l * B * T * C;
float* l_preatt = acts.preatt + l * B * NH * T * T;
float* l_att = acts.att + l * B * NH * T * T;
float* l_attproj = acts.attproj + l * B * T * C;
float* l_residual2 = acts.residual2 + l * B * T * C;
float* l_ln2 = acts.ln2 + l * B * T * C;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
float* l_fch = acts.fch + l * B * T * 4*C;
float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
float* l_fcproj = acts.fcproj + l * B * T * C;
float* l_residual3 = acts.residual3 + l * B * T * C;
// now do the forward pass
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);
matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
attention_forward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);
matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);
}
residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, V);
softmax_forward(acts.probs, acts.logits, B, T, V);
// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
crossentropy_forward(model->acts.losses, model->acts.probs, targets, B, T, V);
// for convenience also evaluate the mean loss
float mean_loss = 0.0f;
for (int i=0; i<B*T; i++) { mean_loss += model->acts.losses[i]; }
mean_loss /= B*T;
model->mean_loss = mean_loss;
} else {
// if we don't have targets, we don't have a loss
model->mean_loss = -1.0f;
}
}
void gpt2_zero_grad(GPT2 *model) {
if(model->grads_memory != NULL) { memset(model->grads_memory, 0, model->num_parameters * sizeof(float)); }
if(model->grads_acts_memory != NULL) { memset(model->grads_acts_memory, 0, model->num_activations * sizeof(float)); }
}
void gpt2_backward(GPT2 *model) {
// double check we forwarded previously, with targets
if (model->mean_loss == -1.0f) {
printf("Error: must forward with targets before backward\n");
exit(1);
}
// lazily allocate the memory for gradients of the weights and activations, if needed
if (model->grads_memory == NULL) {
model->grads_memory = malloc_and_point_parameters(&model->grads, model->param_sizes);
model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, model->act_sizes);
gpt2_zero_grad(model);
}
// convenience shortcuts
int B = model->batch_size;
int T = model->seq_len;
int V = model->config.vocab_size;
int L = model->config.num_layers;
int NH = model->config.num_heads;
int C = model->config.channels;
// backward pass
ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
ActivationTensors acts = model->acts;
ActivationTensors grads_acts = model->grads_acts;
// we kick off the chain by filling in dlosses with 1.0f/(B*T), to get the mean loss
float dloss_mean = 1.0f / (B*T);
for (int i = 0; i < B*T; i++) { grads_acts.losses[i] = dloss_mean; }
crossentropy_softmax_backward(grads_acts.logits, grads_acts.losses, acts.probs, model->targets, B, T, V);
matmul_backward(grads_acts.lnf, grads.wte, NULL, grads_acts.logits, acts.lnf, params.wte, B, T, C, V);
float* residual = acts.residual3 + (L-1) * B * T * C; // last layer's residual
float* dresidual = grads_acts.residual3 + (L-1) * B * T * C; // write to last layer's residual
layernorm_backward(dresidual, grads.lnfw, grads.lnfb, grads_acts.lnf, residual, params.lnfw, acts.lnf_mean, acts.lnf_rstd, B, T, C);
for (int l = L-1; l >= 0; l--) {
residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;
dresidual = l == 0 ? grads_acts.encoded : grads_acts.residual3 + (l-1) * B * T * C;
// get the pointers of the weights for this layer
float* l_ln1w = params.ln1w + l * C;
float* l_qkvw = params.qkvw + l * 3*C * C;
float* l_attprojw = params.attprojw + l * C * C;
float* l_ln2w = params.ln2w + l * C;
float* l_fcw = params.fcw + l * 4*C * C;
float* l_fcprojw = params.fcprojw + l * C * 4*C;
// get the pointers of the gradients of the weights for this layer
float* dl_ln1w = grads.ln1w + l * C;
float* dl_ln1b = grads.ln1b + l * C;
float* dl_qkvw = grads.qkvw + l * 3*C * C;
float* dl_qkvb = grads.qkvb + l * 3*C;
float* dl_attprojw = grads.attprojw + l * C * C;
float* dl_attprojb = grads.attprojb + l * C;
float* dl_ln2w = grads.ln2w + l * C;
float* dl_ln2b = grads.ln2b + l * C;
float* dl_fcw = grads.fcw + l * 4*C * C;
float* dl_fcb = grads.fcb + l * 4*C;
float* dl_fcprojw = grads.fcprojw + l * C * 4*C;
float* dl_fcprojb = grads.fcprojb + l * C;
// get the pointers of the activations for this layer
float* l_ln1 = acts.ln1 + l * B * T * C;
float* l_ln1_mean = acts.ln1_mean + l * B * T;
float* l_ln1_rstd = acts.ln1_rstd + l * B * T;
float* l_qkv = acts.qkv + l * B * T * 3*C;
float* l_atty = acts.atty + l * B * T * C;
float* l_att = acts.att + l * B * NH * T * T;
float* l_residual2 = acts.residual2 + l * B * T * C;
float* l_ln2 = acts.ln2 + l * B * T * C;
float* l_ln2_mean = acts.ln2_mean + l * B * T;
float* l_ln2_rstd = acts.ln2_rstd + l * B * T;
float* l_fch = acts.fch + l * B * T * 4*C;
float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C;
// get the pointers of the gradients of the activations for this layer
float* dl_ln1 = grads_acts.ln1 + l * B * T * C;
float* dl_qkv = grads_acts.qkv + l * B * T * 3*C;
float* dl_atty = grads_acts.atty + l * B * T * C;
float* dl_preatt = grads_acts.preatt + l * B * NH * T * T;
float* dl_att = grads_acts.att + l * B * NH * T * T;
float* dl_attproj = grads_acts.attproj + l * B * T * C;
float* dl_residual2 = grads_acts.residual2 + l * B * T * C;
float* dl_ln2 = grads_acts.ln2 + l * B * T * C;
float* dl_fch = grads_acts.fch + l * B * T * 4*C;
float* dl_fch_gelu = grads_acts.fch_gelu + l * B * T * 4*C;
float* dl_fcproj = grads_acts.fcproj + l * B * T * C;
float* dl_residual3 = grads_acts.residual3 + l * B * T * C;
// backprop this layer
residual_backward(dl_residual2, dl_fcproj, dl_residual3, B*T*C);
matmul_backward(dl_fch_gelu, dl_fcprojw, dl_fcprojb, dl_fcproj, l_fch_gelu, l_fcprojw, B, T, 4*C, C);
gelu_backward(dl_fch, l_fch, dl_fch_gelu, B*T*4*C);
matmul_backward(dl_ln2, dl_fcw, dl_fcb, dl_fch, l_ln2, l_fcw, B, T, C, 4*C);
layernorm_backward(dl_residual2, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C);
residual_backward(dresidual, dl_attproj, dl_residual2, B*T*C);
matmul_backward(dl_atty, dl_attprojw, dl_attprojb, dl_attproj, l_atty, l_attprojw, B, T, C, C);
attention_backward(dl_qkv, dl_preatt, dl_att, dl_atty, l_qkv, l_att, B, T, C, NH);
matmul_backward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C);
layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C);
}
encoder_backward(grads.wte, grads.wpe, grads_acts.encoded, model->inputs, B, T, C);
}
void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
// reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
// lazily allocate the memory for m_memory and v_memory
if (model->m_memory == NULL) {
model->m_memory = (float*)calloc(model->num_parameters, sizeof(float));
model->v_memory = (float*)calloc(model->num_parameters, sizeof(float));
}
for (int i = 0; i < model->num_parameters; i++) {
float param = model->params_memory[i];
float grad = model->grads_memory[i];
// update the first moment (momentum)
float m = beta1 * model->m_memory[i] + (1.0f - beta1) * grad;
// update the second moment (RMSprop)
float v = beta2 * model->v_memory[i] + (1.0f - beta2) * grad * grad;
// bias-correct both moments
float m_hat = m / (1.0f - powf(beta1, t));
float v_hat = v / (1.0f - powf(beta2, t));
// update
model->m_memory[i] = m;
model->v_memory[i] = v;
model->params_memory[i] -= learning_rate * (m_hat / (sqrtf(v_hat) + eps) + weight_decay * param);
}
}
void gpt2_free(GPT2 *model) {
free(model->params_memory);
free(model->grads_memory);
free(model->m_memory);
free(model->v_memory);
free(model->acts_memory);
free(model->grads_acts_memory);
free(model->inputs);
free(model->targets);
}
#ifndef TESTING
// if we are TESTING (see test.c), we'll skip the int main below
// ----------------------------------------------------------------------------
// data loader lite
// returns random batches of data from a file of integers
typedef struct {
// hyperparameters
int B;
int T;
// input handling and its state
FILE* tokens_file;
long file_size;
long current_position;
// output memory
int* batch;
int* inputs;
int* targets;
// convenience variables
int num_batches;
} DataLoader;
void dataloader_init(DataLoader *loader, char* filename, int B, int T) {
loader->B = B;
loader->T = T;
// open the input file for reading
loader->tokens_file = fopen(filename, "rb");
if (loader->tokens_file == NULL) {
printf("Error opening tokens file\n");
exit(1);
}
// determine the file size
fseek(loader->tokens_file, 0, SEEK_END);
loader->file_size = ftell(loader->tokens_file);
fseek(loader->tokens_file, 0, SEEK_SET);
if (loader->file_size < (B * T + 1) * sizeof(int)) {
printf("Error: file size is too small for the batch size and sequence length\n");
exit(1);
}
loader->current_position = 0; // start at the beginning
// allocate space for B*T + 1 integers to store the inputs and targets
loader->batch = (int*) malloc((B * T + 1) * sizeof(int));
loader->inputs = loader->batch;
loader->targets = loader->batch + 1; // targets are shifted by one
loader->num_batches = loader->file_size / (B * T * sizeof(int));
}
void dataloader_reset(DataLoader *loader) {
loader->current_position = 0;
}
void dataloader_next_batch(DataLoader *loader) {
int B = loader->B;
int T = loader->T;
// if we are at the end of the file, loop back to the beginning
if (loader->current_position + (B*T+1) * sizeof(int) > loader->file_size) {
loader->current_position = 0;
}
// read the B*T+1 integers from the file into batch
fseek(loader->tokens_file, loader->current_position, SEEK_SET);
fread(loader->batch, sizeof(int), B*T+1, loader->tokens_file);
// advance the current position by B*T integers
loader->current_position += B*T * sizeof(int);
}
void dataloader_free(DataLoader *loader) {
fclose(loader->tokens_file);
free(loader->batch);
}