@@ -155,79 +155,83 @@ c10::IValue LlamaCppHandler::Inference(
155155 std::pair<std::string&, std::map<uint8_t , std::string>&>& idx_to_req_id,
156156 std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
157157 torch::InferenceMode guard;
158- std::vector<torch::Tensor> batch_output_vector;
159- for (const auto input : inputs.toTensorList ()) {
160- torch::Tensor tokens_list_tensor = input.get ().toTensor ();
158+ auto batch_output_vector = c10::impl::GenericList (torch::TensorType::get ());
159+ try {
160+ for (const auto input : inputs.toTensorList ()) {
161+ torch::Tensor tokens_list_tensor = input.get ().toTensor ();
161162
162- int64_t num_elements = tokens_list_tensor.numel ();
163+ int64_t num_elements = tokens_list_tensor.numel ();
163164
164- int64_t * data_ptr = tokens_list_tensor.data_ptr <int64_t >();
165- std::vector<llama_token> tokens_list;
165+ int64_t * data_ptr = tokens_list_tensor.data_ptr <int64_t >();
166+ std::vector<llama_token> tokens_list;
166167
167- for (int64_t i = 0 ; i < num_elements; ++i) {
168- tokens_list.push_back (data_ptr[i]);
169- }
170- const int n_gen = std::min (32 , max_context_size);
168+ for (int64_t i = 0 ; i < num_elements; ++i) {
169+ tokens_list.push_back (data_ptr[i]);
170+ }
171+ const int n_gen = std::min (32 , max_context_size);
171172
172- long pos = 0 ;
173- while (pos < n_gen) {
174- // evaluate the transformer
173+ std::vector<torch::Tensor> tensor_vector;
175174
176- if (llama_eval (llama_ctx, tokens_list.data (), int (tokens_list.size ()),
177- llama_get_kv_cache_token_count (llama_ctx))) {
178- std::cout << " Failed to eval\n " << __func__ << std::endl;
179- break ;
180- }
175+ long pos = 0 ;
176+ while (pos < n_gen) {
177+ // evaluate the transformer
181178
182- tokens_list. clear ( );
179+ int n_past = pos == 0 ? 0 : llama_get_kv_cache_token_count (llama_ctx );
183180
184- // sample the next token
181+ if (llama_eval (llama_ctx, tokens_list.data (), int (tokens_list.size ()),
182+ n_past)) {
183+ std::cout << " Failed to eval\n " << __func__ << std::endl;
184+ break ;
185+ }
185186
186- llama_token new_token_id = 0 ;
187+ tokens_list. clear () ;
187188
188- auto logits = llama_get_logits (llama_ctx);
189- auto n_vocab = llama_n_vocab (llamamodel);
189+ // sample the next token
190190
191- std::vector<llama_token_data> candidates;
192- candidates.reserve (n_vocab);
191+ llama_token new_token_id = 0 ;
193192
194- for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
195- candidates.emplace_back (
196- llama_token_data{token_id, logits[token_id], 0 .0f });
197- }
193+ auto logits = llama_get_logits (llama_ctx);
194+ auto n_vocab = llama_n_vocab (llamamodel);
198195
199- llama_token_data_array candidates_p = { candidates. data (),
200- candidates.size (), false } ;
196+ std::vector<llama_token_data> candidates;
197+ candidates.reserve (n_vocab) ;
201198
202- new_token_id = llama_sample_token_greedy (llama_ctx, &candidates_p);
199+ for (llama_token token_id = 0 ; token_id < n_vocab; token_id++) {
200+ candidates.emplace_back (
201+ llama_token_data{token_id, logits[token_id], 0 .0f });
202+ }
203203
204- // is it an end of stream ?
205- if (new_token_id == llama_token_eos (llamamodel)) {
206- std::cout << " Reached [end of text]\n " ;
207- break ;
208- }
204+ llama_token_data_array candidates_p = {candidates.data (),
205+ candidates.size (), false };
209206
210- // print the new token :
211- std::cout << " New Token: "
212- << llama_token_to_piece (llama_ctx, new_token_id) << std::endl;
207+ new_token_id = llama_sample_token_greedy (llama_ctx, &candidates_p);
213208
214- // push this new token for next evaluation
215- tokens_list.push_back (new_token_id);
216- pos += 1 ;
217- }
209+ // is it an end of stream ?
210+ if (new_token_id == llama_token_eos (llamamodel)) {
211+ std::cout << " Reached [end of text]\n " ;
212+ break ;
213+ }
218214
219- std::vector<torch::Tensor> tensor_vector;
220- for (auto id : tokens_list) {
221- torch::Tensor tensor = torch::tensor (id, torch::kLong );
222- tensor_vector.push_back (tensor);
215+ // print the new token :
216+ std::cout << " New Token: "
217+ << llama_token_to_piece (llama_ctx, new_token_id) << std::endl;
218+
219+ // push this new token for next evaluation
220+ tokens_list.push_back (new_token_id);
221+ tensor_vector.push_back (torch::tensor (new_token_id, torch::kLong ));
222+ pos += 1 ;
223+ }
224+
225+ batch_output_vector.push_back (torch::stack (tensor_vector));
223226 }
224227
225- torch::Tensor stacked_tensor = torch::stack (tensor_vector);
226- batch_output_vector.push_back (stacked_tensor);
228+ llama_print_timings (llama_ctx);
229+ } catch (std::runtime_error& e) {
230+ TS_LOG (ERROR, e.what ());
231+ } catch (const c10::Error& e) {
232+ TS_LOGF (ERROR, " Failed to apply inference on input, c10 error:{}" , e.msg ());
227233 }
228-
229- llama_print_timings (llama_ctx);
230- return torch::stack (batch_output_vector);
234+ return batch_output_vector;
231235}
232236
233237void LlamaCppHandler::Postprocess (
0 commit comments