Skip to content

Commit

Permalink
Processing all the items in the batch
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <[email protected]>
  • Loading branch information
shrinath-suresh committed Sep 15, 2023
1 parent c95a576 commit eda470f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 56 deletions.
118 changes: 63 additions & 55 deletions cpp/src/examples/llamacpp/llamacpp_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ LlamacppHandler::LoadModel(
if (json.find("checkpoint_path") != json.items().end()) {
checkpoint_path = json["checkpoint_path"].asString();
} else {
std::cerr
<< "Required field 'checkpoint_path' not found in JSON."
<< std::endl;
std::cerr << "Required field 'checkpoint_path' not found in JSON."
<< std::endl;
throw;
}

params.model = checkpoint_path;
params.main_gpu = 0;
params.n_gpu_layers = 35;
Expand Down Expand Up @@ -76,7 +74,7 @@ std::vector<torch::jit::IValue> LlamacppHandler::Preprocess(
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceRequestBatch>& request_batch,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {

initialize_context();

std::vector<torch::jit::IValue> batch_ivalue;
Expand Down Expand Up @@ -163,78 +161,81 @@ torch::Tensor LlamacppHandler::Inference(
std::shared_ptr<torch::Device>& device,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
auto tokens_list_tensor = inputs[0].toTensor();
torch::InferenceMode guard;
std::vector<torch::Tensor> batch_output_vector;
for (const torch::jit::IValue& input : inputs) {
torch::Tensor tokens_list_tensor = input.toTensor();

int64_t num_elements = tokens_list_tensor.numel();
int64_t num_elements = tokens_list_tensor.numel();

// Convert the tensor to a vector of long values
std::vector<long> long_vector;
long_vector.reserve(num_elements);
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();
std::vector<llama_token> tokens_list;

auto data_ptr = tokens_list_tensor.data_ptr<int64_t>();
for (int64_t i = 0; i < num_elements; ++i) {
long_vector.push_back(data_ptr[i]);
}
for (int64_t i = 0; i < num_elements; ++i) {
tokens_list.push_back(data_ptr[i]);
}
const int n_gen = std::min(32, max_context_size);

std::vector<llama_token> tokens_list;
long pos = 0;
while (pos < n_gen) {
// evaluate the transformer

for (auto id : long_vector) {
tokens_list.push_back(id);
}
const int n_gen = std::min(32, max_context_size);
if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
llama_get_kv_cache_token_count(llama_ctx),
params.n_threads)) {
std::cout << "Failed to eval\n" << __func__ << std::endl;
break;
}

while (llama_get_kv_cache_token_count(llama_ctx) < n_gen) {
// evaluate the transformer
tokens_list.clear();

if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()),
llama_get_kv_cache_token_count(llama_ctx),
params.n_threads)) {
std::cout << "Failed to eval\n" << __func__ << std::endl;
break;
}
// sample the next token

tokens_list.clear();
llama_token new_token_id = 0;

// sample the next token
auto logits = llama_get_logits(llama_ctx);
auto n_vocab = llama_n_vocab(llama_ctx);

llama_token new_token_id = 0;
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);

auto logits = llama_get_logits(llama_ctx);
auto n_vocab = llama_n_vocab(llama_ctx);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(
llama_token_data{token_id, logits[token_id], 0.0f});
}

std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
llama_token_data_array candidates_p = {candidates.data(),
candidates.size(), false};

for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(
llama_token_data{token_id, logits[token_id], 0.0f});
}
new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p);

llama_token_data_array candidates_p = {candidates.data(), candidates.size(),
false};
// is it an end of stream ?
if (new_token_id == llama_token_eos(llama_ctx)) {
std::cout << "Reached [end of text]\n";
break;
}

new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p);
// print the new token :
std::cout << "New Token: "
<< llama_token_to_piece(llama_ctx, new_token_id) << std::endl;

// is it an end of stream ?
if (new_token_id == llama_token_eos(llama_ctx)) {
std::cout << "Reached [end of text]\n";
break;
// push this new token for next evaluation
tokens_list.push_back(new_token_id);
pos += 1;
}

// push this new token for next evaluation
tokens_list.push_back(new_token_id);
}
std::vector<torch::Tensor> tensor_vector;
for (auto id : tokens_list) {
torch::Tensor tensor = torch::tensor(id, torch::kLong);
tensor_vector.push_back(tensor);
}

std::vector<torch::Tensor> tensor_vector;
for (auto id : tokens_list) {
torch::Tensor tensor = torch::tensor(id, torch::kLong);
tensor_vector.push_back(tensor);
torch::Tensor stacked_tensor = torch::stack(tensor_vector);
batch_output_vector.push_back(stacked_tensor);
}

torch::Tensor stacked_tensor = torch::stack(tensor_vector);
llama_print_timings(llama_ctx);
llama_free(llama_ctx);
return stacked_tensor;
return torch::stack(batch_output_vector);
}

void LlamacppHandler::Postprocess(
Expand All @@ -254,6 +255,7 @@ void LlamacppHandler::Postprocess(
}

std::string generated_text_str = generated_text_stream.str();
std::cout << "Generated Text Str: " << generated_text_str << std::endl;

auto response = (*response_batch)[kv.second];

Expand All @@ -279,6 +281,12 @@ void LlamacppHandler::Postprocess(
}
}

LlamacppHandler::~LlamacppHandler() noexcept {
llama_free(llama_ctx);
llama_free_model(llamamodel);
llama_backend_free();
}

} // namespace llm

#if defined(__linux__) || defined(__APPLE__)
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/examples/llamacpp/llamacpp_handler.hh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler {
// NOLINTBEGIN(bugprone-exception-escape)
LlamacppHandler() = default;
// NOLINTEND(bugprone-exception-escape)
~LlamacppHandler() override = default;
~LlamacppHandler() noexcept;

void initialize_context();

Expand Down

0 comments on commit eda470f

Please sign in to comment.