Skip to content
This repository was archived by the owner on Aug 7, 2025. It is now read-only.

Commit 8e33517

Browse files
committed
Reset context for mutiple batch entries
1 parent 641386c commit 8e33517

File tree

1 file changed

+57
-53
lines changed

1 file changed

+57
-53
lines changed

cpp/src/examples/llamacpp/llamacpp_handler.cc

Lines changed: 57 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -155,79 +155,83 @@ c10::IValue LlamaCppHandler::Inference(
155155
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
156156
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
157157
torch::InferenceMode guard;
158-
std::vector<torch::Tensor> batch_output_vector;
159-
for (const auto input : inputs.toTensorList()) {
160-
torch::Tensor tokens_list_tensor = input.get().toTensor();
158+
auto batch_output_vector = c10::impl::GenericList(torch::TensorType::get());
159+
try {
160+
for (const auto input : inputs.toTensorList()) {
161+
torch::Tensor tokens_list_tensor = input.get().toTensor();
161162

162-
int64_t num_elements = tokens_list_tensor.numel();
163+
int64_t num_elements = tokens_list_tensor.numel();
163164

164-
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();
165-
std::vector<llama_token> tokens_list;
165+
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();
166+
std::vector<llama_token> tokens_list;
166167

167-
for (int64_t i = 0; i < num_elements; ++i) {
168-
tokens_list.push_back(data_ptr[i]);
169-
}
170-
const int n_gen = std::min(32, max_context_size);
168+
for (int64_t i = 0; i < num_elements; ++i) {
169+
tokens_list.push_back(data_ptr[i]);
170+
}
171+
const int n_gen = std::min(32, max_context_size);
171172

172-
long pos = 0;
173-
while (pos < n_gen) {
174-
// evaluate the transformer
173+
std::vector<torch::Tensor> tensor_vector;
175174

176-
if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
177-
llama_get_kv_cache_token_count(llama_ctx))) {
178-
std::cout << "Failed to eval\n" << __func__ << std::endl;
179-
break;
180-
}
175+
long pos = 0;
176+
while (pos < n_gen) {
177+
// evaluate the transformer
181178

182-
tokens_list.clear();
179+
int n_past = pos == 0 ? 0 : llama_get_kv_cache_token_count(llama_ctx);
183180

184-
// sample the next token
181+
if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
182+
n_past)) {
183+
std::cout << "Failed to eval\n" << __func__ << std::endl;
184+
break;
185+
}
185186

186-
llama_token new_token_id = 0;
187+
tokens_list.clear();
187188

188-
auto logits = llama_get_logits(llama_ctx);
189-
auto n_vocab = llama_n_vocab(llamamodel);
189+
// sample the next token
190190

191-
std::vector<llama_token_data> candidates;
192-
candidates.reserve(n_vocab);
191+
llama_token new_token_id = 0;
193192

194-
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
195-
candidates.emplace_back(
196-
llama_token_data{token_id, logits[token_id], 0.0f});
197-
}
193+
auto logits = llama_get_logits(llama_ctx);
194+
auto n_vocab = llama_n_vocab(llamamodel);
198195

199-
llama_token_data_array candidates_p = {candidates.data(),
200-
candidates.size(), false};
196+
std::vector<llama_token_data> candidates;
197+
candidates.reserve(n_vocab);
201198

202-
new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p);
199+
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
200+
candidates.emplace_back(
201+
llama_token_data{token_id, logits[token_id], 0.0f});
202+
}
203203

204-
// is it an end of stream ?
205-
if (new_token_id == llama_token_eos(llamamodel)) {
206-
std::cout << "Reached [end of text]\n";
207-
break;
208-
}
204+
llama_token_data_array candidates_p = {candidates.data(),
205+
candidates.size(), false};
209206

210-
// print the new token :
211-
std::cout << "New Token: "
212-
<< llama_token_to_piece(llama_ctx, new_token_id) << std::endl;
207+
new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p);
213208

214-
// push this new token for next evaluation
215-
tokens_list.push_back(new_token_id);
216-
pos += 1;
217-
}
209+
// is it an end of stream ?
210+
if (new_token_id == llama_token_eos(llamamodel)) {
211+
std::cout << "Reached [end of text]\n";
212+
break;
213+
}
218214

219-
std::vector<torch::Tensor> tensor_vector;
220-
for (auto id : tokens_list) {
221-
torch::Tensor tensor = torch::tensor(id, torch::kLong);
222-
tensor_vector.push_back(tensor);
215+
// print the new token :
216+
std::cout << "New Token: "
217+
<< llama_token_to_piece(llama_ctx, new_token_id) << std::endl;
218+
219+
// push this new token for next evaluation
220+
tokens_list.push_back(new_token_id);
221+
tensor_vector.push_back(torch::tensor(new_token_id, torch::kLong));
222+
pos += 1;
223+
}
224+
225+
batch_output_vector.push_back(torch::stack(tensor_vector));
223226
}
224227

225-
torch::Tensor stacked_tensor = torch::stack(tensor_vector);
226-
batch_output_vector.push_back(stacked_tensor);
228+
llama_print_timings(llama_ctx);
229+
} catch (std::runtime_error& e) {
230+
TS_LOG(ERROR, e.what());
231+
} catch (const c10::Error& e) {
232+
TS_LOGF(ERROR, "Failed to apply inference on input, c10 error:{}", e.msg());
227233
}
228-
229-
llama_print_timings(llama_ctx);
230-
return torch::stack(batch_output_vector);
234+
return batch_output_vector;
231235
}
232236

233237
void LlamaCppHandler::Postprocess(

0 commit comments

Comments
 (0)