Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run.c - Speed up with output buffering #193

Closed
wants to merge 11 commits into from
15 changes: 14 additions & 1 deletion run.c
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ void error_usage() {
printf(" -p <float> p value in top-p (nucleus) sampling. default 0.9, 0 = off\n");
printf(" -s <int> random seed, default time(NULL)\n");
printf(" -n <int> number of steps to run for, default 256. 0 = max_seq_len\n");
printf(" -b <int> number of tokens to buffer, default 1. 0 = max_seq_len\n");
printf(" -i <string> input prompt\n");
exit(EXIT_FAILURE);
}
Expand All @@ -520,6 +521,7 @@ int main(int argc, char *argv[]) {
rng_seed = (unsigned int)time(NULL); // seed rng with time by default
int steps = 256; // number of steps to run for
char *prompt = NULL; // prompt string
int buffertokens = 1; // number of tokens to buffer before flushing to screen

// poor man's C argparse so we can override the defaults above from the command line
if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); }
Expand All @@ -533,6 +535,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); }
else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); }
else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); }
else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); }
else if (argv[i][1] == 'i') { prompt = argv[i + 1]; }
else { error_usage(); }
}
Expand Down Expand Up @@ -603,7 +606,15 @@ int main(int argc, char *argv[]) {
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
int bufferflush = 1; // token counter for flushing buffer
static char outbuff[4096 * (6 + 2)] ; // buffersize is context length * average size of subwords + margin
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons

// setvbuf is used to buffer output into outbuff instead of flushing to screen directly
if (setvbuf(stdout, outbuff, _IOFBF, sizeof(outbuff)) != 0) {
puts("Error: Buffer allocation!"); exit(EXIT_FAILURE);
}

while (pos < steps) {

// forward the transformer to get logits for the next token
Expand Down Expand Up @@ -635,8 +646,10 @@ int main(int argc, char *argv[]) {

// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];

printf("%s", token_str);
fflush(stdout);
// flush output to screen after the defined number of buffertokens have accumulated
if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; }

// advance forward
token = next;
Expand Down