diff --git a/run.c b/run.c index b8bd0b6e..4e534d03 100644 --- a/run.c +++ b/run.c @@ -507,6 +507,7 @@ void error_usage() { printf(" -p p value in top-p (nucleus) sampling. default 0.9, 0 = off\n"); printf(" -s random seed, default time(NULL)\n"); printf(" -n number of steps to run for, default 256. 0 = max_seq_len\n"); + printf(" -b number of tokens to buffer, default 1. 0 = max_seq_len\n"); printf(" -i input prompt\n"); exit(EXIT_FAILURE); } @@ -520,6 +521,7 @@ int main(int argc, char *argv[]) { rng_seed = (unsigned int)time(NULL); // seed rng with time by default int steps = 256; // number of steps to run for char *prompt = NULL; // prompt string + int buffertokens = 1; // number of tokens to buffer before flushing to screen // poor man's C argparse so we can override the defaults above from the command line if (argc >= 2) { checkpoint = argv[1]; } else { error_usage(); } @@ -533,6 +535,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'p') { topp = atof(argv[i + 1]); } else if (argv[i][1] == 's') { rng_seed = atoi(argv[i + 1]); } else if (argv[i][1] == 'n') { steps = atoi(argv[i + 1]); } + else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); } else if (argv[i][1] == 'i') { prompt = argv[i + 1]; } else { error_usage(); } } @@ -603,7 +606,15 @@ int main(int argc, char *argv[]) { int next; // will store the next token in the sequence int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer int pos = 0; // position in the sequence + int bufferflush = 1; // token counter for flushing buffer + static char outbuff[4096 * (6 + 2)] ; // buffersize is context length * average size of subwords + margin printf("\n"); // explicit print the initial BOS token for stylistic symmetry reasons + + // setvbuf is used to buffer output into outbuff instead of flushing to screen directly + if (setvbuf(stdout, outbuff, _IOFBF, sizeof(outbuff)) != 0) { + puts("Error: Buffer allocation!"); exit(EXIT_FAILURE); + } + while (pos < steps) { // forward the transformer to get logits for the next token @@ -635,8 +646,10 @@ int main(int argc, char *argv[]) { // following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89) char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next]; + printf("%s", token_str); - fflush(stdout); + // flush output to screen after the defined number of buffertokens have accumulated + if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; } // advance forward token = next;