Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions run.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,24 @@ float* forward(Transformer* transformer, int token, int pos) {
// ----------------------------------------------------------------------------
// The Byte Pair Encoding (BPE) Tokenizer that translates strings <-> tokens

typedef struct {
char *str;
int id;
} TokenIndex;

typedef struct {
char** vocab;
float* vocab_scores;
TokenIndex *sorted_vocab;
int vocab_size;
unsigned int max_token_length;
char byte_piece[2];
} Tokenizer;

int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
// i should have written the vocab_size into the tokenizer file... sigh
t->vocab_size = vocab_size;
Expand All @@ -396,13 +406,31 @@ void build_tokenizer(Tokenizer* t, char* tokenizer_path, int vocab_size) {
if (fread(t->vocab[i], len, 1, file) != 1) { fprintf(stderr, "failed read\n"); exit(EXIT_FAILURE); }
t->vocab[i][len] = '\0'; // add the string terminating token
}

t->sorted_vocab = NULL;

fclose(file);
}

void sort_vocabulary(Tokenizer* t)
{
if (t->sorted_vocab == NULL) {
// sort vocabulary
t->sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
if (t->sorted_vocab == NULL) { fprintf(stderr, "failed malloc for sorted vocabulary\n"); exit(EXIT_FAILURE); }
for (int i = 0; i < t->vocab_size; i++) {
t->sorted_vocab[i].str = t->vocab[i];
t->sorted_vocab[i].id = i;
}
qsort(t->sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
}
}

void free_tokenizer(Tokenizer* t) {
for (int i = 0; i < t->vocab_size; i++) { free(t->vocab[i]); }
free(t->vocab);
free(t->vocab_scores);
free(t->sorted_vocab);
}

char* decode(Tokenizer* t, int prev_token, int token) {
Expand All @@ -422,15 +450,6 @@ char* decode(Tokenizer* t, int prev_token, int token) {
return piece;
}

typedef struct {
char *str;
int id;
} TokenIndex;

int compare_tokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}

int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
// efficiently find the perfect match for str in vocab, return its index or -1 if not found
TokenIndex tok = { .str = str }; // acts as the key to search for
Expand All @@ -441,20 +460,14 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array

// sort vocabulary
TokenIndex *sorted_vocab = malloc(t->vocab_size * sizeof(TokenIndex));
for (int i = 0; i < t->vocab_size; i++) {
sorted_vocab[i].str = t->vocab[i];
sorted_vocab[i].id = i;
}
qsort(sorted_vocab, t->vocab_size, sizeof(TokenIndex), compare_tokens);
sort_vocabulary(t);

// create a temporary buffer that will store merge candidates of always two consecutive tokens
char* str_buffer = malloc((t->max_token_length*2 +1 +2) * sizeof(char)); // *2 for concat, +1 for null terminator +2 for UTF8 (in case max_token_lenght is 1)
size_t str_len = 0;

// add_dummy_prefix is true by default
tokens[0] = str_lookup(" ", sorted_vocab, t->vocab_size);
tokens[0] = str_lookup(" ", t->sorted_vocab, t->vocab_size);
*n_tokens = 1; // the number of tokens

// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
Expand Down Expand Up @@ -490,7 +503,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}

// ok c+1 is not a continuation byte, so we've read in a full codepoint
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);

if (id != -1) {
// we found this codepoint in vocab, add it as a token
Expand All @@ -515,7 +528,7 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
for (int i=0; i < (*n_tokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, sorted_vocab, t->vocab_size);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = t->vocab_scores[id];
Expand All @@ -538,7 +551,6 @@ void encode(Tokenizer* t, char *text, int *tokens, int *n_tokens) {
}

free(str_buffer);
free(sorted_vocab);
}

// ----------------------------------------------------------------------------
Expand Down