Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,18 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my

base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!

You can also chat with the Llama Chat models. Export the chat model exactly as above:

```bash
python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat
```

Then chat with it by specifying the chat mode using the `-m` flag, e.g.:

```bash
./run llama2_7b_chat.bin -m chat
```

## hugginface models

We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
Expand Down Expand Up @@ -207,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`.

If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183)

### OpenMP
Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does:

```bash
Expand Down Expand Up @@ -324,13 +335,11 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg

## unsorted todos

- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp
- ability to calculate perplexity in run.c, exactly as done in llama.cpp
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
- add more tests in [test.c](test.c)
- runq.c (int8 quantization) add
- run.cu (CUDA) investigate and merge
- add an Engine class that serves the model ~efficiently but in PyTorch (see [Issue 346](https://github.com/karpathy/llama2.c/issues/346))
- add more tests inside [test.c](test.c)
- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping
- make it easier to add a new dataset with not too much pain
- (LoRA) finetuning and export of Llama 2 models

Expand Down
127 changes: 122 additions & 5 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,8 @@ long time_in_ms() {
// generation loop

void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
char *empty_prompt = "";
if (prompt == NULL) { prompt = empty_prompt; }

// encode the (string) prompt into tokens sequence
int num_prompt_tokens = 0;
Expand Down Expand Up @@ -785,6 +787,108 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
free(prompt_tokens);
}

void read_stdin(const char* guide, char* buffer, size_t bufsize) {
// read a line from stdin, up to but not including \n
printf("%s", guide);
if (fgets(buffer, bufsize, stdin) != NULL) {
size_t len = strlen(buffer);
if (len > 0 && buffer[len - 1] == '\n') {
buffer[len - 1] = '\0'; // strip newline
}
}
}

// ----------------------------------------------------------------------------
// chat loop
// I manually inspected the tokens for a few chat conversations compared to
// python reference and that seemed ok, but this was not thoroughly tested and
// is not safely implemented, it's more a proof of concept atm.

void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
char *cli_user_prompt, char *cli_system_prompt, int steps) {

// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[1152];
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
int user_idx;

// start the main loop
int8_t user_turn = 1; // user starts
int next; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int prev_token;
int pos = 0; // position in the sequence
while (pos < steps) {

// when it is the user's turn to contribute tokens to the dialog...
if (user_turn) {
// get the (optional) system prompt at position 0
if (pos == 0) {
// at position 0, the user can also contribute a system prompt
if (cli_system_prompt == NULL) {
// system prompt was not passed in, attempt to get it from stdin
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
} else {
// system prompt was passed in, use it
strcpy(system_prompt, cli_system_prompt);
}
}
// get the user prompt
if (pos == 0 && cli_user_prompt != NULL) {
// user prompt for position 0 was passed in, use it
strcpy(user_prompt, cli_user_prompt);
} else {
// otherwise get user prompt from stdin
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
char user_template[] = "[INST] %s [/INST]";
sprintf(rendered_prompt, user_template, user_prompt);
}
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
user_turn = 0;
printf("Assistant: ");
}

// determine the token to pass into the transformer next
if (user_idx < num_prompt_tokens) {
// if we are still processing the input prompt, force the next prompt token
token = prompt_tokens[user_idx++];
} else {
// otherwise use the next token sampled from previous turn
token = next;
}
// EOS (=2) token ends the Assistant turn
if (token == 2) { user_turn = 1; }

// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
Copy link
Contributor

@nikolaydubina nikolaydubina Aug 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for my knowledge, isn't there any problems with not forwarding model through tokens from users prompt? (e.g. things like KV cache and position/relative-offsets?)


for example, for "generate", I see we forward model through prompt. the sampling logits happens after "forward".

        if (pos < num_prompt_tokens - 1) {
            // if we are still processing the input prompt, force the next prompt token
            next = prompt_tokens[pos + 1];
        } else {
            // otherwise sample the next token from the logits
            next = sample(sampler, logits);
        }

but for chat mode here we encode whole new user prompt ("pos" not incremented), then "forward" and immediately "sample" for that token

        // ... add whole new user prompt

        // forward the transformer to get logits for the next token
        float* logits = forward(transformer, token, pos);
        next = sample(sampler, logits);
        pos++;

Shouldn't we "forward" through new prompt first and only then sample logits?

next = sample(sampler, logits);
pos++;

if (user_idx >= num_prompt_tokens && next != 2) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
if (next == 2) { printf("\n"); }
}
printf("\n");
free(prompt_tokens);
}


// ----------------------------------------------------------------------------
// CLI, include only if not testing
#ifndef TESTING
Expand All @@ -799,6 +903,8 @@ void error_usage() {
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
fprintf(stderr, " -i <string> input prompt\n");
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
exit(EXIT_FAILURE);
}

Expand All @@ -807,11 +913,13 @@ int main(int argc, char *argv[]) {
// default parameters
char *checkpoint_path = NULL; // e.g. out/model.bin
char *tokenizer_path = "tokenizer.bin";
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = ""; // prompt string
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
unsigned long long rng_seed = 0; // seed rng with time by default
char *mode = "generate"; // generate|chat
char *system_prompt = NULL; // the (optional) system prompt to use in chat mode

// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
Expand All @@ -827,6 +935,8 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
else { error_usage(); }
}

Expand All @@ -850,7 +960,14 @@ int main(int argc, char *argv[]) {
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);

// run!
generate(&transformer, &tokenizer, &sampler, prompt, steps);
if (strcmp(mode, "generate") == 0) {
generate(&transformer, &tokenizer, &sampler, prompt, steps);
} else if (strcmp(mode, "chat") == 0) {
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
} else {
fprintf(stderr, "unknown mode: %s\n", mode);
error_usage();
}

// memory and file handles cleanup
free_sampler(&sampler);
Expand Down