Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run.c - Speed up with output buffering #193

Closed
wants to merge 11 commits into from
12 changes: 10 additions & 2 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,11 @@ int main(int argc, char *argv[]) {
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
char *prompt = NULL; // prompt string
int buffertokens = 1; // output token buffer size

// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt]\n", argv[0]);
printf("Usage: %s <checkpoint_file> [temperature] [steps] [prompt] [buffer_tokens]\n", argv[0]);
return 1;
}
if (argc >= 2) {
Expand All @@ -474,6 +475,9 @@ int main(int argc, char *argv[]) {
if (argc >= 5) {
prompt = argv[4];
}
if (argc >= 6) {
buffertokens = atoi(argv[5]);
}

// seed rng with time. if you want deterministic behavior use temperature 0.0
rng_seed = (unsigned int)time(NULL);
Expand Down Expand Up @@ -543,7 +547,11 @@ int main(int argc, char *argv[]) {
int next; // will store the next token in the sequence
int token = 1; // init with token 1 (=BOS), as done in Llama-2 sentencepiece tokenizer
int pos = 0; // position in the sequence
int bufferflush = 1; // buffer flush after token counter
char outbuff[4096]; // used for output buffering
trholding marked this conversation as resolved.
Show resolved Hide resolved
memset( outbuff, '\0', sizeof( outbuff )); // clear buffer area
printf("<s>\n"); // explicit print the initial BOS token for stylistic symmetry reasons
setvbuf(stdout, outbuff, _IOFBF, 4096); // setup output buffering
while (pos < steps) {

// forward the transformer to get logits for the next token
Expand All @@ -570,7 +578,7 @@ int main(int argc, char *argv[]) {
// following BOS token (1), sentencepiece decoder strips any leading whitespace (see PR #89)
char *token_str = (token == 1 && vocab[next][0] == ' ') ? vocab[next]+1 : vocab[next];
printf("%s", token_str);
fflush(stdout);
if (bufferflush==pos) { fflush(stdout); bufferflush+=buffertokens; } // flush after every n tokens

// advance forward
token = next;
Expand Down