diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a7ee04a103..306b2005e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,7 +34,7 @@ repos: - id: black additional_dependencies: ['click==8.0.4'] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile", "black"] diff --git a/cpp/build.sh b/cpp/build.sh index 23df4df722..2a962d7a9e 100755 --- a/cpp/build.sh +++ b/cpp/build.sh @@ -299,6 +299,10 @@ function build() { mv $DEPS_DIR/../src/examples/libmnist_handler.so $DEPS_DIR/../../test/resources/torchscript_model/mnist/mnist_handler/libmnist_handler.so fi + if [ -f "$DEPS_DIR/../src/examples/libllamacpp_handler.so" ]; then + mv $DEPS_DIR/../src/examples/libllamacpp_handler.so $DEPS_DIR/../../test/resources/torchscript_model/llamacpp/llamacpp_handler/libllamacpp_handler.so + fi + cd $DEPS_DIR/../.. if [ -f "$DEPS_DIR/../test/torchserve_cpp_test" ]; then $DEPS_DIR/../test/torchserve_cpp_test diff --git a/cpp/src/examples/CMakeLists.txt b/cpp/src/examples/CMakeLists.txt index 4c9c534097..b7c97691df 100644 --- a/cpp/src/examples/CMakeLists.txt +++ b/cpp/src/examples/CMakeLists.txt @@ -4,4 +4,29 @@ set(MNIST_SOURCE_FILES "") list(APPEND MNIST_SOURCE_FILES ${MNIST_SRC_DIR}/mnist_handler.cc) add_library(mnist_handler SHARED ${MNIST_SOURCE_FILES}) target_include_directories(mnist_handler PUBLIC ${MNIST_SRC_DIR}) -target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) +target_link_libraries(mnist_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) + +set(LLM_SRC_DIR "${torchserve_cpp_SOURCE_DIR}/src/examples/llamacpp") +set(LLAMACPP_SRC_DIR "/home/ubuntu/llama.cpp") +set(LLM_SOURCE_FILES "") +list(APPEND LLM_SOURCE_FILES ${LLM_SRC_DIR}/llamacpp_handler.cc) +add_library(llamacpp_handler SHARED ${LLM_SOURCE_FILES}) +target_include_directories(llamacpp_handler PUBLIC ${LLM_SRC_DIR}) +target_include_directories(llamacpp_handler PUBLIC ${LLAMACPP_SRC_DIR}) +target_link_libraries(llamacpp_handler PRIVATE ts_backends_torch_scripted ts_utils ${TORCH_LIBRARIES}) + + +set(MY_OBJECT_FILES + ${LLAMACPP_SRC_DIR}/ggml.o + ${LLAMACPP_SRC_DIR}/llama.o + ${LLAMACPP_SRC_DIR}/common.o + ${LLAMACPP_SRC_DIR}/ggml-quants.o + ${LLAMACPP_SRC_DIR}/ggml-alloc.o + ${LLAMACPP_SRC_DIR}/grammar-parser.o + ${LLAMACPP_SRC_DIR}/console.o + ${LLAMACPP_SRC_DIR}/build-info.o + ${LLAMACPP_SRC_DIR}/ggml-backend.o + +) + +target_sources(llamacpp_handler PRIVATE ${MY_OBJECT_FILES}) diff --git a/cpp/src/examples/llamacpp/config.json b/cpp/src/examples/llamacpp/config.json new file mode 100644 index 0000000000..6fb87978f3 --- /dev/null +++ b/cpp/src/examples/llamacpp/config.json @@ -0,0 +1,5 @@ +{ +"checkpoint_path" : "/home/ubuntu/llama-2-7b-chat.Q4_0.gguf" +} + + diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.cc b/cpp/src/examples/llamacpp/llamacpp_handler.cc new file mode 100644 index 0000000000..fc01858bec --- /dev/null +++ b/cpp/src/examples/llamacpp/llamacpp_handler.cc @@ -0,0 +1,303 @@ +#include "src/examples/llamacpp/llamacpp_handler.hh" + +#include +#include + +#include + +namespace llm { + +void LlamacppHandler::initialize_context() { + llama_ctx = llama_new_context_with_model(llamamodel, ctx_params); + + if (llama_ctx == nullptr) { + std::cerr << "Failed to initialize llama context" << std::endl; + } else { + std::cout << "Context initialized successfully" << std::endl; + } +} + +std::pair, + std::shared_ptr> +LlamacppHandler::LoadModel( + std::shared_ptr& load_model_request) { + try { + auto device = GetTorchDevice(load_model_request); + // Load dummy model + auto module = std::make_shared( + torch::jit::load(fmt::format("{}/{}", load_model_request->model_dir, + manifest_->GetModel().serialized_file), + *device)); + + const std::string configFilePath = + fmt::format("{}/{}", load_model_request->model_dir, "config.json"); + std::string jsonContent; + if (!folly::readFile(configFilePath.c_str(), jsonContent)) { + std::cerr << "config.json not found at: " << configFilePath << std::endl; + throw; + } + folly::dynamic json; + json = folly::parseJson(jsonContent); + + std::string checkpoint_path; + if (json.find("checkpoint_path") != json.items().end()) { + checkpoint_path = json["checkpoint_path"].asString(); + } else { + std::cerr << "Required field 'checkpoint_path' not found in JSON." + << std::endl; + throw; + } + params.model = checkpoint_path; + params.main_gpu = 0; + params.n_gpu_layers = 35; + + llama_backend_init(params.numa); + ctx_params = llama_context_default_params(); + model_params = llama_model_default_params(); + llamamodel = llama_load_model_from_file(params.model.c_str(), model_params); + + return std::make_pair(module, device); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.msg()); + throw e; + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "loading the model: {}, device id: {}, error: {}", + load_model_request->model_name, load_model_request->gpu_id, + e.what()); + throw e; + } +} + +std::vector LlamacppHandler::Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) { + initialize_context(); + + std::vector batch_ivalue; + std::vector batch_tensors; + uint8_t idx = 0; + for (auto& request : *request_batch) { + try { + (*response_batch)[request.request_id] = + std::make_shared(request.request_id); + idx_to_req_id.first += idx_to_req_id.first.empty() + ? request.request_id + : "," + request.request_id; + + auto data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_DATA); + auto dtype_it = + request.headers.find(torchserve::PayloadType::kHEADER_NAME_DATA_TYPE); + if (data_it == request.parameters.end()) { + data_it = request.parameters.find( + torchserve::PayloadType::kPARAMETER_NAME_BODY); + dtype_it = request.headers.find( + torchserve::PayloadType::kHEADER_NAME_BODY_TYPE); + } + + if (data_it == request.parameters.end() || + dtype_it == request.headers.end()) { + TS_LOGF(ERROR, "Empty payload for request id: {}", request.request_id); + (*response_batch)[request.request_id]->SetResponse( + 500, "data_type", torchserve::PayloadType::kCONTENT_TYPE_TEXT, + "Empty payload"); + continue; + } + + std::string msg = torchserve::Converter::VectorToStr(data_it->second); + + // tokenization + + std::vector tokens_list; + tokens_list = ::llama_tokenize(llama_ctx, msg, true); + + // const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int)tokens_list.size() > max_tokens_list_size) { + std::cout << __func__ << ": error: prompt too long (" + << tokens_list.size() << " tokens, max " + << max_tokens_list_size << ")\n"; + } + + // Print the tokens from the prompt : + std::vector tensor_vector; + for (auto id : tokens_list) { + torch::Tensor tensor = torch::tensor(id, torch::kInt64); + tensor_vector.push_back(tensor); + } + + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + batch_ivalue.push_back(stacked_tensor); + idx_to_req_id.second[idx++] = request.request_id; + + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + request.request_id, e.what()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to load tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, c10 error: {}", + request.request_id, e.msg()); + auto response = (*response_batch)[request.request_id]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to load tensor"); + } + } + + return batch_ivalue; +} + +torch::Tensor LlamacppHandler::Inference( + std::shared_ptr model, + std::vector& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) { + torch::InferenceMode guard; + std::vector batch_output_vector; + for (const torch::jit::IValue& input : inputs) { + torch::Tensor tokens_list_tensor = input.toTensor(); + + int64_t num_elements = tokens_list_tensor.numel(); + + int64_t* data_ptr = tokens_list_tensor.data_ptr(); + std::vector tokens_list; + + for (int64_t i = 0; i < num_elements; ++i) { + tokens_list.push_back(data_ptr[i]); + } + const int n_gen = std::min(32, max_context_size); + + long pos = 0; + while (pos < n_gen) { + // evaluate the transformer + + if (llama_eval(llama_ctx, tokens_list.data(), int(tokens_list.size()), + llama_get_kv_cache_token_count(llama_ctx))) { + std::cout << "Failed to eval\n" << __func__ << std::endl; + break; + } + + tokens_list.clear(); + + // sample the next token + + llama_token new_token_id = 0; + + auto logits = llama_get_logits(llama_ctx); + auto n_vocab = llama_n_vocab(llamamodel); + + std::vector candidates; + candidates.reserve(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back( + llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array candidates_p = {candidates.data(), + candidates.size(), false}; + + new_token_id = llama_sample_token_greedy(llama_ctx, &candidates_p); + + // is it an end of stream ? + if (new_token_id == llama_token_eos(llamamodel)) { + std::cout << "Reached [end of text]\n"; + break; + } + + // print the new token : + std::cout << "New Token: " + << llama_token_to_piece(llama_ctx, new_token_id) << std::endl; + + // push this new token for next evaluation + tokens_list.push_back(new_token_id); + pos += 1; + } + + std::vector tensor_vector; + for (auto id : tokens_list) { + torch::Tensor tensor = torch::tensor(id, torch::kLong); + tensor_vector.push_back(tensor); + } + + torch::Tensor stacked_tensor = torch::stack(tensor_vector); + batch_output_vector.push_back(stacked_tensor); + } + + llama_print_timings(llama_ctx); + return torch::stack(batch_output_vector); +} + +void LlamacppHandler::Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) { + for (const auto& kv : idx_to_req_id.second) { + try { + int64_t num_elements = data.numel(); + + // Convert the tensor to a vector of long values + std::stringstream generated_text_stream; + + auto data_ptr = data.data_ptr(); + for (int64_t i = 0; i < num_elements; ++i) { + generated_text_stream << llama_token_to_piece(llama_ctx, data_ptr[i]); + } + + std::string generated_text_str = generated_text_stream.str(); + std::cout << "Generated Text Str: " << generated_text_str << std::endl; + + auto response = (*response_batch)[kv.second]; + + response->SetResponse(200, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + generated_text_str); + } catch (const std::runtime_error& e) { + TS_LOGF(ERROR, "Failed to load tensor for request id: {}, error: {}", + kv.second, e.what()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "runtime_error, failed to postprocess tensor"); + } catch (const c10::Error& e) { + TS_LOGF(ERROR, + "Failed to postprocess tensor for request id: {}, error: {}", + kv.second, e.msg()); + auto response = (*response_batch)[kv.second]; + response->SetResponse(500, "data_type", + torchserve::PayloadType::kDATA_TYPE_STRING, + "c10 error, failed to postprocess tensor"); + } + } +} + +LlamacppHandler::~LlamacppHandler() noexcept { + llama_free(llama_ctx); + llama_free_model(llamamodel); + llama_backend_free(); +} + +} // namespace llm + +#if defined(__linux__) || defined(__APPLE__) +extern "C" { +torchserve::torchscripted::BaseHandler* allocatorLlamacppHandler() { + return new llm::LlamacppHandler(); +} + +void deleterLlamacppHandler(torchserve::torchscripted::BaseHandler* p) { + if (p != nullptr) { + delete static_cast(p); + } +} +} +#endif diff --git a/cpp/src/examples/llamacpp/llamacpp_handler.hh b/cpp/src/examples/llamacpp/llamacpp_handler.hh new file mode 100644 index 0000000000..520099f2d6 --- /dev/null +++ b/cpp/src/examples/llamacpp/llamacpp_handler.hh @@ -0,0 +1,56 @@ +#ifndef LLAMACPP_HANDLER_HH_ +#define LLAMACPP_HANDLER_HH_ + +#include +#include + +#include "common/common.h" +#include "ggml.h" +#include "llama.h" +#include "src/backends/torch_scripted/handler/base_handler.hh" + +namespace llm { +class LlamacppHandler : public torchserve::torchscripted::BaseHandler { + private: + gpt_params params; + llama_model_params model_params; + llama_model* llamamodel; + llama_context_params ctx_params; + llama_context* llama_ctx; + const int max_context_size = 32; + + public: + // NOLINTBEGIN(bugprone-exception-escape) + LlamacppHandler() = default; + // NOLINTEND(bugprone-exception-escape) + ~LlamacppHandler() noexcept; + + void initialize_context(); + + virtual std::pair, + std::shared_ptr> + LoadModel(std::shared_ptr& load_model_request); + + std::vector Preprocess( + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& request_batch, + std::shared_ptr& response_batch) + override; + + torch::Tensor Inference( + std::shared_ptr model, + std::vector& inputs, + std::shared_ptr& device, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; + + void Postprocess( + const torch::Tensor& data, + std::pair&>& idx_to_req_id, + std::shared_ptr& response_batch) + override; +}; +} // namespace llm +#endif // LLAMACPP_HANDLER_HH_ diff --git a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc index b3099d1a2a..16fedc660a 100644 --- a/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc +++ b/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc @@ -78,6 +78,15 @@ TEST_F(TorchScriptedBackendTest, TestLoadPredictMnistHandler) { "mnist_ts", 200); } +TEST_F(TorchScriptedBackendTest, TestLoadPredictLlmHandler) { + this->LoadPredict( + std::make_shared( + "test/resources/torchscript_model/llamacpp/llamacpp_handler", "llm", + -1, "", "", 1, false), + "test/resources/torchscript_model/llamacpp/llamacpp_handler", + "test/resources/torchscript_model/llamacpp/prompt.txt", "llm_ts", 200); +} + TEST_F(TorchScriptedBackendTest, TestBackendInitWrongModelDir) { auto result = backend_->Initialize("test/resources/torchscript_model/mnist"); ASSERT_EQ(result, false); diff --git a/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json new file mode 100644 index 0000000000..79c37e08ba --- /dev/null +++ b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/MAR-INF/MANIFEST.json @@ -0,0 +1,11 @@ +{ + "createdOn": "28/07/2020 06:32:08", + "runtime": "LSP", + "model": { + "modelName": "llamacpp", + "serializedFile": "dummy.pt", + "handler": "libllamacpp_handler:LlamacppHandler", + "modelVersion": "2.0" + }, + "archiverVersion": "0.2.0" +} \ No newline at end of file diff --git a/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt new file mode 100644 index 0000000000..2f4058d69c Binary files /dev/null and b/cpp/test/resources/torchscript_model/llamacpp/llamacpp_handler/dummy.pt differ diff --git a/cpp/test/resources/torchscript_model/llamacpp/prompt.txt b/cpp/test/resources/torchscript_model/llamacpp/prompt.txt new file mode 100644 index 0000000000..6e3c30c691 --- /dev/null +++ b/cpp/test/resources/torchscript_model/llamacpp/prompt.txt @@ -0,0 +1 @@ +Hello my name is \ No newline at end of file