diff --git a/run.c b/run.c index e5b12338..eb47fe5e 100644 --- a/run.c +++ b/run.c @@ -369,14 +369,24 @@ float* forward(Transformer* transformer, int token, int pos) { // ---------------------------------------------------------------------------- // The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens +typedef struct { + char *str; + int id; +} TokenIndex; + typedef struct { char** vocab; float* vocab_scores; + TokenIndex *sorted_vocab; int vocab_size; unsigned int max_token_length; char byte_piece[2]; } Tokenizer; +int compare_tokens(const void *a, const void *b) { + return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); +} + void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { // i should have written the vocab_size into the tokenizer file... sigh t->vocab_size = vocab_size; @@ -396,13 +406,31 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) { if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); } t->vocab[i][len] = '\0'; // add the string terminating token } + + t->sorted_vocab = NULL; + fclose(file); } +void sort_vocabulary(Tokenizer* t) +{ + if (t->sorted_vocab == NULL) { + // sort vocabulary + t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); + if (t->sorted_vocab == NULL) { fprintf(stderr, "failed malloc for sorted vocabulary\n"); exit(EXIT_FAILURE); } + for (int i = 0; i < t->vocab_size; i++) { + t->sorted_vocab[i].str = t->vocab[i]; + t->sorted_vocab[i].id = i; + } + qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); + } +} + void free_tokenizer(Tokenizer* t) { for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); } free(t->vocab); free(t->vocab_scores); + free(t->sorted_vocab); } char* decode(Tokenizer* t, int prev_token, int token) { @@ -422,15 +450,6 @@ char* decode(Tokenizer* t, int prev_token, int token) { return piece; } -typedef struct { - char *str; - int id; -} TokenIndex; - -int compare_tokens(const void *a, const void *b) { - return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str); -} - int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { // efficiently find the perfect match for str in vocab, return its index or -1 if not found TokenIndex tok = { .str = str }; // acts as the key to search for @@ -441,20 +460,14 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { // encode the string text (input) into an upper-bound preallocated tokens[] array - // sort vocabulary - TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex)); - for (int i = 0; i < t->vocab_size; i++) { - sorted_vocab[i].str = t->vocab[i]; - sorted_vocab[i].id = i; - } - qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens); + sort_vocabulary(t); // create a temporary buffer that will store merge candidates of always two consecutive tokens char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1) size_t str_len = 0; // add_dummy_prefix is true by default - tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size); + tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size); *n_tokens = 1; // the number of tokens // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: @@ -490,7 +503,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } // ok c+1 is not a continuation byte, so we've read in a full codepoint - int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); if (id != -1) { // we found this codepoint in vocab, add it as a token @@ -515,7 +528,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { for (int i=0; i < (*n_tokens-1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); - int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size); + int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); if (id != -1 && t->vocab_scores[id] > best_score) { // this merge pair exists in vocab! record its score and position best_score = t->vocab_scores[id]; @@ -538,7 +551,6 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) { } free(str_buffer); - free(sorted_vocab); } // ----------------------------------------------------------------------------