Skip to content

Commit 5a396f1

Browse files
committed
update quantize; refactor struct defs
1 parent 0634382 commit 5a396f1

File tree

8 files changed

+345
-396
lines changed

8 files changed

+345
-396
lines changed

bert.cpp

Lines changed: 84 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -76,105 +76,6 @@ static void tensor_stats(ggml_tensor * t) {
7676
);
7777
}
7878

79-
//
80-
// data structures
81-
//
82-
83-
// default hparams (all-MiniLM-L6-v2)
84-
struct bert_hparams {
85-
int32_t n_vocab = 30522;
86-
int32_t n_max_tokens = 512;
87-
int32_t n_embd = 256;
88-
int32_t n_intermediate = 1536;
89-
int32_t n_head = 12;
90-
int32_t n_layer = 6;
91-
float_t layer_norm_eps = 1e-12;
92-
};
93-
94-
struct bert_layer {
95-
// normalization
96-
struct ggml_tensor *ln_att_w;
97-
struct ggml_tensor *ln_att_b;
98-
99-
struct ggml_tensor *ln_out_w;
100-
struct ggml_tensor *ln_out_b;
101-
102-
// attention
103-
struct ggml_tensor *q_w;
104-
struct ggml_tensor *q_b;
105-
struct ggml_tensor *k_w;
106-
struct ggml_tensor *k_b;
107-
struct ggml_tensor *v_w;
108-
struct ggml_tensor *v_b;
109-
110-
struct ggml_tensor *o_w;
111-
struct ggml_tensor *o_b;
112-
113-
// ff
114-
struct ggml_tensor *ff_i_w;
115-
struct ggml_tensor *ff_i_b;
116-
117-
struct ggml_tensor *ff_o_w;
118-
struct ggml_tensor *ff_o_b;
119-
};
120-
121-
struct bert_vocab {
122-
std::map<std::string, bert_token> token_to_id;
123-
std::map<std::string, bert_token> subword_token_to_id;
124-
125-
std::map<bert_token, std::string> _id_to_token;
126-
std::map<bert_token, std::string> _id_to_subword_token;
127-
};
128-
129-
struct bert_model {
130-
bert_hparams hparams;
131-
132-
// embeddings weights
133-
struct ggml_tensor *word_embeddings;
134-
struct ggml_tensor *token_type_embeddings;
135-
struct ggml_tensor *position_embeddings;
136-
struct ggml_tensor *ln_e_w;
137-
struct ggml_tensor *ln_e_b;
138-
139-
std::vector<bert_layer> layers;
140-
};
141-
142-
struct bert_ctx {
143-
bert_model model;
144-
bert_vocab vocab;
145-
146-
struct ggml_context * ctx_data;
147-
148-
std::vector<uint8_t> buf_compute_meta;
149-
150-
// memory buffers to evaluate the model
151-
ggml_backend_t backend = NULL;
152-
ggml_backend_buffer_t weights_buffer = NULL;
153-
ggml_backend_buffer_t compute_buffer = NULL;
154-
ggml_allocr * compute_alloc = NULL;
155-
};
156-
157-
int32_t bert_n_embd(bert_ctx * ctx) {
158-
return ctx->model.hparams.n_embd;
159-
}
160-
161-
int32_t bert_n_max_tokens(bert_ctx * ctx) {
162-
return ctx->model.hparams.n_max_tokens;
163-
}
164-
165-
const char* bert_vocab_id_to_token(bert_ctx * ctx, bert_token id) {
166-
bert_vocab & vocab = ctx->vocab;
167-
auto it = vocab._id_to_token.find(id);
168-
if (it != vocab._id_to_token.end()) {
169-
return it->second.c_str();
170-
}
171-
it = vocab._id_to_subword_token.find(id);
172-
if (it != vocab._id_to_subword_token.end()) {
173-
return it->second.c_str();
174-
}
175-
return "[UNK TOKEN from bert_vocab]";
176-
}
177-
17879
//
17980
// tokenizing
18081
//
@@ -275,6 +176,19 @@ bool is_chinese_char(const std::string& str) {
275176
return false;
276177
}
277178

179+
const char* bert_vocab_id_to_token(bert_ctx * ctx, bert_token id) {
180+
bert_vocab & vocab = ctx->vocab;
181+
auto it = vocab._id_to_token.find(id);
182+
if (it != vocab._id_to_token.end()) {
183+
return it->second.c_str();
184+
}
185+
it = vocab._id_to_subword_token.find(id);
186+
if (it != vocab._id_to_subword_token.end()) {
187+
return it->second.c_str();
188+
}
189+
return "[UNK TOKEN from bert_vocab]";
190+
}
191+
278192
bert_tokens bert_tokenize(struct bert_ctx * ctx, bert_string text, int32_t n_max_tokens) {
279193
int cls_tok_id = 101;
280194
int sep_tok_id = 102;
@@ -392,11 +306,23 @@ void bert_tokenize_c(struct bert_ctx * ctx, const char * text, int32_t * output,
392306
}
393307
}
394308

309+
//
310+
// bert model
311+
//
312+
313+
int32_t bert_n_embd(bert_ctx * ctx) {
314+
return ctx->model.hparams.n_embd;
315+
}
316+
317+
int32_t bert_n_max_tokens(bert_ctx * ctx) {
318+
return ctx->model.hparams.n_max_tokens;
319+
}
320+
395321
//
396322
// loading and setup
397323
//
398324

399-
struct bert_ctx * bert_load_from_file(const char *fname, int32_t batch_size, bool use_cpu) {
325+
struct bert_ctx * bert_load_from_file(const char *fname, bool use_cpu) {
400326
printf("%s: loading model from '%s (use_cpu = %b)' - please wait ...\n", __func__, fname, use_cpu);
401327

402328
struct ggml_context * ctx_ggml = NULL;
@@ -476,6 +402,7 @@ struct bert_ctx * bert_load_from_file(const char *fname, int32_t batch_size, boo
476402

477403
for (int i = 0; i < n_vocab; i++) {
478404
std::string word = gguf_get_arr_str(ctx_gguf, token_idx, i);
405+
vocab.tokens.push_back(word);
479406

480407
if (word[0] == '#' && word[1] == '#') {
481408
vocab.subword_token_to_id[word.substr(2)] = i;
@@ -651,36 +578,71 @@ struct bert_ctx * bert_load_from_file(const char *fname, int32_t batch_size, boo
651578
ggml_free(ctx_ggml);
652579
gguf_free(ctx_gguf);
653580

654-
// measure mem requirement and allocate
655-
{
656-
// get measuring allocr for backend
657-
new_bert->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
658-
new_bert->compute_alloc = ggml_allocr_new_measure_from_backend(new_bert->backend);
659-
660-
// construct batch and compute graph
661-
bert_tokens tokens(hparams.n_max_tokens);
662-
bert_batch batch;
663-
for (int i = 0; i < batch_size; ++i) {
664-
batch.push_back(tokens);
665-
}
666-
ggml_cgraph * gf = bert_build_graph(new_bert, batch);
581+
// return context
582+
return new_bert;
583+
}
667584

668-
// do computing graph measurement
669-
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(new_bert->compute_alloc, gf);
670-
ggml_allocr_free(new_bert->compute_alloc);
585+
// measure and allocate comptue buffers
586+
void bert_allocate_buffers(bert_ctx * ctx, int32_t n_max_tokens, int32_t batch_size) {
587+
// deallocate if already allocated
588+
bert_deallocate_buffers(ctx);
671589

672-
// now that we know the compute size, create a buffer and allocr
673-
new_bert->compute_buffer = ggml_backend_alloc_buffer(new_bert->backend, compute_memory_buffer_size);
674-
new_bert->compute_alloc = ggml_allocr_new_from_buffer(new_bert->compute_buffer);
590+
// get measuring allocr for backend
591+
ctx->buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
592+
ctx->compute_alloc = ggml_allocr_new_measure_from_backend(ctx->backend);
675593

676-
printf("%s: compute allocated memory: %.2f MB\n\n", __func__, compute_memory_buffer_size / 1024.0 / 1024.0);
594+
// construct batch and compute graph
595+
bert_tokens tokens(n_max_tokens);
596+
bert_batch batch;
597+
for (int i = 0; i < batch_size; ++i) {
598+
batch.push_back(tokens);
677599
}
600+
ggml_cgraph * gf = bert_build_graph(ctx, batch);
678601

679-
return new_bert;
602+
// do computing graph measurement
603+
size_t compute_memory_buffer_size = ggml_allocr_alloc_graph(ctx->compute_alloc, gf);
604+
ggml_allocr_free(ctx->compute_alloc);
605+
606+
// now that we know the compute size, create a buffer and allocr
607+
ctx->compute_buffer = ggml_backend_alloc_buffer(ctx->backend, compute_memory_buffer_size);
608+
ctx->compute_alloc = ggml_allocr_new_from_buffer(ctx->compute_buffer);
609+
610+
printf("%s: compute allocated memory: %.2f MB\n\n", __func__, compute_memory_buffer_size / 1024.0 / 1024.0);
611+
}
612+
613+
void bert_deallocate_buffers(bert_ctx * ctx) {
614+
if (ctx->compute_buffer) {
615+
ggml_backend_buffer_free(ctx->compute_buffer);
616+
ctx->compute_buffer = NULL;
617+
}
618+
if (ctx->compute_alloc) {
619+
ggml_allocr_free(ctx->compute_alloc);
620+
ctx->compute_alloc = NULL;
621+
}
680622
}
681623

682624
void bert_free(bert_ctx * ctx) {
683-
ggml_free(ctx->ctx_data);
625+
// free compute buffers
626+
bert_deallocate_buffers(ctx);
627+
628+
// free weights buffer
629+
if (ctx->weights_buffer) {
630+
ggml_backend_buffer_free(ctx->weights_buffer);
631+
ctx->weights_buffer = NULL;
632+
}
633+
634+
// free tensor context
635+
if (ctx->ctx_data) {
636+
ggml_free(ctx->ctx_data);
637+
ctx->ctx_data = NULL;
638+
}
639+
640+
// free backend
641+
if (ctx->backend) {
642+
ggml_backend_free(ctx->backend);
643+
ctx->backend = NULL;
644+
}
645+
684646
delete ctx;
685647
}
686648

0 commit comments

Comments
 (0)