From ee072773004426700badb481c9610074ff0278d9 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Thu, 17 Aug 2023 00:11:46 +0530 Subject: [PATCH 01/11] Version1 of llm inference with cpp backend Signed-off-by: Shrinath Suresh --- cpp/build.sh | 4 + cpp/src/examples/CMakeLists.txt | 23 +++ cpp/src/examples/llm/llm_handler.cc | 139 ++++++++++++++++++ cpp/src/examples/llm/llm_handler.hh | 36 +++++ .../torch_scripted_backend_test.cc | 9 ++ .../llm/llm_handler/prompt.txt | 1 + 6 files changed, 212 insertions(+) create mode 100644 cpp/src/examples/llm/llm_handler.cc create mode 100644 cpp/src/examples/llm/llm_handler.hh create mode 100644 cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt diff --git a/cpp/build.sh b/cpp/build.sh index 23df4df722..bd08f7c4a4 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -299,6 +299,10 @@ function build() { mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so fi + if [ -f "$DEPS_DIR/../src/examples/libllm_handler.so" ]; then + mv $DEPS_DIR/../src/examples/libllm_handler.so $DEPS_DIR/../../test/resources/torchscript_model/llm/llm_handler/libllm_handler.so + fi + cd $DEPS_DIR/../.. if [ -f "$DEPS_DIR/../test/torchserve_cpp_test" ]; then $DEPS_DIR/../test/torchserve_cpp_test diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 4c9c534097..66d48ee066 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -5,3 +5,26 @@ list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) + +set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llm") +set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp") +set(LLM_SOURCE_FILES "") +list(APPEND LLM_SOURCE_FILES ${LLM_SRC_DIR}/llm_handler.cc) +add_library(llm_handler SHARED ${LLM_SOURCE_FILES}) +target_include_directories(llm_handler PUBLIC ${LLM_SRC_DIR}) +target_include_directories(llm_handler PUBLIC ${LLAMACPP_SRC_DIR}) +target_link_libraries(llm_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) + + +set(MY_OBJECT_FILES + ${LLAMACPP_SRC_DIR}/ggml.o + ${LLAMACPP_SRC_DIR}/llama.o + ${LLAMACPP_SRC_DIR}/common.o + ${LLAMACPP_SRC_DIR}/k_quants.o + ${LLAMACPP_SRC_DIR}/ggml-alloc.o + ${LLAMACPP_SRC_DIR}/grammar-parser.o + ${LLAMACPP_SRC_DIR}/console.o + +) + +target_sources(llm_handler PRIVATE ${MY_OBJECT_FILES}) diff --git a/cpp/src/examples/llm/llm_handler.cc b/cpp/src/examples/llm/llm_handler.cc new file mode 100644 index 0000000000..ace6fa2a57 --- /dev/null +++ b/cpp/src/examples/llm/llm_handler.cc @@ -0,0 +1,139 @@ +#include "src/examples/image_classifier/llm/llm_handler.hh" + +#include +#include + +#include + +#include "examples/common.h" +#include "ggml.h" +#include "llama.h" + +namespace llm { + +std::vector LlmHandler::Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) { + /** + * @brief + * Ref: + * https://github.com/pytorch/serve/blob/master/ts/torch_handler/vision_handler.py#L27 + */ + + // Model Loading + gpt_params params; + params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; + params.prompt = "Hello my name is"; + llama_backend_init(params.numa); + + + llama_model* model; + llama_context* ctx; + + std::tie(model, ctx) = llama_init_from_gpt_params(params); + + if (model == NULL) { + std::cout << "<<<<<<<<<<< tokens_list; + tokens_list = ::llama_tokenize(ctx, params.prompt, true); + + // const int max_context_size = llama_n_ctx(ctx); + const int max_context_size = 64; + const int max_tokens_list_size = max_context_size - 4; + + if ((int)tokens_list.size() > max_tokens_list_size) { + std::cout << __func__ << ": error: prompt too long (" << tokens_list.size() + << " tokens, max " << max_tokens_list_size << ")\n"; + } + + // Print the tokens from the prompt : + + for (auto id : tokens_list) { + std::cout << llama_token_to_str(ctx, id) << std::endl; + } + + // Prediction loop + + std::vector generated_tokens; + + while (llama_get_kv_cache_token_count(ctx) < max_context_size) { + //--------------------------------- + // Evaluate the tokens : + //--------------------------------- + + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), + llama_get_kv_cache_token_count(ctx), params.n_threads)) { + std::cout << "Evaluation Failed" << __func__ << std::endl; + // return 1; + // TODO: Raise exception here + } + + // tokens_list.clear(); + + //--------------------------------- + // Select the best prediction : + //--------------------------------- + + llama_token new_token_id = 0; + + auto logits = llama_get_logits(ctx); + auto n_vocab = + llama_n_vocab(ctx); // the size of the LLM vocabulary (in tokens) + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back( + llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), candidates.size(), + false}; + + // Select it using the "Greedy sampling" method : + new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream ? + if (new_token_id == llama_token_eos()) { + // fprintf(stderr, " [end of text]\n"); + break; + } + + generated_tokens.push_back(llama_token_to_str(ctx, new_token_id)); + + // Print the new token : + std::cout << llama_token_to_str(ctx, new_token_id) << std::endl; + + // Push this new token for next evaluation : + tokens_list.push_back(new_token_id); + + } // wend of main loop + + torch::Tensor tokens_tensor = + torch::from_blob(tokens_list.data(), + {static_cast(tokens_list.size())}, torch::kInt64); + +} + +} // namespace llm + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::torchscripted::BaseHandler* allocatorLlmHandler() { + return new llm::LlmHandler(); +} + +void deleterLlmHandler(torchserve::torchscripted::BaseHandler* p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/cpp/src/examples/llm/llm_handler.hh b/cpp/src/examples/llm/llm_handler.hh new file mode 100644 index 0000000000..282aae0fe2 --- /dev/null +++ b/cpp/src/examples/llm/llm_handler.hh @@ -0,0 +1,36 @@ +#ifndef LLM_HANDLER_HH_ +#define LLM_HANDLER_HH_ + +#include "src/backends/torch_scripted/handler/base_handler.hh" + +namespace llm { +class LlmHandler : public torchserve::torchscripted::BaseHandler { + public: + // NOLINTBEGIN(bugprone-exception-escape) + LlmHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~LlmHandler() override = default; + + std::vector Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + // torch::Tensor Inference( + // std::shared_ptr model, + // std::vector& inputs, + // std::shared_ptr& device, + // std::pair&>& + // idx_to_req_id, std::shared_ptr& + // response_batch) override; + + // void Postprocess( + // const torch::Tensor& data, + // std::pair&>& + // idx_to_req_id, std::shared_ptr& + // response_batch) override; +}; +} // namespace llm +#endif // LLM_HANDLER_HH_ diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index b3099d1a2a..131893da2d 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -78,6 +78,15 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) { "mnist_ts", 200); } +TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { + this->LoadPredict(std::make_shared( + "test/resources/torchscript_model/llm/llm_handler", + "llm", -1, "", "", 1, false), + "test/resources/torchscript_model/llm/llm_handler", + "test/resources/torchscript_model/llm/llm_handler/prompt.txt", + "llm_ts", 200); +} + TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) { auto result = backend_->Initialize("test/resources/torchscript_model/mnist"); ASSERT_EQ(result, false); diff --git a/cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt b/cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt new file mode 100644 index 0000000000..6e3c30c691 --- /dev/null +++ b/cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt @@ -0,0 +1 @@ +Hello my name is \ No newline at end of file From 8addde30a4aaad4b53e169b210ba1183eb983a92 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Fri, 18 Aug 2023 01:00:55 +0530 Subject: [PATCH 02/11] Updating llm handler - loadmodel, preprocess, inference methods Signed-off-by: Shrinath Suresh --- cpp/src/examples/llm/llm_handler.cc | 174 +++++++++++++++++++--------- cpp/src/examples/llm/llm_handler.hh | 25 ++-- 2 files changed, 135 insertions(+), 64 deletions(-) diff --git a/cpp/src/examples/llm/llm_handler.cc b/cpp/src/examples/llm/llm_handler.cc index ace6fa2a57..e626f11baa 100644 --- a/cpp/src/examples/llm/llm_handler.cc +++ b/cpp/src/examples/llm/llm_handler.cc @@ -11,80 +11,141 @@ namespace llm { +std::pair, + std::shared_ptr> +LlmHandler::LoadModel( + std::shared_ptr& load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + // Load dummy model + auto module = std::make_shared( + torch::jit::load(fmt::format("{}/{}", load_model_request->model_dir, + manifest_->GetModel().serialized_file), + *device)); + + // Load LLM + gpt_params params; + // TODO: Fetch the path from context + params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; + llama_backend_init(params.numa); + std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params); + + + return std::make_pair(module, device); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + std::vector LlmHandler::Preprocess( std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { - /** - * @brief - * Ref: - * https://github.com/pytorch/serve/blob/master/ts/torch_handler/vision_handler.py#L27 - */ + std::vector batch_ivalue; + std::vector batch_tensors; + + for (auto& request : *request_batch) { + try { + std::vector new_data = request.parameters["data"]; + std::string msg = torchserve::Converter::VectorToStr(new_data); + + // tokenization + + std::vector tokens_list; + tokens_list = ::llama_tokenize(llama_ctx, msg, true); + + // const int max_context_size = llama_n_ctx(ctx); + const int max_context_size = 64; + const int max_tokens_list_size = max_context_size - 4; + + if ((int)tokens_list.size() > max_tokens_list_size) { + std::cout << __func__ << ": error: prompt too long (" + << tokens_list.size() << " tokens, max " + << max_tokens_list_size << ")\n"; + } + + // Print the tokens from the prompt : + std::vector tensor_vector; + for (auto id : tokens_list) { + torch::Tensor tensor = torch::tensor(id, torch::kInt64); + tensor_vector.push_back(tensor); + } + + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + batch_ivalue.push_back(stacked_tensor); + + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } - // Model Loading - gpt_params params; - params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; - params.prompt = "Hello my name is"; - llama_backend_init(params.numa); + return batch_ivalue; +} +torch::Tensor LlmHandler::Inference( + std::shared_ptr model, + std::vector& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) { + auto tokens_list_tensor = inputs[0].toTensor(); - llama_model* model; - llama_context* ctx; + int64_t num_elements = tokens_list_tensor.numel(); - std::tie(model, ctx) = llama_init_from_gpt_params(params); + // Convert the tensor to a vector of long values + std::vector long_vector; + long_vector.reserve(num_elements); - if (model == NULL) { - std::cout << "<<<<<<<<<<<(); + for (int64_t i = 0; i < num_elements; ++i) { + long_vector.push_back(data_ptr[i]); } - // Tokenization - std::vector tokens_list; - tokens_list = ::llama_tokenize(ctx, params.prompt, true); - // const int max_context_size = llama_n_ctx(ctx); - const int max_context_size = 64; - const int max_tokens_list_size = max_context_size - 4; - - if ((int)tokens_list.size() > max_tokens_list_size) { - std::cout << __func__ << ": error: prompt too long (" << tokens_list.size() - << " tokens, max " << max_tokens_list_size << ")\n"; - } - - // Print the tokens from the prompt : - - for (auto id : tokens_list) { - std::cout << llama_token_to_str(ctx, id) << std::endl; + for (auto id : long_vector) { + tokens_list.push_back(id); } - // Prediction loop - std::vector generated_tokens; + gpt_params params; - while (llama_get_kv_cache_token_count(ctx) < max_context_size) { - //--------------------------------- - // Evaluate the tokens : - //--------------------------------- + const int max_context_size = 64; + + while (llama_get_kv_cache_token_count(llama_ctx) < max_context_size) { - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), - llama_get_kv_cache_token_count(ctx), params.n_threads)) { + if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), + llama_get_kv_cache_token_count(llama_ctx), + params.n_threads)) { std::cout << "Evaluation Failed" << __func__ << std::endl; - // return 1; // TODO: Raise exception here } - // tokens_list.clear(); - - //--------------------------------- - // Select the best prediction : - //--------------------------------- - llama_token new_token_id = 0; - auto logits = llama_get_logits(ctx); - auto n_vocab = - llama_n_vocab(ctx); // the size of the LLM vocabulary (in tokens) + auto logits = llama_get_logits(llama_ctx); + auto n_vocab = llama_n_vocab(llama_ctx); std::vector candidates; candidates.reserve(n_vocab); @@ -97,31 +158,30 @@ std::vector LlmHandler::Preprocess( llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - // Select it using the "Greedy sampling" method : - new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); - // is it an end of stream ? if (new_token_id == llama_token_eos()) { - // fprintf(stderr, " [end of text]\n"); break; } - generated_tokens.push_back(llama_token_to_str(ctx, new_token_id)); + generated_tokens.push_back(llama_token_to_str(llama_ctx, new_token_id)); // Print the new token : - std::cout << llama_token_to_str(ctx, new_token_id) << std::endl; + std::cout << llama_token_to_str(llama_ctx, new_token_id) << std::endl; // Push this new token for next evaluation : tokens_list.push_back(new_token_id); } // wend of main loop - torch::Tensor tokens_tensor = + torch::Tensor inference_result = torch::from_blob(tokens_list.data(), - {static_cast(tokens_list.size())}, torch::kInt64); + {static_cast(tokens_list.size())}, torch::kInt32); + return inference_result; } + } // namespace llm #if defined(__linux__) || defined(__APPLE__) diff --git a/cpp/src/examples/llm/llm_handler.hh b/cpp/src/examples/llm/llm_handler.hh index 282aae0fe2..2bab9849c7 100644 --- a/cpp/src/examples/llm/llm_handler.hh +++ b/cpp/src/examples/llm/llm_handler.hh @@ -1,16 +1,27 @@ #ifndef LLM_HANDLER_HH_ #define LLM_HANDLER_HH_ +#include "examples/common.h" +#include "ggml.h" +#include "llama.h" #include "src/backends/torch_scripted/handler/base_handler.hh" namespace llm { class LlmHandler : public torchserve::torchscripted::BaseHandler { + private: + llama_model* llamamodel; + llama_context* llama_ctx; + public: // NOLINTBEGIN(bugprone-exception-escape) LlmHandler() = default; // NOLINTEND(bugprone-exception-escape) ~LlmHandler() override = default; + virtual std::pair, + std::shared_ptr> + LoadModel(std::shared_ptr& load_model_request); + std::vector Preprocess( std::shared_ptr& device, std::pair&>& idx_to_req_id, @@ -18,13 +29,13 @@ class LlmHandler : public torchserve::torchscripted::BaseHandler { std::shared_ptr& response_batch) override; - // torch::Tensor Inference( - // std::shared_ptr model, - // std::vector& inputs, - // std::shared_ptr& device, - // std::pair&>& - // idx_to_req_id, std::shared_ptr& - // response_batch) override; + torch::Tensor Inference( + std::shared_ptr model, + std::vector& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; // void Postprocess( // const torch::Tensor& data, From 002e2215cd02b1785c9973269988fc52221af9d3 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Mon, 21 Aug 2023 15:37:58 +0530 Subject: [PATCH 03/11] Fixed infinite lock by adding request ids to the preprocess method Signed-off-by: Shrinath Suresh --- cpp/src/examples/llm/llm_handler.cc | 155 ++++++++++++++++++++++++---- cpp/src/examples/llm/llm_handler.hh | 13 ++- 2 files changed, 141 insertions(+), 27 deletions(-) diff --git a/cpp/src/examples/llm/llm_handler.cc b/cpp/src/examples/llm/llm_handler.cc index e626f11baa..09270341cc 100644 --- a/cpp/src/examples/llm/llm_handler.cc +++ b/cpp/src/examples/llm/llm_handler.cc @@ -1,4 +1,4 @@ -#include "src/examples/image_classifier/llm/llm_handler.hh" +nclude "src/examples/image_classifier/llm/llm_handler.hh" #include #include @@ -11,6 +11,29 @@ namespace llm { +void LlmHandler::initialize_context() { + // gpt_params params; + params.seed = 42; + params.n_threads = 4; + params.repeat_last_n = 64; + + auto lparams = llama_context_default_params(); + lparams.n_ctx = params.n_ctx; + lparams.n_gqa = params.n_gqa; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + + llama_ctx = llama_new_context_with_model(llamamodel, lparams); + + if (llama_ctx == nullptr) { + std::cerr << "Failed to initialize llama context" << std::endl; + } else { + std::cout << "Context initialized successfully" << std::endl; + } +} + std::pair, std::shared_ptr> LlmHandler::LoadModel( @@ -23,13 +46,24 @@ LlmHandler::LoadModel( manifest_->GetModel().serialized_file), *device)); - // Load LLM - gpt_params params; - // TODO: Fetch the path from context params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; - llama_backend_init(params.numa); - std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params); - + auto lparams = llama_context_default_params(); + lparams.n_ctx = params.n_ctx; + lparams.n_gqa = params.n_gqa; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + llamamodel = llama_load_model_from_file(params.model.c_str(), lparams); + // llama_ctx = llama_new_context_with_model(llamamodel, lparams); + // initialize_context(); + + // // Load LLM + // gpt_params params; + // // TODO: Fetch the path from context + // params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; + // llama_backend_init(params.numa); + // std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params); return std::make_pair(module, device); } catch (const c10::Error& e) { @@ -50,13 +84,48 @@ std::vector LlmHandler::Preprocess( std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { + std::cout << "Initializing llama context" << std::endl; + + initialize_context(); + + std::cout << "Llama context initialized" << std::endl; + std::vector batch_ivalue; std::vector batch_tensors; - + uint8_t idx = 0; for (auto& request : *request_batch) { try { - std::vector new_data = request.parameters["data"]; - std::string msg = torchserve::Converter::VectorToStr(new_data); + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::cout << "Received Input: " << data_it->second << std::endl; + + // std::vector new_data = request.parameters["data"]; + // std::string msg = torchserve::Converter::VectorToStr(new_data); + std::string msg = torchserve::Converter::VectorToStr(data_it->second); // tokenization @@ -82,6 +151,7 @@ std::vector LlmHandler::Preprocess( torch::Tensor stacked_tensor = torch::stack(tensor_vector); batch_ivalue.push_back(stacked_tensor); + idx_to_req_id.second[idx++] = request.request_id; } catch (const std::runtime_error& e) { TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", @@ -128,13 +198,11 @@ torch::Tensor LlmHandler::Inference( tokens_list.push_back(id); } - std::vector generated_tokens; - gpt_params params; + // gpt_params params; const int max_context_size = 64; while (llama_get_kv_cache_token_count(llama_ctx) < max_context_size) { - if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(llama_ctx), params.n_threads)) { @@ -164,23 +232,66 @@ torch::Tensor LlmHandler::Inference( break; } - generated_tokens.push_back(llama_token_to_str(llama_ctx, new_token_id)); - - // Print the new token : - std::cout << llama_token_to_str(llama_ctx, new_token_id) << std::endl; + std::cout << "New Token: " << llama_token_to_str(llama_ctx, new_token_id); // Push this new token for next evaluation : tokens_list.push_back(new_token_id); + } - } // wend of main loop + std::vector tensor_vector; + for (auto id : tokens_list) { + torch::Tensor tensor = torch::tensor(id, torch::kLong); + tensor_vector.push_back(tensor); + } - torch::Tensor inference_result = - torch::from_blob(tokens_list.data(), - {static_cast(tokens_list.size())}, torch::kInt32); + torch::Tensor stacked_tensor = torch::stack(tensor_vector); - return inference_result; + llama_free(llama_ctx); + return stacked_tensor; } +void LlmHandler::Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) { + for (const auto& kv : idx_to_req_id.second) { + try { + int64_t num_elements = data.numel(); + + // Convert the tensor to a vector of long values + std::stringstream generated_text_stream; + + auto data_ptr = data.data_ptr(); + for (int64_t i = 0; i < num_elements; ++i) { + generated_text_stream << llama_token_to_str(llama_ctx, data_ptr[i]); + } + + std::string generated_text_str = generated_text_stream.str(); + std::cout << "Generated Text Str: " << generated_text_str << std::endl; + + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + generated_text_str); + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} } // namespace llm diff --git a/cpp/src/examples/llm/llm_handler.hh b/cpp/src/examples/llm/llm_handler.hh index 2bab9849c7..288c7ef89b 100644 --- a/cpp/src/examples/llm/llm_handler.hh +++ b/cpp/src/examples/llm/llm_handler.hh @@ -9,6 +9,7 @@ namespace llm { class LlmHandler : public torchserve::torchscripted::BaseHandler { private: + gpt_params params; llama_model* llamamodel; llama_context* llama_ctx; @@ -18,6 +19,8 @@ class LlmHandler : public torchserve::torchscripted::BaseHandler { // NOLINTEND(bugprone-exception-escape) ~LlmHandler() override = default; + void initialize_context(); + virtual std::pair, std::shared_ptr> LoadModel(std::shared_ptr& load_model_request); @@ -37,11 +40,11 @@ class LlmHandler : public torchserve::torchscripted::BaseHandler { std::shared_ptr& response_batch) override; - // void Postprocess( - // const torch::Tensor& data, - // std::pair&>& - // idx_to_req_id, std::shared_ptr& - // response_batch) override; + void Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; }; } // namespace llm #endif // LLM_HANDLER_HH_ From f351d1d3c961ad7002446677f2e6bbd7ab9ac2b1 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Mon, 21 Aug 2023 22:55:41 +0530 Subject: [PATCH 04/11] Adding test script for finding tokens per second llama-7b-chat and ggml version Signed-off-by: Shrinath Suresh --- .../llm/Llama-2-7b-chat-ggml-hf.py | 25 +++++++++++ .../llm/Llama-2-7b-chat-hf.py | 42 +++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py create mode 100644 cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py diff --git a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py new file mode 100644 index 0000000000..6e581d649a --- /dev/null +++ b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py @@ -0,0 +1,25 @@ +import requests +import json +import time + +def send_text_file(url, file_path): + with open(file_path, 'rb') as fp: + file_bytes = fp.read() + + start_time = time.time() + response = requests.post(url, data=file_bytes) + time_taken = time.time() - start_time + generated_answer = response.text + print("Generated Anser: ", generated_answer) + number_of_tokens = len(generated_answer.split(' ')) + print("Number of tokens: ", number_of_tokens) + print("Time taken: ", time_taken) + print("Tokens per second:", number_of_tokens / int(time_taken)) + + +if __name__ == "__main__": + url = "http://localhost:8080/predictions/llm" + file_path = "llm_handler/prompt.txt" + + send_text_file(url, file_path) + diff --git a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py new file mode 100644 index 0000000000..f2cda37961 --- /dev/null +++ b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py @@ -0,0 +1,42 @@ +from transformers import AutoTokenizer +import transformers +import torch +import time + +model = "meta-llama/Llama-2-7b-chat-hf" +hf_api_key = "" + +tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=hf_api_key) +pipeline = transformers.pipeline( + "text-generation", + model=model, + torch_dtype=torch.float16, + device_map="auto", + use_auth_token=hf_api_key +) + +start_time = time.time() +sequences = pipeline( + 'Hello my name is\n', + do_sample=True, + top_k=10, + num_return_sequences=1, + eos_token_id=tokenizer.eos_token_id, + max_length=512, +) +result = "" +for seq in sequences: + result += seq['generated_text'] + print(f"Result: {seq['generated_text']}") +time_taken = time.time() - start_time + +print("Generated String:", result) +print("Total time taken:", time_taken) + +num_words = len(result.split(' ')) + +print("Total words generated: ", num_words) + +tokens_per_second = num_words / int(time_taken) + +print("Tokens per second: ", tokens_per_second) From e3a753c658b7489ca8f8a979ac7a857123f6d5ae Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Wed, 13 Sep 2023 09:52:32 +0530 Subject: [PATCH 05/11] GGUF Compatibility Signed-off-by: Shrinath Suresh --- cpp/build.sh | 4 +- cpp/src/examples/CMakeLists.txt | 14 +-- .../llamacpp_handler.cc} | 94 +++++++------------ .../llamacpp_handler.hh} | 10 +- .../torch_scripted_backend_test.cc | 13 +-- 5 files changed, 58 insertions(+), 77 deletions(-) rename cpp/src/examples/{llm/llm_handler.cc => llamacpp/llamacpp_handler.cc} (79%) rename cpp/src/examples/{llm/llm_handler.hh => llamacpp/llamacpp_handler.hh} (86%) diff --git a/cpp/build.sh b/cpp/build.sh index bd08f7c4a4..2a962d7a9e 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -299,8 +299,8 @@ function build() { mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so fi - if [ -f "$DEPS_DIR/../src/examples/libllm_handler.so" ]; then - mv $DEPS_DIR/../src/examples/libllm_handler.so $DEPS_DIR/../../test/resources/torchscript_model/llm/llm_handler/libllm_handler.so + if [ -f "$DEPS_DIR/../src/examples/libllamacpp_handler.so" ]; then + mv $DEPS_DIR/../src/examples/libllamacpp_handler.so $DEPS_DIR/../../test/resources/torchscript_model/llamacpp/llamacpp_handler/libllamacpp_handler.so fi cd $DEPS_DIR/../.. diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 66d48ee066..6f8441d190 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -6,14 +6,14 @@ add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) -set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llm") +set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llamacpp") set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp") set(LLM_SOURCE_FILES "") -list(APPEND LLM_SOURCE_FILES ${LLM_SRC_DIR}/llm_handler.cc) -add_library(llm_handler SHARED ${LLM_SOURCE_FILES}) -target_include_directories(llm_handler PUBLIC ${LLM_SRC_DIR}) -target_include_directories(llm_handler PUBLIC ${LLAMACPP_SRC_DIR}) -target_link_libraries(llm_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +list(APPEND LLM_SOURCE_FILES ${LLM_SRC_DIR}/llamacpp_handler.cc) +add_library(llamacpp_handler SHARED ${LLM_SOURCE_FILES}) +target_include_directories(llamacpp_handler PUBLIC ${LLM_SRC_DIR}) +target_include_directories(llamacpp_handler PUBLIC ${LLAMACPP_SRC_DIR}) +target_link_libraries(llamacpp_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) set(MY_OBJECT_FILES @@ -27,4 +27,4 @@ set(MY_OBJECT_FILES ) -target_sources(llm_handler PRIVATE ${MY_OBJECT_FILES}) +target_sources(llamacpp_handler PRIVATE ${MY_OBJECT_FILES}) diff --git a/cpp/src/examples/llm/llm_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc similarity index 79% rename from cpp/src/examples/llm/llm_handler.cc rename to cpp/src/examples/llamacpp/llamacpp_handler.cc index 09270341cc..fe1a2f7bf0 100644 --- a/cpp/src/examples/llm/llm_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -1,31 +1,14 @@ -nclude "src/examples/image_classifier/llm/llm_handler.hh" +#include "src/examples/llamacpp/llamacpp_handler.hh" #include #include #include -#include "examples/common.h" -#include "ggml.h" -#include "llama.h" - namespace llm { -void LlmHandler::initialize_context() { - // gpt_params params; - params.seed = 42; - params.n_threads = 4; - params.repeat_last_n = 64; - - auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_gqa = params.n_gqa; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - - llama_ctx = llama_new_context_with_model(llamamodel, lparams); +void LlamacppHandler::initialize_context() { + llama_ctx = llama_new_context_with_model(llamamodel, ctx_params); if (llama_ctx == nullptr) { std::cerr << "Failed to initialize llama context" << std::endl; @@ -36,7 +19,7 @@ void LlmHandler::initialize_context() { std::pair, std::shared_ptr> -LlmHandler::LoadModel( +LlamacppHandler::LoadModel( std::shared_ptr& load_model_request) { try { auto device = GetTorchDevice(load_model_request); @@ -46,24 +29,13 @@ LlmHandler::LoadModel( manifest_->GetModel().serialized_file), *device)); - params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; - auto lparams = llama_context_default_params(); - lparams.n_ctx = params.n_ctx; - lparams.n_gqa = params.n_gqa; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - llamamodel = llama_load_model_from_file(params.model.c_str(), lparams); - // llama_ctx = llama_new_context_with_model(llamamodel, lparams); - // initialize_context(); - - // // Load LLM - // gpt_params params; - // // TODO: Fetch the path from context - // params.model = "/home/ubuntu/serve/cpp/llama-2-7b-chat.ggmlv3.q4_0.bin"; - // llama_backend_init(params.numa); - // std::tie(llamamodel, llama_ctx) = llama_init_from_gpt_params(params); + params.model = "/home/ubuntu/gpu/llama.cpp/llama-2-7b-chat.Q4_0.gguf"; + params.main_gpu = 0; + params.n_gpu_layers = 35; + + llama_backend_init(params.numa); + ctx_params = llama_context_default_params(); + llamamodel = llama_load_model_from_file(params.model.c_str(), ctx_params); return std::make_pair(module, device); } catch (const c10::Error& e) { @@ -79,7 +51,7 @@ LlmHandler::LoadModel( } } -std::vector LlmHandler::Preprocess( +std::vector LlamacppHandler::Preprocess( std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, @@ -133,7 +105,6 @@ std::vector LlmHandler::Preprocess( tokens_list = ::llama_tokenize(llama_ctx, msg, true); // const int max_context_size = llama_n_ctx(ctx); - const int max_context_size = 64; const int max_tokens_list_size = max_context_size - 4; if ((int)tokens_list.size() > max_tokens_list_size) { @@ -173,7 +144,7 @@ std::vector LlmHandler::Preprocess( return batch_ivalue; } -torch::Tensor LlmHandler::Inference( +torch::Tensor LlamacppHandler::Inference( std::shared_ptr model, std::vector& inputs, std::shared_ptr& device, @@ -197,19 +168,22 @@ torch::Tensor LlmHandler::Inference( for (auto id : long_vector) { tokens_list.push_back(id); } + const int n_gen = std::min(32, max_context_size); - // gpt_params params; - - const int max_context_size = 64; + while (llama_get_kv_cache_token_count(llama_ctx) < n_gen) { + // evaluate the transformer - while (llama_get_kv_cache_token_count(llama_ctx) < max_context_size) { if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(llama_ctx), params.n_threads)) { - std::cout << "Evaluation Failed" << __func__ << std::endl; - // TODO: Raise exception here + std::cout << "Failed to eval\n" << __func__ << std::endl; + break; } + tokens_list.clear(); + + // sample the next token + llama_token new_token_id = 0; auto logits = llama_get_logits(llama_ctx); @@ -228,13 +202,17 @@ torch::Tensor LlmHandler::Inference( new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); - if (new_token_id == llama_token_eos()) { + // is it an end of stream ? + if (new_token_id == llama_token_eos(llama_ctx)) { + std::cout << "Reached [end of text]\n"; break; } - std::cout << "New Token: " << llama_token_to_str(llama_ctx, new_token_id); + // print the new token : + std::cout << "New Token: " << llama_token_to_piece(llama_ctx, new_token_id) + << std::endl; - // Push this new token for next evaluation : + // push this new token for next evaluation tokens_list.push_back(new_token_id); } @@ -245,12 +223,12 @@ torch::Tensor LlmHandler::Inference( } torch::Tensor stacked_tensor = torch::stack(tensor_vector); - + llama_print_timings(llama_ctx); llama_free(llama_ctx); return stacked_tensor; } -void LlmHandler::Postprocess( +void LlamacppHandler::Postprocess( const torch::Tensor& data, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) { @@ -263,7 +241,7 @@ void LlmHandler::Postprocess( auto data_ptr = data.data_ptr(); for (int64_t i = 0; i < num_elements; ++i) { - generated_text_stream << llama_token_to_str(llama_ctx, data_ptr[i]); + generated_text_stream << llama_token_to_piece(llama_ctx, data_ptr[i]); } std::string generated_text_str = generated_text_stream.str(); @@ -297,13 +275,13 @@ void LlmHandler::Postprocess( #if defined(__linux__) || defined(__APPLE__) extern "C" { -torchserve::torchscripted::BaseHandler* allocatorLlmHandler() { - return new llm::LlmHandler(); +torchserve::torchscripted::BaseHandler* allocatorLlamacppHandler() { + return new llm::LlamacppHandler(); } -void deleterLlmHandler(torchserve::torchscripted::BaseHandler* p) { +void deleterLlamacppHandler(torchserve::torchscripted::BaseHandler* p) { if (p != nullptr) { - delete static_cast(p); + delete static_cast(p); } } } diff --git a/cpp/src/examples/llm/llm_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh similarity index 86% rename from cpp/src/examples/llm/llm_handler.hh rename to cpp/src/examples/llamacpp/llamacpp_handler.hh index 288c7ef89b..43e77826ac 100644 --- a/cpp/src/examples/llm/llm_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -1,23 +1,25 @@ #ifndef LLM_HANDLER_HH_ #define LLM_HANDLER_HH_ -#include "examples/common.h" +#include "common/common.h" #include "ggml.h" #include "llama.h" #include "src/backends/torch_scripted/handler/base_handler.hh" namespace llm { -class LlmHandler : public torchserve::torchscripted::BaseHandler { +class LlamacppHandler : public torchserve::torchscripted::BaseHandler { private: gpt_params params; llama_model* llamamodel; + llama_context_params ctx_params; llama_context* llama_ctx; + const int max_context_size = 32; public: // NOLINTBEGIN(bugprone-exception-escape) - LlmHandler() = default; + LlamacppHandler() = default; // NOLINTEND(bugprone-exception-escape) - ~LlmHandler() override = default; + ~LlamacppHandler() override = default; void initialize_context(); diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index 131893da2d..b18a74fb84 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -79,12 +79,13 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) { } TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { - this->LoadPredict(std::make_shared( - "test/resources/torchscript_model/llm/llm_handler", - "llm", -1, "", "", 1, false), - "test/resources/torchscript_model/llm/llm_handler", - "test/resources/torchscript_model/llm/llm_handler/prompt.txt", - "llm_ts", 200); + this->LoadPredict( + std::make_shared( + "test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm", + -1, "", "", 1, false), + "test/resources/torchscript_model/llamacpp/llamacpp_handler", + "test/resources/torchscript_model/llamacpp/sentences.json", "llm_ts", + 200); } TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) { From c4ab8a15e058fe54c3ae608f518c0f0d9848bdd4 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Wed, 13 Sep 2023 10:00:44 +0530 Subject: [PATCH 06/11] Fixing unit tests Signed-off-by: Shrinath Suresh --- .../torch_scripted_backend_test.cc | 2 +- .../llamacpp_handler/MAR-INF/MANIFEST.json | 11 +++++ .../llamacpp/llamacpp_handler/dummy.pt | Bin 0 -> 3520 bytes .../{llm/llm_handler => llamacpp}/prompt.txt | 0 .../llm/Llama-2-7b-chat-ggml-hf.py | 25 ----------- .../llm/Llama-2-7b-chat-hf.py | 42 ------------------ 6 files changed, 12 insertions(+), 68 deletions(-) create mode 100644 cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json create mode 100644 cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt rename cpp/test/resources/torchscript_model/{llm/llm_handler => llamacpp}/prompt.txt (100%) delete mode 100644 cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py delete mode 100644 cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index b18a74fb84..e841c57ea1 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -84,7 +84,7 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { "test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm", -1, "", "", 1, false), "test/resources/torchscript_model/llamacpp/llamacpp_handler", - "test/resources/torchscript_model/llamacpp/sentences.json", "llm_ts", + "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", 200); } diff --git a/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..79c37e08ba --- /dev/null +++ b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,11 @@ +{ + "createdOn": "28/07/2020 06:32:08", + "runtime": "LSP", + "model": { + "modelName": "llamacpp", + "serializedFile": "dummy.pt", + "handler": "libllamacpp_handler:LlamacppHandler", + "modelVersion": "2.0" + }, + "archiverVersion": "0.2.0" +} \ No newline at end of file diff --git a/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt new file mode 100644 index 0000000000000000000000000000000000000000..2f4058d69ca4556bdee3f19929748f85b6edac60 GIT binary patch literal 3520 zcmbVO3se(V8lF6OO)zLcLfJ||tpp!QP3g7sM&yShG3_ueGH1i@YRo^x|2Gxz`g z```cZ{i!k*8$mohk>`dt;*V%7d3gmQjgnG|#8L?lS^3Q65%{gbAnp0&Sp4Gu_`vpJ`oT)F+8${O|E|mn1p?ByY zKipO~`(+QtLHzm~Yn9LtILF85B5%;}h55O9sU-RZ8iq8Dy;_-S45N%DwGPA3bRsWb zk0%>7xE>{uu#!T70-l93DG38%(2|Tz3V-k)Bg|NiMX$$HN_FnrN|OfD8I8FlGeyA! zUvtzXYXnnh#}*n4!aSqKqQ}ibJz>C=CN#;hkZijecOdvSPH1%$>85swU>1tdn^DY! zt1N_GgHgD_Y&2o(=Axv#K%npd5@xVdsl`bThX^Xw8JG?>xlCJk=FHH!!}g_7Q3a!O84P=E+S zRD{w@a&<+i;Cod=_`Z|o*b1eCBUKJ zgbwxtfeSC#b?_0M)4^)s3DK~7g!u(h33rsI-&)?ewpHxi{4k74+GAs}8$)|~lubKA zUR|;>dGf5&+*{ZG9yzO?zidPC_R>c+SLc1Qv#UusePLZwSM(HQVwAC{R`R<^SJ{L2 zD$mE-d^0*u`Car$5+8iiw{7c*ioV;ezg+&Vlz%Drp9#l!5d|?(N}FlFK$XVfFS^_$ zJ5%?p^7ix@KM(w{{@~-X07V_n{r!8db@~*)deE;g@+*Grp~#kVw|&Za6TIdGuz#BKO2*FR5F}!}D-miyd}uLy8ibF!oUudR z<*k*e2mtD(%jL<58Cd4BjHOt5hJ0Dd5-cN8lC+pa6wo;&ql{$A5NBkrB`6`5PpEVC zIO>Q^!a(T>Gl&?LAX|V_av35+%1O3;0Z(P5Te&>|j@Gi5;w*_zN}=j{DX03-wnfc7$M@g4`9)ACRT!}A7YWsAYmJLR z+jpaiUTyuF&76zU2kTCUPQON6Sb654BsI#5x1)O1f%aQAqjm5}CAX;PbJGW#S2dKY z4!k+%ntx7JZYZn0-ehe&dY9P!to@;?`_5_@Rat|_r+y9Yc=ZbtEe7#;*?Xo(8w=ie z-S$pdt2m(jVuR?`spR`N$RK^>u^(qlP5DXvx0!l3)&f=cir-X!mb`QBI$L+b+2Zqu z+Kx1peNf$Bxc!RlS-_RJ^-8XQ5h1GpUUhw><`_Koa z4V)_C+Bvp-_qQeYx+44aXlv~Q=9hU7)^zNBH&FhrG~xSzv%yE>Zam2kiuTIn*Zf#UmAu_exr(zEd7)3Dm5vd+x|byxPO5+aHi(2I}03Ag}j z_;gUe$LB*|By`w|1~a8JP-X`yPWxXn!paP@XJ($rD-)J^I32p|Gz|XH$K>J61e~bG zH5g&Yw@_x-T>PK;H@;nbgyXc!MVG@ik0ILte>>=sUX)BYWJVmL3i82^QVYl&P&s7G z9HR*mHp$Q-*@FL0w9HXN%Nmnt*-oN$8zx$Jz^-Ye2V_OYFwt%vpJ+YpMC*ksyucZg zKqbveG!$f$=tj9k;RjwpI?ggvrrE*6AG1l$MmgzicV3lOKm-E>3Y_C#=mg+gM?<&{ zH<0J|JopKB~eAAn9&Aub>tAQ{I287YpV&qS}sA{V$u zkZ@zcy@2aRgVW1Ij0^Zw&{n=A_-GqO;(%Tud|lMs0GTzUhE{=w?sV)zLPy8|O`rq7 zCDldZA&8+aksj!-zCAfM#7KaQ7(x&B3G`zilhC1aI5Q+w=I%b*xh04-XLvKm`82e8 j!wwq^IRJnI$iZrZtajpqyaeyuV3!Z78V7;<9ee)+!Vv`C literal 0 HcmV?d00001 diff --git a/cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt b/cpp/test/resources/torchscript_model/llamacpp/prompt.txt similarity index 100% rename from cpp/test/resources/torchscript_model/llm/llm_handler/prompt.txt rename to cpp/test/resources/torchscript_model/llamacpp/prompt.txt diff --git a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py deleted file mode 100644 index 6e581d649a..0000000000 --- a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-ggml-hf.py +++ /dev/null @@ -1,25 +0,0 @@ -import requests -import json -import time - -def send_text_file(url, file_path): - with open(file_path, 'rb') as fp: - file_bytes = fp.read() - - start_time = time.time() - response = requests.post(url, data=file_bytes) - time_taken = time.time() - start_time - generated_answer = response.text - print("Generated Anser: ", generated_answer) - number_of_tokens = len(generated_answer.split(' ')) - print("Number of tokens: ", number_of_tokens) - print("Time taken: ", time_taken) - print("Tokens per second:", number_of_tokens / int(time_taken)) - - -if __name__ == "__main__": - url = "http://localhost:8080/predictions/llm" - file_path = "llm_handler/prompt.txt" - - send_text_file(url, file_path) - diff --git a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py b/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py deleted file mode 100644 index f2cda37961..0000000000 --- a/cpp/test/resources/torchscript_model/llm/Llama-2-7b-chat-hf.py +++ /dev/null @@ -1,42 +0,0 @@ -from transformers import AutoTokenizer -import transformers -import torch -import time - -model = "meta-llama/Llama-2-7b-chat-hf" -hf_api_key = "" - -tokenizer = AutoTokenizer.from_pretrained(model, use_auth_token=hf_api_key) -pipeline = transformers.pipeline( - "text-generation", - model=model, - torch_dtype=torch.float16, - device_map="auto", - use_auth_token=hf_api_key -) - -start_time = time.time() -sequences = pipeline( - 'Hello my name is\n', - do_sample=True, - top_k=10, - num_return_sequences=1, - eos_token_id=tokenizer.eos_token_id, - max_length=512, -) -result = "" -for seq in sequences: - result += seq['generated_text'] - print(f"Result: {seq['generated_text']}") -time_taken = time.time() - start_time - -print("Generated String:", result) -print("Total time taken:", time_taken) - -num_words = len(result.split(' ')) - -print("Total words generated: ", num_words) - -tokens_per_second = num_words / int(time_taken) - -print("Tokens per second: ", tokens_per_second) From 8cdda73829af5dc384365f40ada33e760773d2a0 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Wed, 13 Sep 2023 10:03:14 +0530 Subject: [PATCH 07/11] Fix typo Signed-off-by: Shrinath Suresh --- cpp/src/examples/llamacpp/llamacpp_handler.hh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh index 43e77826ac..151ba170c7 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -1,5 +1,5 @@ -#ifndef LLM_HANDLER_HH_ -#define LLM_HANDLER_HH_ +#ifndef LLAMACPP_HANDLER_HH_ +#define LLAMACPP_HANDLER_HH_ #include "common/common.h" #include "ggml.h" @@ -49,4 +49,4 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler { override; }; } // namespace llm -#endif // LLM_HANDLER_HH_ +#endif // LLAMACPP_HANDLER_HH_ \ No newline at end of file From 026b836099ac4fa70efa4887db8cf8b2fb0dbe6f Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Wed, 13 Sep 2023 11:01:02 +0530 Subject: [PATCH 08/11] Using folly to read config path Signed-off-by: Shrinath Suresh --- cpp/src/examples/llamacpp/config.json | 5 +++++ cpp/src/examples/llamacpp/llamacpp_handler.cc | 22 ++++++++++++++++++- cpp/src/examples/llamacpp/llamacpp_handler.hh | 3 +++ 3 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 cpp/src/examples/llamacpp/config.json diff --git a/cpp/src/examples/llamacpp/config.json b/cpp/src/examples/llamacpp/config.json new file mode 100644 index 0000000000..6fb87978f3 --- /dev/null +++ b/cpp/src/examples/llamacpp/config.json @@ -0,0 +1,5 @@ +{ +"checkpoint_path" : "/home/ubuntu/llama-2-7b-chat.Q4_0.gguf" +} + + diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc index fe1a2f7bf0..d4c5831d64 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -29,7 +29,27 @@ LlamacppHandler::LoadModel( manifest_->GetModel().serialized_file), *device)); - params.model = "/home/ubuntu/gpu/llama.cpp/llama-2-7b-chat.Q4_0.gguf"; + const std::string configFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "config.json"); + std::string jsonContent; + if (!folly::readFile(configFilePath.c_str(), jsonContent)) { + std::cerr << "config.json not found at: " << configFilePath << std::endl; + throw; + } + folly::dynamic json; + json = folly::parseJson(jsonContent); + + std::string checkpoint_path; + if (json.find("checkpoint_path") != json.items().end()) { + checkpoint_path = json["checkpoint_path"].asString(); + } else { + std::cerr + << "Required field 'checkpoint_path' not found in JSON." + << std::endl; + throw; + } + + params.model = checkpoint_path; params.main_gpu = 0; params.n_gpu_layers = 35; diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh index 151ba170c7..305e1fc062 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -1,6 +1,9 @@ #ifndef LLAMACPP_HANDLER_HH_ #define LLAMACPP_HANDLER_HH_ +#include +#include + #include "common/common.h" #include "ggml.h" #include "llama.h" From c95a5768447f0b7e1dee8338cc42a991d3f012c4 Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Wed, 13 Sep 2023 11:04:55 +0530 Subject: [PATCH 09/11] Removing debug couts Signed-off-by: Shrinath Suresh --- cpp/src/examples/llamacpp/llamacpp_handler.cc | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc index d4c5831d64..75d43963f6 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -76,12 +76,9 @@ std::vector LlamacppHandler::Preprocess( std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { - std::cout << "Initializing llama context" << std::endl; - + initialize_context(); - std::cout << "Llama context initialized" << std::endl; - std::vector batch_ivalue; std::vector batch_tensors; uint8_t idx = 0; @@ -113,10 +110,6 @@ std::vector LlamacppHandler::Preprocess( continue; } - std::cout << "Received Input: " << data_it->second << std::endl; - - // std::vector new_data = request.parameters["data"]; - // std::string msg = torchserve::Converter::VectorToStr(new_data); std::string msg = torchserve::Converter::VectorToStr(data_it->second); // tokenization @@ -228,10 +221,6 @@ torch::Tensor LlamacppHandler::Inference( break; } - // print the new token : - std::cout << "New Token: " << llama_token_to_piece(llama_ctx, new_token_id) - << std::endl; - // push this new token for next evaluation tokens_list.push_back(new_token_id); } @@ -265,7 +254,6 @@ void LlamacppHandler::Postprocess( } std::string generated_text_str = generated_text_stream.str(); - std::cout << "Generated Text Str: " << generated_text_str << std::endl; auto response = (*response_batch)[kv.second]; From eda470f09ae243b6ccccccb94cf0b8ebd159c84b Mon Sep 17 00:00:00 2001 From: Shrinath Suresh Date: Fri, 15 Sep 2023 10:23:19 +0530 Subject: [PATCH 10/11] Processing all the items in the batch Signed-off-by: Shrinath Suresh --- cpp/src/examples/llamacpp/llamacpp_handler.cc | 118 ++++++++++-------- cpp/src/examples/llamacpp/llamacpp_handler.hh | 2 +- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc index 75d43963f6..1ad72e05cc 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -43,12 +43,10 @@ LlamacppHandler::LoadModel( if (json.find("checkpoint_path") != json.items().end()) { checkpoint_path = json["checkpoint_path"].asString(); } else { - std::cerr - << "Required field 'checkpoint_path' not found in JSON." - << std::endl; + std::cerr << "Required field 'checkpoint_path' not found in JSON." + << std::endl; throw; } - params.model = checkpoint_path; params.main_gpu = 0; params.n_gpu_layers = 35; @@ -76,7 +74,7 @@ std::vector LlamacppHandler::Preprocess( std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { - + initialize_context(); std::vector batch_ivalue; @@ -163,78 +161,81 @@ torch::Tensor LlamacppHandler::Inference( std::shared_ptr& device, std::pair&>& idx_to_req_id, std::shared_ptr& response_batch) { - auto tokens_list_tensor = inputs[0].toTensor(); + torch::InferenceMode guard; + std::vector batch_output_vector; + for (const torch::jit::IValue& input : inputs) { + torch::Tensor tokens_list_tensor = input.toTensor(); - int64_t num_elements = tokens_list_tensor.numel(); + int64_t num_elements = tokens_list_tensor.numel(); - // Convert the tensor to a vector of long values - std::vector long_vector; - long_vector.reserve(num_elements); + int64_t* data_ptr = tokens_list_tensor.data_ptr(); + std::vector tokens_list; - auto data_ptr = tokens_list_tensor.data_ptr(); - for (int64_t i = 0; i < num_elements; ++i) { - long_vector.push_back(data_ptr[i]); - } + for (int64_t i = 0; i < num_elements; ++i) { + tokens_list.push_back(data_ptr[i]); + } + const int n_gen = std::min(32, max_context_size); - std::vector tokens_list; + long pos = 0; + while (pos < n_gen) { + // evaluate the transformer - for (auto id : long_vector) { - tokens_list.push_back(id); - } - const int n_gen = std::min(32, max_context_size); + if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), + llama_get_kv_cache_token_count(llama_ctx), + params.n_threads)) { + std::cout << "Failed to eval\n" << __func__ << std::endl; + break; + } - while (llama_get_kv_cache_token_count(llama_ctx) < n_gen) { - // evaluate the transformer + tokens_list.clear(); - if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), - llama_get_kv_cache_token_count(llama_ctx), - params.n_threads)) { - std::cout << "Failed to eval\n" << __func__ << std::endl; - break; - } + // sample the next token - tokens_list.clear(); + llama_token new_token_id = 0; - // sample the next token + auto logits = llama_get_logits(llama_ctx); + auto n_vocab = llama_n_vocab(llama_ctx); - llama_token new_token_id = 0; + std::vector candidates; + candidates.reserve(n_vocab); - auto logits = llama_get_logits(llama_ctx); - auto n_vocab = llama_n_vocab(llama_ctx); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back( + llama_token_data{token_id, logits[token_id], 0.0f}); + } - std::vector candidates; - candidates.reserve(n_vocab); + llama_token_data_array candidates_p = {candidates.data(), + candidates.size(), false}; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back( - llama_token_data{token_id, logits[token_id], 0.0f}); - } + new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); - llama_token_data_array candidates_p = {candidates.data(), candidates.size(), - false}; + // is it an end of stream ? + if (new_token_id == llama_token_eos(llama_ctx)) { + std::cout << "Reached [end of text]\n"; + break; + } - new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); + // print the new token : + std::cout << "New Token: " + << llama_token_to_piece(llama_ctx, new_token_id) << std::endl; - // is it an end of stream ? - if (new_token_id == llama_token_eos(llama_ctx)) { - std::cout << "Reached [end of text]\n"; - break; + // push this new token for next evaluation + tokens_list.push_back(new_token_id); + pos += 1; } - // push this new token for next evaluation - tokens_list.push_back(new_token_id); - } + std::vector tensor_vector; + for (auto id : tokens_list) { + torch::Tensor tensor = torch::tensor(id, torch::kLong); + tensor_vector.push_back(tensor); + } - std::vector tensor_vector; - for (auto id : tokens_list) { - torch::Tensor tensor = torch::tensor(id, torch::kLong); - tensor_vector.push_back(tensor); + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + batch_output_vector.push_back(stacked_tensor); } - torch::Tensor stacked_tensor = torch::stack(tensor_vector); llama_print_timings(llama_ctx); - llama_free(llama_ctx); - return stacked_tensor; + return torch::stack(batch_output_vector); } void LlamacppHandler::Postprocess( @@ -254,6 +255,7 @@ void LlamacppHandler::Postprocess( } std::string generated_text_str = generated_text_stream.str(); + std::cout << "Generated Text Str: " << generated_text_str << std::endl; auto response = (*response_batch)[kv.second]; @@ -279,6 +281,12 @@ void LlamacppHandler::Postprocess( } } +LlamacppHandler::~LlamacppHandler() noexcept { + llama_free(llama_ctx); + llama_free_model(llamamodel); + llama_backend_free(); +} + } // namespace llm #if defined(__linux__) || defined(__APPLE__) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh index 305e1fc062..54de782fad 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -22,7 +22,7 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler { // NOLINTBEGIN(bugprone-exception-escape) LlamacppHandler() = default; // NOLINTEND(bugprone-exception-escape) - ~LlamacppHandler() override = default; + ~LlamacppHandler() noexcept; void initialize_context(); From 28c9b026ed1c6e0b768291b3137dd250b03d4889 Mon Sep 17 00:00:00 2001 From: Matthias Reso <13337103+mreso@users.noreply.github.com> Date: Fri, 12 Jan 2024 23:00:32 +0000 Subject: [PATCH 11/11] Adopted llama.cpp api changes --- .pre-commit-config.yaml | 2 +- cpp/src/examples/CMakeLists.txt | 6 ++++-- cpp/src/examples/llamacpp/llamacpp_handler.cc | 11 +++++------ cpp/src/examples/llamacpp/llamacpp_handler.hh | 3 ++- .../torch_scripted/torch_scripted_backend_test.cc | 3 +-- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7ee04a103..306b2005e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: black additional_dependencies: ['click==8.0.4'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 6f8441d190..b7c97691df 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -4,7 +4,7 @@ set(MNIST_SOURCE_FILES "") list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) -target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llamacpp") set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp") @@ -20,10 +20,12 @@ set(MY_OBJECT_FILES ${LLAMACPP_SRC_DIR}/ggml.o ${LLAMACPP_SRC_DIR}/llama.o ${LLAMACPP_SRC_DIR}/common.o - ${LLAMACPP_SRC_DIR}/k_quants.o + ${LLAMACPP_SRC_DIR}/ggml-quants.o ${LLAMACPP_SRC_DIR}/ggml-alloc.o ${LLAMACPP_SRC_DIR}/grammar-parser.o ${LLAMACPP_SRC_DIR}/console.o + ${LLAMACPP_SRC_DIR}/build-info.o + ${LLAMACPP_SRC_DIR}/ggml-backend.o ) diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc index 1ad72e05cc..fc01858bec 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.cc +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -53,7 +53,8 @@ LlamacppHandler::LoadModel( llama_backend_init(params.numa); ctx_params = llama_context_default_params(); - llamamodel = llama_load_model_from_file(params.model.c_str(), ctx_params); + model_params = llama_model_default_params(); + llamamodel = llama_load_model_from_file(params.model.c_str(), model_params); return std::make_pair(module, device); } catch (const c10::Error& e) { @@ -74,7 +75,6 @@ std::vector LlamacppHandler::Preprocess( std::pair&>& idx_to_req_id, std::shared_ptr& request_batch, std::shared_ptr& response_batch) { - initialize_context(); std::vector batch_ivalue; @@ -181,8 +181,7 @@ torch::Tensor LlamacppHandler::Inference( // evaluate the transformer if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), - llama_get_kv_cache_token_count(llama_ctx), - params.n_threads)) { + llama_get_kv_cache_token_count(llama_ctx))) { std::cout << "Failed to eval\n" << __func__ << std::endl; break; } @@ -194,7 +193,7 @@ torch::Tensor LlamacppHandler::Inference( llama_token new_token_id = 0; auto logits = llama_get_logits(llama_ctx); - auto n_vocab = llama_n_vocab(llama_ctx); + auto n_vocab = llama_n_vocab(llamamodel); std::vector candidates; candidates.reserve(n_vocab); @@ -210,7 +209,7 @@ torch::Tensor LlamacppHandler::Inference( new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); // is it an end of stream ? - if (new_token_id == llama_token_eos(llama_ctx)) { + if (new_token_id == llama_token_eos(llamamodel)) { std::cout << "Reached [end of text]\n"; break; } diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh index 54de782fad..520099f2d6 100644 --- a/cpp/src/examples/llamacpp/llamacpp_handler.hh +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -13,6 +13,7 @@ namespace llm { class LlamacppHandler : public torchserve::torchscripted::BaseHandler { private: gpt_params params; + llama_model_params model_params; llama_model* llamamodel; llama_context_params ctx_params; llama_context* llama_ctx; @@ -52,4 +53,4 @@ class LlamacppHandler : public torchserve::torchscripted::BaseHandler { override; }; } // namespace llm -#endif // LLAMACPP_HANDLER_HH_ \ No newline at end of file +#endif // LLAMACPP_HANDLER_HH_ diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index e841c57ea1..16fedc660a 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -84,8 +84,7 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { "test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm", -1, "", "", 1, false), "test/resources/torchscript_model/llamacpp/llamacpp_handler", - "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", - 200); + "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", 200); } TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) {