Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into patch-openBLAS
Browse files Browse the repository at this point in the history
  • Loading branch information
rdentato authored Aug 6, 2023
2 parents 22be84b + a7a3aa0 commit cddb5d1
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 47 deletions.
49 changes: 30 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,25 +208,36 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg

## notable forks

- [llama2.rs](https://github.com/gaxler/llama2.rs) by @gaxler: a Rust port of this project
- [go-llama2](https://github.com/tmc/go-llama2) by @tmc: a Go port of this project
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @nikolaydubina: a Go port of this project
- [llama2.go](https://github.com/haormj/llama2.go) by @haormj: a Go port of this project
- [llama2.go](https://github.com/saracen/llama2.go) by @saracen: a Go port of this project
- [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @Manuel030: adds Android binaries of this project
- [llama2.c-android-wrapper](https://github.com/celikin/llama2.c-android-wrapper): by @celikin: added JNI wrapper, PoC
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @leloykun: a C++ port of this project
- [llama2.js](https://github.com/epicure/llama2.js) by @epicure: a JavaScript port of this project
- [llama2.zig](https://github.com/cgbur/llama2.zig) by @cgbur: A Zig port of this project
- [llama2.zig](https://github.com/vodkaslime/llama2.zig) by @vodkaslime: a Zig port of this project
- [llama2.jl](https://github.com/juvi21/llama2.jl) by @juvi21: a Julia port of this project
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @trholding: Standalone, Bootable & Portable Binary Llama 2
- [llama2.rs](https://github.com/leo-du/llama2.rs) by @leo-du: A Rust port of this project
- [llama2.scala](https://github.com/jrudolph/llama2.scala) by @jrudolph: a Scala port of this project
- [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @gohai: Emscripten (JavaScript) port, based on @ggerganov's initial prototype
- [llama2.java](https://github.com/mukel/llama2.java) by @mukel: a Java port of this project
- [llama2.kt](https://github.com/madroidmaq/llama2.kt) by @madroidmaq: a Kotlin port of this project
- [llama2.zig](https://github.com/clebert/llama2.zig) by @clebert: a Zig port of this project
- Rust
- [llama2.rs](https://github.com/gaxler/llama2.rs) by @[gaxler](https://github.com/gaxler): a Rust port of this project
- [llama2.rs](https://github.com/leo-du/llama2.rs) by @[leo-du](https://github.com/leo-du): A Rust port of this project
- Go
- [go-llama2](https://github.com/tmc/go-llama2) by @[tmc](https://github.com/tmc): a Go port of this project
- [llama2.go](https://github.com/nikolaydubina/llama2.go) by @[nikolaydubina](https://github.com/nikolaydubina): a Go port of this project
- [llama2.go](https://github.com/haormj/llama2.go) by @[haormj](https://github.com/haormj): a Go port of this project
- [llama2.go](https://github.com/saracen/llama2.go) by @[saracen](https://github.com/saracen): a Go port of this project
- Android
- [llama2.c-android](https://github.com/Manuel030/llama2.c-android): by @[Manuel030](https://github.com/Manuel030): adds Android binaries of this project
- [llama2.c-android-wrapper](https://github.com/celikin/llama2.c-android-wrapper): by @[celikin](https://github.com/celikin): added JNI wrapper, PoC
- C++
- [llama2.cpp](https://github.com/leloykun/llama2.cpp) by @[leloykun](https://github.com/leloykun): a C++ port of this project
- JavaScript
- [llama2.js](https://github.com/epicure/llama2.js) by @[epicure](https://github.com/epicure): a JavaScript port of this project
- [llama2.c-emscripten](https://github.com/gohai/llama2.c-emscripten) by @[gohai](https://github.com/gohai): Emscripten (JavaScript) port, based on @ggerganov's initial prototype
- Zig
- [llama2.zig](https://github.com/cgbur/llama2.zig) by @[cgbur](https://github.com/cgbur): A Zig port of this project
- [llama2.zig](https://github.com/vodkaslime/llama2.zig) by @[vodkaslime](https://github.com/vodkaslime): a Zig port of this project
- [llama2.zig](https://github.com/clebert/llama2.zig) by @[clebert](https://github.com/clebert): a Zig port of this project
- Julia
- [llama2.jl](https://github.com/juvi21/llama2.jl) by @[juvi21](https://github.com/juvi21): a Julia port of this project
- Scala
- [llama2.scala](https://github.com/jrudolph/llama2.scala) by @[jrudolph](https://github.com/jrudolph): a Scala port of this project
- Java
- [llama2.java](https://github.com/mukel/llama2.java) by @[mukel](https://github.com/mukel): a Java port of this project
- Kotlin
- [llama2.kt](https://github.com/madroidmaq/llama2.kt) by @[madroidmaq](https://github.com/madroidmaq): a Kotlin port of this project
- [llama2.c - Llama 2 Everywhere](https://github.com/trholding/llama2.c) by @[trholding](https://github.com/trholding): Standalone, Bootable & Portable Binary Llama 2


## unsorted todos

Expand Down
61 changes: 33 additions & 28 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void malloc_run_state(RunState* s, Config* p) {
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache || !s->probindex) {
printf("malloc failed!\n");
fprintf(stderr, "malloc failed!\n");
exit(EXIT_FAILURE);
}
}
Expand Down Expand Up @@ -379,7 +379,7 @@ void bpe_encode(char *text, char **vocab, float *vocab_scores, int vocab_size, u
for (char *c = text; *c != '\0'; c++) {
sprintf(str_buffer, "%c", *c);
int id = str_lookup(str_buffer, vocab, vocab_size);
if (id == -1) { printf("not good\n"); exit(EXIT_FAILURE); }
if (id == -1) { fprintf(stderr, "not good\n"); exit(EXIT_FAILURE); }
tokens[*n_tokens] = id;
(*n_tokens)++;
}
Expand Down Expand Up @@ -517,14 +517,14 @@ int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex) {
// int main

void error_usage() {
printf("Usage: run <checkpoint> [options]\n");
printf("Example: run model.bin -n 256 -i \"Once upon a time\"\n");
printf("Options:\n");
printf(" -t <float> temperature, default 1.0\n");
printf(" -p <float> p value in top-p (nucleus) sampling. default 0.9, 0 = off\n");
printf(" -s <int> random seed, default time(NULL)\n");
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
printf(" -i <string> input prompt\n");
fprintf(stderr, "Usage: run <checkpoint> [options]\n");
fprintf(stderr, "Example: run model.bin -n 256 -i \"Once upon a time\"\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -t <float> temperature, default 1.0\n");
fprintf(stderr, " -p <float> p value in top-p (nucleus) sampling. default 0.9, 0 = off\n");
fprintf(stderr, " -s <int> random seed, default time(NULL)\n");
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
exit(EXIT_FAILURE);
}

Expand Down Expand Up @@ -553,7 +553,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else { error_usage(); }
}
if(rng_seed == 0) { printf("Cannot use seed=0 because of the rng alg used\n"); return 1; }
if(rng_seed == 0) { fprintf(stderr, "Cannot use seed=0 because of the rng alg used\n"); return 1; }

// read in the model.bin file
Config config;
Expand All @@ -563,7 +563,7 @@ int main(int argc, char *argv[]) {
ssize_t file_size; // size of the checkpoint file in bytes
{
FILE *file = fopen(checkpoint, "rb");
if (!file) { printf("Couldn't open file %s\n", checkpoint); return 1; }
if (!file) { fprintf(stderr, "Couldn't open file %s\n", checkpoint); return 1; }
// read in the config header
if (fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
Expand All @@ -575,9 +575,9 @@ int main(int argc, char *argv[]) {
fclose(file);
// memory map the Transformer weights into the data pointer
fd = open(checkpoint, O_RDONLY); // open in read only mode
if (fd == -1) { printf("open failed!\n"); return 1; }
if (fd == -1) { fprintf(stderr, "open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
if (data == MAP_FAILED) { fprintf(stderr, "mmap failed!\n"); return 1; }
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
}
Expand All @@ -590,14 +590,14 @@ int main(int argc, char *argv[]) {
unsigned int max_token_length;
{
FILE *file = fopen("tokenizer.bin", "rb");
if (!file) { printf("couldn't load tokenizer.bin\n"); return 1; }
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
if (!file) { fprintf(stderr, "couldn't load tokenizer.bin\n"); return 1; }
if (fread(&max_token_length, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
int len;
for (int i = 0; i < config.vocab_size; i++) {
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { printf("failed read\n"); return 1;}
if (fread(&len, sizeof(int), 1, file) != 1) { printf("failed read\n"); return 1; }
if (fread(vocab_scores + i, sizeof(float), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1;}
if (fread(&len, sizeof(int), 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
vocab[i] = (char *)malloc(len + 1);
if (fread(vocab[i], len, 1, file) != 1) { printf("failed read\n"); return 1; }
if (fread(vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); return 1; }
vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
Expand All @@ -620,12 +620,12 @@ int main(int argc, char *argv[]) {
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
while (pos < steps) {

// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);

// advance the state state machine
if(pos < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
next = prompt_tokens[pos];
Expand All @@ -649,22 +649,27 @@ int main(int argc, char *argv[]) {
}
}
}
pos++;

// data-dependent terminating condition: the BOS (1) token delimits sequences
if (next == 1) { break; }

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
fflush(stdout);

// advance forward
token = next;
pos++;
// init our timer here because the first iteration is slow due to memmap

// init the timer here because the first iteration can be slower
if (start == 0) { start = time_in_ms(); }
}
printf("\n");

// report achieved tok/s
long end = time_in_ms();
printf("\nachieved tok/s: %f\n", (steps-1) / (double)(end-start)*1000);
// report achieved tok/s (pos-1 because the timer starts after first iteration)
if (pos > 1) {
long end = time_in_ms();
fprintf(stderr, "achieved tok/s: %f\n", (pos-1) / (double)(end-start)*1000);
}

// memory and file handles cleanup
free_run_state(&state);
Expand Down

0 comments on commit cddb5d1

Please sign in to comment.