Skip to content

Commit 2148496

Browse files
authored
Merge pull request karpathy#343 from karpathy/feature/chat
Add interactive loop to enable nice chat with a Llama 2 Chat model
2 parents 0a6525b + 90b41e6 commit 2148496

File tree

2 files changed

+137
-11
lines changed

2 files changed

+137
-11
lines changed

README.md

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ This ran at about 4 tokens/s compiled with [OpenMP](#OpenMP) on 96 threads on my
8383
8484
base models... ¯\\_(ツ)_/¯. Since we can inference the base model, it should be possible to also inference the chat model quite easily, and have a conversation with it. And if we can find a way to run 7B more efficiently, we can start adding LoRA to our training script, and going wild with finetunes all within the repo!
8585

86+
You can also chat with the Llama Chat models. Export the chat model exactly as above:
87+
88+
```bash
89+
python export.py llama2_7b_chat.bin --meta-llama /path/to/7B-chat
90+
```
91+
92+
Then chat with it by specifying the chat mode using the `-m` flag, e.g.:
93+
94+
```bash
95+
./run llama2_7b_chat.bin -m chat
96+
```
97+
8698
## hugginface models
8799

88100
We can load any huggingface models that use the Llama 2 architecture. See the script [export.py](export.py) and the `--hf` flag to export the model .bin file.
@@ -207,8 +219,7 @@ You can also experiment with replacing `gcc` with `clang`.
207219

208220
If compiling with gcc, try experimenting with `-funroll-all-loops`, see PR [#183](https://github.com/karpathy/llama2.c/pull/183)
209221

210-
### OpenMP
211-
Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
222+
**OpenMP**. Big improvements can also be achieved by compiling with OpenMP, which "activates" the `#pragma omp parallel for` inside the matmul and attention, allowing the work in the loops to be split up over multiple processors.
212223
You'll need to install the OpenMP library and the clang compiler first (e.g. `apt install clang libomp-dev` on ubuntu). Then you can compile with `make runomp`, which does:
213224

214225
```bash
@@ -324,13 +335,11 @@ If your candidate PRs have elements of these it doesn't mean they won't get merg
324335

325336
## unsorted todos
326337

327-
- support Llama 2 7B Chat models with a Chat UI/UX in run.c, very similar to llama.cpp
328-
- ability to calculate perplexity in run.c, exactly as done in llama.cpp
329338
- add support in run.c of reading version 1+ files from export, later deprecate "version 0"
330-
- add more tests in [test.c](test.c)
331339
- runq.c (int8 quantization) add
332340
- run.cu (CUDA) investigate and merge
333-
- add an Engine class that serves the model ~efficiently but in PyTorch (see [Issue 346](https://github.com/karpathy/llama2.c/issues/346))
341+
- add more tests inside [test.c](test.c)
342+
- add Engine class for use in sample.py that does efficient inference in PyTorch, e.g. KV cache keeping
334343
- make it easier to add a new dataset with not too much pain
335344
- (LoRA) finetuning and export of Llama 2 models
336345

run.c

Lines changed: 122 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,8 @@ long time_in_ms() {
732732
// generation loop
733733

734734
void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, char *prompt, int steps) {
735+
char *empty_prompt = "";
736+
if (prompt == NULL) { prompt = empty_prompt; }
735737

736738
// encode the (string) prompt into tokens sequence
737739
int num_prompt_tokens = 0;
@@ -785,6 +787,108 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
785787
free(prompt_tokens);
786788
}
787789

790+
void read_stdin(const char* guide, char* buffer, size_t bufsize) {
791+
// read a line from stdin, up to but not including \n
792+
printf("%s", guide);
793+
if (fgets(buffer, bufsize, stdin) != NULL) {
794+
size_t len = strlen(buffer);
795+
if (len > 0 && buffer[len - 1] == '\n') {
796+
buffer[len - 1] = '\0'; // strip newline
797+
}
798+
}
799+
}
800+
801+
// ----------------------------------------------------------------------------
802+
// chat loop
803+
// I manually inspected the tokens for a few chat conversations compared to
804+
// python reference and that seemed ok, but this was not thoroughly tested and
805+
// is not safely implemented, it's more a proof of concept atm.
806+
807+
void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
808+
char *cli_user_prompt, char *cli_system_prompt, int steps) {
809+
810+
// buffers for reading the system prompt and user prompt from stdin
811+
// you'll notice they are soomewhat haphazardly and unsafely set atm
812+
char system_prompt[512];
813+
char user_prompt[512];
814+
char rendered_prompt[1152];
815+
int num_prompt_tokens = 0;
816+
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
817+
int user_idx;
818+
819+
// start the main loop
820+
int8_t user_turn = 1; // user starts
821+
int next; // will store the next token in the sequence
822+
int token; // stores the current token to feed into the transformer
823+
int prev_token;
824+
int pos = 0; // position in the sequence
825+
while (pos < steps) {
826+
827+
// when it is the user's turn to contribute tokens to the dialog...
828+
if (user_turn) {
829+
// get the (optional) system prompt at position 0
830+
if (pos == 0) {
831+
// at position 0, the user can also contribute a system prompt
832+
if (cli_system_prompt == NULL) {
833+
// system prompt was not passed in, attempt to get it from stdin
834+
read_stdin("Enter system prompt (optional): ", system_prompt, sizeof(system_prompt));
835+
} else {
836+
// system prompt was passed in, use it
837+
strcpy(system_prompt, cli_system_prompt);
838+
}
839+
}
840+
// get the user prompt
841+
if (pos == 0 && cli_user_prompt != NULL) {
842+
// user prompt for position 0 was passed in, use it
843+
strcpy(user_prompt, cli_user_prompt);
844+
} else {
845+
// otherwise get user prompt from stdin
846+
read_stdin("User: ", user_prompt, sizeof(user_prompt));
847+
}
848+
// render user/system prompts into the Llama 2 Chat schema
849+
if (pos == 0 && system_prompt[0] != '\0') {
850+
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
851+
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
852+
} else {
853+
char user_template[] = "[INST] %s [/INST]";
854+
sprintf(rendered_prompt, user_template, user_prompt);
855+
}
856+
// encode the rendered prompt into tokens
857+
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
858+
user_idx = 0; // reset the user index
859+
user_turn = 0;
860+
printf("Assistant: ");
861+
}
862+
863+
// determine the token to pass into the transformer next
864+
if (user_idx < num_prompt_tokens) {
865+
// if we are still processing the input prompt, force the next prompt token
866+
token = prompt_tokens[user_idx++];
867+
} else {
868+
// otherwise use the next token sampled from previous turn
869+
token = next;
870+
}
871+
// EOS (=2) token ends the Assistant turn
872+
if (token == 2) { user_turn = 1; }
873+
874+
// forward the transformer to get logits for the next token
875+
float* logits = forward(transformer, token, pos);
876+
next = sample(sampler, logits);
877+
pos++;
878+
879+
if (user_idx >= num_prompt_tokens && next != 2) {
880+
// the Assistant is responding, so print its output
881+
char* piece = decode(tokenizer, token, next);
882+
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
883+
fflush(stdout);
884+
}
885+
if (next == 2) { printf("\n"); }
886+
}
887+
printf("\n");
888+
free(prompt_tokens);
889+
}
890+
891+
788892
// ----------------------------------------------------------------------------
789893
// CLI, include only if not testing
790894
#ifndef TESTING
@@ -799,6 +903,8 @@ void error_usage() {
799903
fprintf(stderr, " -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
800904
fprintf(stderr, " -i <string> input prompt\n");
801905
fprintf(stderr, " -z <string> optional path to custom tokenizer\n");
906+
fprintf(stderr, " -m <string> mode: generate|chat, default: generate\n");
907+
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
802908
exit(EXIT_FAILURE);
803909
}
804910

@@ -807,11 +913,13 @@ int main(int argc, char *argv[]) {
807913
// default parameters
808914
char *checkpoint_path = NULL; // e.g. out/model.bin
809915
char *tokenizer_path = "tokenizer.bin";
810-
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
811-
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
812-
int steps = 256; // number of steps to run for
813-
char *prompt = ""; // prompt string
916+
float temperature = 1.0f; // 0.0 = greedy deterministic. 1.0 = original. don't set higher
917+
float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well, but slower
918+
int steps = 256; // number of steps to run for
919+
char *prompt = NULL; // prompt string
814920
unsigned long long rng_seed = 0; // seed rng with time by default
921+
char *mode = "generate"; // generate|chat
922+
char *system_prompt = NULL; // the (optional) system prompt to use in chat mode
815923

816924
// poor man's C argparse so we can override the defaults above from the command line
817925
if (argc >= 2) { checkpoint_path = argv[1]; } else { error_usage(); }
@@ -827,6 +935,8 @@ int main(int argc, char *argv[]) {
827935
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
828936
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
829937
else if (argv[i][1] == 'z') { tokenizer_path = argv[i + 1]; }
938+
else if (argv[i][1] == 'm') { mode = argv[i + 1]; }
939+
else if (argv[i][1] == 'y') { system_prompt = argv[i + 1]; }
830940
else { error_usage(); }
831941
}
832942

@@ -850,7 +960,14 @@ int main(int argc, char *argv[]) {
850960
build_sampler(&sampler, transformer.config.vocab_size, temperature, topp, rng_seed);
851961

852962
// run!
853-
generate(&transformer, &tokenizer, &sampler, prompt, steps);
963+
if (strcmp(mode, "generate") == 0) {
964+
generate(&transformer, &tokenizer, &sampler, prompt, steps);
965+
} else if (strcmp(mode, "chat") == 0) {
966+
chat(&transformer, &tokenizer, &sampler, prompt, system_prompt, steps);
967+
} else {
968+
fprintf(stderr, "unknown mode: %s\n", mode);
969+
error_usage();
970+
}
854971

855972
// memory and file handles cleanup
856973
free_sampler(&sampler);

0 commit comments

Comments
 (0)