diff --git a/run.c b/run.c index efb254f8..e1a4ec24 100644 --- a/run.c +++ b/run.c @@ -83,16 +83,13 @@ void malloc_run_state(RunState* s, Config* p) { s->hb = calloc(p->hidden_dim, sizeof(float)); s->hb2 = calloc(p->hidden_dim, sizeof(float)); s->q = calloc(p->dim, sizeof(float)); - s->k = calloc(kv_dim, sizeof(float)); - s->v = calloc(kv_dim, sizeof(float)); - s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); - s->logits = calloc(p->vocab_size, sizeof(float)); s->key_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); s->value_cache = calloc(p->n_layers * p->seq_len * kv_dim, sizeof(float)); + s->att = calloc(p->n_heads * p->seq_len, sizeof(float)); + s->logits = calloc(p->vocab_size, sizeof(float)); // ensure all mallocs went fine if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q - || !s->k || !s->v || !s->att || !s->logits || !s->key_cache - || !s->value_cache) { + || !s->key_cache || !s->value_cache || !s->att || !s->logits) { fprintf(stderr, "malloc failed!\n"); exit(EXIT_FAILURE); } @@ -105,8 +102,6 @@ void free_run_state(RunState* s) { free(s->hb); free(s->hb2); free(s->q); - free(s->k); - free(s->v); free(s->att); free(s->logits); free(s->key_cache); @@ -256,6 +251,11 @@ float* forward(Transformer* transformer, int token, int pos) { // attention rmsnorm rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim); + // key and value point to the kv cache + int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience + s->k = s->key_cache + loff + pos * kv_dim; + s->v = s->value_cache + loff + pos * kv_dim; + // qkv matmuls for this position matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim); matmul(s->k, s->xb, w->wk + l*dim*kv_dim, dim, kv_dim); @@ -278,13 +278,6 @@ float* forward(Transformer* transformer, int token, int pos) { } } - // save key,value at this time step (pos) to our kv cache - int loff = l * p->seq_len * kv_dim; // kv cache layer offset for convenience - float* key_cache_row = s->key_cache + loff + pos * kv_dim; - float* value_cache_row = s->value_cache + loff + pos * kv_dim; - memcpy(key_cache_row, s->k, kv_dim * sizeof(*key_cache_row)); - memcpy(value_cache_row, s->v, kv_dim * sizeof(*value_cache_row)); - // multihead attention. iterate over all heads int h; #pragma omp parallel for private(h)