Skip to content

Commit

Permalink
refactor the Transformer (Config, Weights, RunState) into a single ob…
Browse files Browse the repository at this point in the history
…ject, with build and free too
  • Loading branch information
karpathy committed Aug 21, 2023
1 parent ae2e4f8 commit 8a377a1
Showing 1 changed file with 54 additions and 41 deletions.
95 changes: 54 additions & 41 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <sys/mman.h>
#endif
// ----------------------------------------------------------------------------
// Transformer and RunState structs, and related memory management
// Transformer model

typedef struct {
int dim; // transformer dimension
Expand Down Expand Up @@ -64,6 +64,16 @@ typedef struct {
float* value_cache; // (layer, seq_len, dim)
} RunState;

typedef struct {
Config config; // the hyperparameters of the architecture (the blueprint)
TransformerWeights weights; // the weights of the model
RunState state; // buffers for the "wave" of activations in the forward pass
// some more state needed to properly clean up the memory mapping (sigh)
int fd; // file descriptor for memory mapping
float* data; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
} Transformer;

void malloc_run_state(RunState* s, Config* p) {
// we calloc instead of malloc to keep valgrind happy
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
Expand Down Expand Up @@ -103,10 +113,7 @@ void free_run_state(RunState* s) {
free(s->value_cache);
}

// ----------------------------------------------------------------------------
// initialization: read from checkpoint

void checkpoint_init_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
void memory_map_weights(TransformerWeights *w, Config* p, float* ptr, int shared_weights) {
int head_size = p->dim / p->n_heads;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
Expand Down Expand Up @@ -154,11 +161,26 @@ void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weigh
*data = mmap(NULL, *file_size, PROT_READ, MAP_PRIVATE, *fd, 0);
if (*data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); exit(EXIT_FAILURE); }
float* weights_ptr = *data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(weights, config, weights_ptr, shared_weights);
memory_map_weights(weights, config, weights_ptr, shared_weights);
}

void build_transformer(char* checkpoint_path, Transformer *t) {
// read in the Config and the Weights from the checkpoint
read_checkpoint(checkpoint_path, &t->config, &t->weights, &t->fd, &t->data, &t->file_size);
// allocate the RunState buffers
malloc_run_state(&t->state, &t->config);
}

void free_transformer(Transformer* t) {
// close the memory mapping
if (t->data != MAP_FAILED) { munmap(t->data, t->file_size); }
if (t->fd != -1) { close(t->fd); }
// free the RunState buffers
free_run_state(&t->state);
}

// ----------------------------------------------------------------------------
// neural net blocks
// neural net blocks; the dynamics of the Transformer

void rmsnorm(float* o, float* x, float* weight, int size) {
// calculate sum of squares
Expand Down Expand Up @@ -209,9 +231,12 @@ void matmul(float* xout, float* x, float* w, int n, int d) {
}
}

void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
float* forward(Transformer* transformer, int token, int pos) {

// a few convenience variables
Config* p = &transformer->config;
TransformerWeights* w = &transformer->weights;
RunState* s = &transformer->state;
float *x = s->x;
int dim = p->dim;
int kv_dim = (p->dim * p->n_kv_heads) / p->n_heads;
Expand Down Expand Up @@ -338,6 +363,7 @@ void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights*

// classifier into logits
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
return s->logits;
}

// ----------------------------------------------------------------------------
Expand All @@ -351,16 +377,16 @@ typedef struct {
char byte_piece[2];
} Tokenizer;

void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
void build_tokenizer(char* tokenizer_path, Tokenizer* t, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
// malloc space to hold the scores and the strings
t->vocab = (char**)malloc(vocab_size * sizeof(char*));
t->vocab_scores = (float*)malloc(vocab_size * sizeof(float));
t->byte_piece[1] = '\0'; // null terminate the byte_piece string
// read in the file
FILE *file = fopen(tokenizer, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer); exit(EXIT_FAILURE); }
FILE *file = fopen(tokenizer_path, "rb");
if (!file) { fprintf(stderr, "couldn't load %s\n", tokenizer_path); exit(EXIT_FAILURE); }
if (fread(&t->max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
int len;
for (int i = 0; i < vocab_size; i++) {
Expand All @@ -374,9 +400,7 @@ void build_tokenizer(char* tokenizer, Tokenizer* t, int vocab_size) {
}

void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) {
free(t->vocab[i]);
}
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
free(t->vocab);
free(t->vocab_scores);
}
Expand Down Expand Up @@ -667,28 +691,19 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
else { error_usage(); }
}
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}

// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0; // file descriptor for memory mapping
float* data = NULL; // memory mapped data pointer
ssize_t file_size; // size of the checkpoint file in bytes
read_checkpoint(checkpoint_path, &config, &weights, &fd, &data, &file_size);
if(rng_seed == 0) { rng_seed = (unsigned int)time(NULL);}

// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// build the Transformer via the model .bin file
Transformer transformer;
build_transformer(checkpoint_path, &transformer);
int vocab_size = transformer.config.vocab_size; // convenience copy

// read in the tokenizer .bin file
// build the Tokenizer via the tokenizer .bin file
Tokenizer tokenizer;
build_tokenizer(tokenizer_path, &tokenizer, config.vocab_size);
build_tokenizer(tokenizer_path, &tokenizer, vocab_size);

// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
ProbIndex *probindex = malloc(config.vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling

ProbIndex *probindex = malloc(vocab_size * sizeof(ProbIndex)); // buffer used in top-p sampling
// process the prompt, if any
int *prompt_tokens = NULL;
int num_prompt_tokens = 0;
Expand All @@ -705,7 +720,7 @@ int main(int argc, char *argv[]) {
while (pos < steps) {

// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
float* logits = forward(&transformer, token, pos);

// advance the state state machine
if(pos < num_prompt_tokens) {
Expand All @@ -715,19 +730,19 @@ int main(int argc, char *argv[]) {
// sample the next token
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = argmax(state.logits, config.vocab_size);
next = argmax(logits, vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
for (int q=0; q<vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
softmax(logits, vocab_size);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample(state.logits, config.vocab_size);
next = sample(logits, vocab_size);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(state.logits, config.vocab_size, topp, probindex);
next = sample_topp(logits, vocab_size, topp, probindex);
}
}
}
Expand All @@ -754,11 +769,9 @@ int main(int argc, char *argv[]) {
}

// memory and file handles cleanup
free_run_state(&state);
free(probindex);
if (prompt_tokens != NULL) { free(prompt_tokens); }
free_tokenizer(&tokenizer);
if (prompt_tokens != NULL) free(prompt_tokens);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);
free_transformer(&transformer);
return 0;
}

0 comments on commit 8a377a1

Please sign in to comment.