Skip to content

Commit

Permalink
Merge pull request #400 from kroggen/use-kv-cache
Browse files Browse the repository at this point in the history
calculate key and value inside the kv cache
  • Loading branch information
karpathy authored Oct 9, 2023
2 parents d0237ab + 411c5bd commit 1fcdf04
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) {
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(kv_dim, sizeof(float));
s->v = calloc(kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
// ensure all mallocs went fine
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
|| !s->key_cache || !s->value_cache || !s->att || !s->logits) {
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
Expand All @@ -105,8 +102,6 @@ void free_run_state(RunState* s) {
free(s->hb);
free(s->hb2);
free(s->q);
free(s->k);
free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
Expand Down Expand Up @@ -256,6 +251,11 @@ float* forward(Transformer* transformer, int token, int pos) {
// attention rmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);

// key and value point to the kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
s->k = s->key_cache + loff + pos * kv_dim;
s->v = s->value_cache + loff + pos * kv_dim;

// qkv matmuls for this position
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim);
Expand All @@ -278,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) {
}
}

// save key,value at this time step (pos) to our kv cache
int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience
float* key_cache_row = s->key_cache + loff + pos * kv_dim;
float* value_cache_row = s->value_cache + loff + pos * kv_dim;
memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row));

// multihead attention. iterate over all heads
int h;
#pragma omp parallel for private(h)
Expand Down

0 comments on commit 1fcdf04

Please sign in to comment.