diff --git a/.gitmodules b/.gitmodules index 98c9a2142a21..1ac9606d6dd1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,5 @@ [submodule "llama.cpp-mainline"] - path = gpt4all-backend/llama.cpp-mainline + path = gpt4all-backend/llama.cpp url = https://github.com/nomic-ai/llama.cpp.git branch = master [submodule "gpt4all-chat/usearch"] diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index e6210d74ae69..f10d5d9444ab 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -47,7 +47,7 @@ else() message(STATUS "Interprocedural optimization support detected") endif() -set(DIRECTORY llama.cpp-mainline) +set(DIRECTORY llama.cpp) include(llama.cpp.cmake) set(BUILD_VARIANTS) @@ -108,7 +108,7 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS) endif() # Include GGML - include_ggml(-mainline-${BUILD_VARIANT}) + include_ggml(-${BUILD_VARIANT}) # Function for preparing individual implementations function(prepare_target TARGET_NAME BASE_LIB) @@ -127,11 +127,10 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS) endfunction() # Add each individual implementations - add_library(llamamodel-mainline-${BUILD_VARIANT} SHARED - llamamodel.cpp llmodel_shared.cpp) - target_compile_definitions(llamamodel-mainline-${BUILD_VARIANT} PRIVATE + add_library(llamacpp-${BUILD_VARIANT} SHARED llamacpp_backend_impl.cpp) + target_compile_definitions(llamacpp-${BUILD_VARIANT} PRIVATE LLAMA_VERSIONS=>=3 LLAMA_DATE=999999) - prepare_target(llamamodel-mainline llama-mainline) + prepare_target(llamacpp llama) if (NOT PROJECT_IS_TOP_LEVEL AND BUILD_VARIANT STREQUAL cuda) set(CUDAToolkit_BIN_DIR ${CUDAToolkit_BIN_DIR} PARENT_SCOPE) @@ -139,7 +138,9 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS) endforeach() add_library(llmodel - llmodel.h llmodel.cpp llmodel_shared.cpp + model_backend.h + llamacpp_backend.h llamacpp_backend.cpp + llamacpp_backend_manager.h llamacpp_backend_manager.cpp llmodel_c.h llmodel_c.cpp dlhandle.cpp ) diff --git a/gpt4all-backend/llama.cpp-mainline b/gpt4all-backend/llama.cpp similarity index 100% rename from gpt4all-backend/llama.cpp-mainline rename to gpt4all-backend/llama.cpp diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llamacpp_backend.cpp similarity index 84% rename from gpt4all-backend/llmodel_shared.cpp rename to gpt4all-backend/llamacpp_backend.cpp index 7477254a74ff..489c276db1c7 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llamacpp_backend.cpp @@ -1,4 +1,6 @@ -#include "llmodel.h" +#include "llamacpp_backend.h" + +#include "llamacpp_backend_manager.h" #include #include @@ -15,6 +17,7 @@ namespace ranges = std::ranges; + static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) { static const std::regex placeholderRegex(R"(%[1-2](?![0-9]))"); @@ -38,24 +41,25 @@ static bool parsePromptTemplate(const std::string &tmpl, std::vector promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - bool special, - std::string *fakeReply) -{ +void LlamaCppBackend::prompt( + const std::string &prompt, + const std::string &promptTemplate, + std::function promptCallback, + std::function responseCallback, + bool allowContextShift, + PromptContext &promptCtx, + bool special, + std::string *fakeReply +) { if (!isModelLoaded()) { - std::cerr << implementation().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; + std::cerr << manager().modelType() << " ERROR: prompt won't work with an unloaded model!\n"; return; } if (!supportsCompletion()) { std::string errorMessage = "ERROR: this model does not support text completion or chat!"; responseCallback(-1, errorMessage); - std::cerr << implementation().modelType() << " " << errorMessage << "\n"; + std::cerr << manager().modelType() << " " << errorMessage << "\n"; return; } @@ -152,15 +156,22 @@ void LLModel::prompt(const std::string &prompt, } } +const LlamaCppBackendManager &LlamaCppBackend::manager() const +{ + return *m_manager; +} + // returns false on error -bool LLModel::decodePrompt(std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - std::vector embd_inp) { +bool LlamaCppBackend::decodePrompt( + std::function promptCallback, + std::function responseCallback, + bool allowContextShift, + PromptContext &promptCtx, + std::vector embd_inp +) { if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); - std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << + std::cerr << manager().modelType() << " ERROR: The prompt is " << embd_inp.size() << " tokens and the context window is " << promptCtx.n_ctx << "!\n"; return false; } @@ -188,7 +199,7 @@ bool LLModel::decodePrompt(std::function promptCallback, } if (!evalTokens(promptCtx, batch)) { - std::cerr << implementation().modelType() << " ERROR: Failed to process prompt\n"; + std::cerr << manager().modelType() << " ERROR: Failed to process prompt\n"; return false; } @@ -224,9 +235,11 @@ static std::string::size_type stringsOverlap(const std::string &s, const std::st return std::string::npos; } -void LLModel::generateResponse(std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx) { +void LlamaCppBackend::generateResponse( + std::function responseCallback, + bool allowContextShift, + PromptContext &promptCtx +) { static const char *stopSequences[] { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context", }; @@ -265,7 +278,7 @@ void LLModel::generateResponse(std::function Token tok = std::exchange(new_tok, std::nullopt).value(); if (!evalTokens(promptCtx, { tok })) { // TODO(jared): raise an exception - std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; + std::cerr << manager().modelType() << " ERROR: Failed to predict next token\n"; return false; } @@ -370,32 +383,3 @@ void LLModel::generateResponse(std::function promptCtx.n_past -= cachedTokens.size(); } - -void LLModel::embed( - const std::vector &texts, float *embeddings, std::optional prefix, int dimensionality, - size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb -) { - (void)texts; - (void)embeddings; - (void)prefix; - (void)dimensionality; - (void)tokenCount; - (void)doMean; - (void)atlas; - (void)cancelCb; - throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); -} - -void LLModel::embed( - const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, size_t *tokenCount, - bool doMean, bool atlas -) { - (void)texts; - (void)embeddings; - (void)isRetrieval; - (void)dimensionality; - (void)tokenCount; - (void)doMean; - (void)atlas; - throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); -} diff --git a/gpt4all-backend/llamacpp_backend.h b/gpt4all-backend/llamacpp_backend.h new file mode 100644 index 000000000000..2c924a02644c --- /dev/null +++ b/gpt4all-backend/llamacpp_backend.h @@ -0,0 +1,145 @@ +#pragma once + +#include "model_backend.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std::string_literals; + +class LlamaCppBackendManager; + + +class LlamaCppBackend : public EmbCapableBackend { +public: + struct GPUDevice { + const char *backend; + int index; + int type; + size_t heapSize; + std::string name; + std::string vendor; + + GPUDevice(const char *backend, int index, int type, size_t heapSize, std::string name, std::string vendor): + backend(backend), index(index), type(type), heapSize(heapSize), name(std::move(name)), + vendor(std::move(vendor)) {} + + std::string selectionName() const + { + assert(backend == "cuda"s || backend == "kompute"s); + return backendName() + ": " + name; + } + + std::string backendName() const { return backendIdToName(backend); } + + static std::string backendIdToName(const std::string &backend) { return s_backendNames.at(backend); } + + static std::string updateSelectionName(const std::string &name) { + if (name == "Auto" || name == "CPU" || name == "Metal") + return name; + auto it = std::find_if(s_backendNames.begin(), s_backendNames.end(), [&name](const auto &entry) { + return name.starts_with(entry.second + ": "); + }); + if (it != s_backendNames.end()) + return name; + return "Vulkan: " + name; // previously, there were only Vulkan devices + } + + private: + static inline const std::unordered_map s_backendNames { + {"cpu", "CPU"}, {"metal", "Metal"}, {"cuda", "CUDA"}, {"kompute", "Vulkan"}, + }; + }; + + using ProgressCallback = std::function; + + virtual bool isModelBlacklisted(const std::string &modelPath) const = 0; + virtual bool isEmbeddingModel(const std::string &modelPath) const = 0; + virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0; + + void prompt(const std::string &prompt, + const std::string &promptTemplate, + std::function promptCallback, + std::function responseCallback, + bool allowContextShift, + PromptContext &ctx, + bool special = false, + std::string *fakeReply = nullptr) override; + + virtual void setThreadCount(int32_t n_threads) { (void)n_threads; } + virtual int32_t threadCount() const { return 1; } + + const LlamaCppBackendManager &manager() const; + + virtual std::vector availableGPUDevices(size_t memoryRequired) const + { + (void)memoryRequired; + return {}; + } + + virtual bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const + { + (void)memoryRequired; + (void)name; + return false; + } + + virtual bool initializeGPUDevice(int device, std::string *unavail_reason = nullptr) const + { + (void)device; + if (unavail_reason) { + *unavail_reason = "model has no GPU support"; + } + return false; + } + + virtual bool usingGPUDevice() const { return false; } + virtual const char *backendName() const { return "cpu"; } + virtual const char *gpuDeviceName() const { return nullptr; } + + void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } + +protected: + virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0; + virtual bool isSpecialToken(Token id) const = 0; + virtual std::string tokenToString(Token id) const = 0; + virtual Token sampleToken(PromptContext &ctx) const = 0; + virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) const = 0; + virtual void shiftContext(PromptContext &promptCtx) = 0; + virtual int32_t contextLength() const = 0; + virtual const std::vector &endTokens() const = 0; + virtual bool shouldAddBOS() const = 0; + + virtual int32_t maxContextLength(std::string const &modelPath) const = 0; + virtual int32_t layerCount(std::string const &modelPath) const = 0; + + static bool staticProgressCallback(float progress, void* ctx) + { + LlamaCppBackend *model = static_cast(ctx); + if (model && model->m_progressCallback) + return model->m_progressCallback(progress); + return true; + } + + bool decodePrompt(std::function promptCallback, + std::function responseCallback, + bool allowContextShift, + PromptContext &promptCtx, + std::vector embd_inp); + void generateResponse(std::function responseCallback, + bool allowContextShift, + PromptContext &promptCtx); + + const LlamaCppBackendManager *m_manager = nullptr; + ProgressCallback m_progressCallback; + Token m_tokenize_last_token = -1; + + friend class LlamaCppBackendManager; +}; diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamacpp_backend_impl.cpp similarity index 92% rename from gpt4all-backend/llamamodel.cpp rename to gpt4all-backend/llamacpp_backend_impl.cpp index f07a05e88ab8..d666f3e9791f 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamacpp_backend_impl.cpp @@ -1,7 +1,7 @@ -#define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE -#include "llamamodel_impl.h" +#define LLAMACPP_BACKEND_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#include "llamacpp_backend_impl.h" -#include "llmodel.h" +#include "model_backend.h" #include #include @@ -232,7 +232,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const return value; } -struct LLamaPrivate { +struct LlamaPrivate { const std::string modelPath; bool modelLoaded = false; int device = -1; @@ -242,12 +242,12 @@ struct LLamaPrivate { llama_model_params model_params; llama_context_params ctx_params; int64_t n_threads = 0; - std::vector end_tokens; + std::vector end_tokens; const char *backend_name = nullptr; }; -LLamaModel::LLamaModel() - : d_ptr(new LLamaPrivate) {} +LlamaCppBackendImpl::LlamaCppBackendImpl() + : d_ptr(new LlamaPrivate) {} // default hparams (LLaMA 7B) struct llama_file_hparams { @@ -260,7 +260,7 @@ struct llama_file_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; }; -size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl) +size_t LlamaCppBackendImpl::requiredMem(const std::string &modelPath, int n_ctx, int ngl) { // TODO(cebtenzzre): update to GGUF (void)ngl; // FIXME(cetenzzre): use this value @@ -285,7 +285,7 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl) return filesize + est_kvcache_size; } -bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const +bool LlamaCppBackendImpl::isModelBlacklisted(const std::string &modelPath) const { auto * ctx = load_gguf(modelPath.c_str()); if (!ctx) { @@ -322,7 +322,7 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const return res; } -bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const +bool LlamaCppBackendImpl::isEmbeddingModel(const std::string &modelPath) const { bool result = false; std::string arch; @@ -346,7 +346,7 @@ bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const return result; } -bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) +bool LlamaCppBackendImpl::loadModel(const std::string &modelPath, int n_ctx, int ngl) { d_ptr->modelLoaded = false; @@ -378,7 +378,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) d_ptr->model_params.use_mlock = params.use_mlock; #endif - d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback; + d_ptr->model_params.progress_callback = &LlamaCppBackend::staticProgressCallback; d_ptr->model_params.progress_callback_user_data = this; d_ptr->backend_name = "cpu"; // default @@ -488,18 +488,18 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl) return true; } -void LLamaModel::setThreadCount(int32_t n_threads) +void LlamaCppBackendImpl::setThreadCount(int32_t n_threads) { d_ptr->n_threads = n_threads; llama_set_n_threads(d_ptr->ctx, n_threads, n_threads); } -int32_t LLamaModel::threadCount() const +int32_t LlamaCppBackendImpl::threadCount() const { return d_ptr->n_threads; } -LLamaModel::~LLamaModel() +LlamaCppBackendImpl::~LlamaCppBackendImpl() { if (d_ptr->ctx) { llama_free(d_ptr->ctx); @@ -507,32 +507,32 @@ LLamaModel::~LLamaModel() llama_free_model(d_ptr->model); } -bool LLamaModel::isModelLoaded() const +bool LlamaCppBackendImpl::isModelLoaded() const { return d_ptr->modelLoaded; } -size_t LLamaModel::stateSize() const +size_t LlamaCppBackendImpl::stateSize() const { return llama_get_state_size(d_ptr->ctx); } -size_t LLamaModel::saveState(uint8_t *dest) const +size_t LlamaCppBackendImpl::saveState(uint8_t *dest) const { return llama_copy_state_data(d_ptr->ctx, dest); } -size_t LLamaModel::restoreState(const uint8_t *src) +size_t LlamaCppBackendImpl::restoreState(const uint8_t *src) { // const_cast is required, see: https://github.com/ggerganov/llama.cpp/pull/1540 return llama_set_state_data(d_ptr->ctx, const_cast(src)); } -std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) +std::vector LlamaCppBackendImpl::tokenize(PromptContext &ctx, const std::string &str, bool special) { bool atStart = m_tokenize_last_token == -1; bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token); - std::vector fres(str.length() + 4); + std::vector fres(str.length() + 4); int32_t fres_len = llama_tokenize_gpt4all( d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, /*parse_special*/ special, /*insert_space*/ insertSpace @@ -543,13 +543,13 @@ std::vector LLamaModel::tokenize(PromptContext &ctx, const std:: return fres; } -bool LLamaModel::isSpecialToken(Token id) const +bool LlamaCppBackendImpl::isSpecialToken(Token id) const { return llama_token_get_attr(d_ptr->model, id) & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN); } -std::string LLamaModel::tokenToString(Token id) const +std::string LlamaCppBackendImpl::tokenToString(Token id) const { std::vector result(8, 0); const int n_tokens = llama_token_to_piece(d_ptr->model, id, result.data(), result.size(), 0, true); @@ -565,7 +565,7 @@ std::string LLamaModel::tokenToString(Token id) const return std::string(result.data(), result.size()); } -LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const +ModelBackend::Token LlamaCppBackendImpl::sampleToken(PromptContext &promptCtx) const { const size_t n_prev_toks = std::min((size_t) promptCtx.repeat_last_n, promptCtx.tokens.size()); return llama_sample_top_p_top_k(d_ptr->ctx, @@ -574,7 +574,7 @@ LLModel::Token LLamaModel::sampleToken(PromptContext &promptCtx) const promptCtx.repeat_penalty); } -bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) const +bool LlamaCppBackendImpl::evalTokens(PromptContext &ctx, const std::vector &tokens) const { llama_kv_cache_seq_rm(d_ptr->ctx, 0, ctx.n_past, -1); @@ -598,7 +598,7 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &toke return res == 0; } -void LLamaModel::shiftContext(PromptContext &promptCtx) +void LlamaCppBackendImpl::shiftContext(PromptContext &promptCtx) { // infinite text generation via context shifting @@ -622,27 +622,27 @@ void LLamaModel::shiftContext(PromptContext &promptCtx) promptCtx.n_past = promptCtx.tokens.size(); } -int32_t LLamaModel::contextLength() const +int32_t LlamaCppBackendImpl::contextLength() const { return llama_n_ctx(d_ptr->ctx); } -const std::vector &LLamaModel::endTokens() const +const std::vector &LlamaCppBackendImpl::endTokens() const { return d_ptr->end_tokens; } -bool LLamaModel::shouldAddBOS() const +bool LlamaCppBackendImpl::shouldAddBOS() const { return llama_add_bos_token(d_ptr->model); } -int32_t LLamaModel::maxContextLength(std::string const &modelPath) const +int32_t LlamaCppBackendImpl::maxContextLength(std::string const &modelPath) const { return get_arch_key_u32(modelPath, "context_length"); } -int32_t LLamaModel::layerCount(std::string const &modelPath) const +int32_t LlamaCppBackendImpl::layerCount(std::string const &modelPath) const { return get_arch_key_u32(modelPath, "block_count"); } @@ -659,7 +659,7 @@ static const char *getVulkanVendorName(uint32_t vendorID) } #endif -std::vector LLamaModel::availableGPUDevices(size_t memoryRequired) const +std::vector LlamaCppBackendImpl::availableGPUDevices(size_t memoryRequired) const { #if defined(GGML_USE_KOMPUTE) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CUDA) size_t count = 0; @@ -675,7 +675,7 @@ std::vector LLamaModel::availableGPUDevices(size_t memoryReq #endif if (lcppDevices) { - std::vector devices; + std::vector devices; devices.reserve(count); for (size_t i = 0; i < count; ++i) { @@ -724,7 +724,7 @@ std::vector LLamaModel::availableGPUDevices(size_t memoryReq return {}; } -bool LLamaModel::initializeGPUDevice(size_t memoryRequired, const std::string &name) const +bool LlamaCppBackendImpl::initializeGPUDevice(size_t memoryRequired, const std::string &name) const { #if defined(GGML_USE_VULKAN) || defined(GGML_USE_CUDA) auto devices = availableGPUDevices(memoryRequired); @@ -761,7 +761,7 @@ bool LLamaModel::initializeGPUDevice(size_t memoryRequired, const std::string &n return false; } -bool LLamaModel::initializeGPUDevice(int device, std::string *unavail_reason) const +bool LlamaCppBackendImpl::initializeGPUDevice(int device, std::string *unavail_reason) const { #if defined(GGML_USE_KOMPUTE) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CUDA) (void)unavail_reason; @@ -779,7 +779,7 @@ bool LLamaModel::initializeGPUDevice(int device, std::string *unavail_reason) co #endif } -bool LLamaModel::usingGPUDevice() const +bool LlamaCppBackendImpl::usingGPUDevice() const { if (!d_ptr->model) return false; @@ -791,12 +791,12 @@ bool LLamaModel::usingGPUDevice() const return usingGPU; } -const char *LLamaModel::backendName() const +const char *LlamaCppBackendImpl::backendName() const { return d_ptr->backend_name; } -const char *LLamaModel::gpuDeviceName() const +const char *LlamaCppBackendImpl::gpuDeviceName() const { if (usingGPUDevice()) { #if defined(GGML_USE_KOMPUTE) || defined(GGML_USE_VULKAN) || defined(GGML_USE_CUDA) @@ -825,14 +825,14 @@ void llama_batch_add( batch.n_tokens++; } -static void batch_add_seq(llama_batch &batch, const std::vector &tokens, int seq_id) +static void batch_add_seq(llama_batch &batch, const std::vector &tokens, int seq_id) { for (unsigned i = 0; i < tokens.size(); i++) { llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } } -size_t LLamaModel::embeddingSize() const +size_t LlamaCppBackendImpl::embeddingSize() const { return llama_n_embd(d_ptr->model); } @@ -884,7 +884,8 @@ static const EmbModelGroup EMBEDDING_MODEL_SPECS[] { "multilingual-e5-large-instruct"}}, }; -static const EmbModelSpec *getEmbedSpec(const std::string &modelName) { +static const EmbModelSpec *getEmbedSpec(const std::string &modelName) +{ static const auto &specs = EMBEDDING_MODEL_SPECS; auto it = std::find_if(specs, std::end(specs), [&modelName](auto &spec) { @@ -895,7 +896,7 @@ static const EmbModelSpec *getEmbedSpec(const std::string &modelName) { return it < std::end(specs) ? &it->spec : nullptr; } -void LLamaModel::embed( +void LlamaCppBackendImpl::embed( const std::vector &texts, float *embeddings, bool isRetrieval, int dimensionality, size_t *tokenCount, bool doMean, bool atlas ) { @@ -907,9 +908,9 @@ void LLamaModel::embed( embed(texts, embeddings, prefix, dimensionality, tokenCount, doMean, atlas); } -void LLamaModel::embed( +void LlamaCppBackendImpl::embed( const std::vector &texts, float *embeddings, std::optional prefix, int dimensionality, - size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb + size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb ) { if (!d_ptr->model) throw std::logic_error("no model is loaded"); @@ -965,11 +966,11 @@ double getL2NormScale(T *start, T *end) return 1.0 / std::max(magnitude, 1e-12); } -void LLamaModel::embedInternal( +void LlamaCppBackendImpl::embedInternal( const std::vector &texts, float *embeddings, std::string prefix, int dimensionality, - size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb, const EmbModelSpec *spec + size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb, const EmbModelSpec *spec ) { - typedef std::vector TokenString; + typedef std::vector TokenString; static constexpr int32_t atlasMaxLength = 8192; static constexpr int chunkOverlap = 8; // Atlas overlaps chunks of input by 8 tokens @@ -1217,12 +1218,12 @@ DLL_EXPORT bool is_arch_supported(const char *arch) return std::find(KNOWN_ARCHES.begin(), KNOWN_ARCHES.end(), std::string(arch)) < KNOWN_ARCHES.end(); } -DLL_EXPORT LLModel *construct() +DLL_EXPORT LlamaCppBackend *construct() { llama_log_set(llama_log_callback, nullptr); #ifdef GGML_USE_CUDA ggml_backend_cuda_log_set_callback(cuda_log_callback, nullptr); #endif - return new LLamaModel; + return new LlamaCppBackendImpl; } } diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamacpp_backend_impl.h similarity index 84% rename from gpt4all-backend/llamamodel_impl.h rename to gpt4all-backend/llamacpp_backend_impl.h index 7c698ffa366b..7ed73c579d42 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamacpp_backend_impl.h @@ -1,22 +1,22 @@ -#ifndef LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE -#error This file is NOT meant to be included outside of llamamodel.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define LLAMAMODEL_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#pragma once + +#ifndef LLAMACPP_BACKEND_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE +#error This file is NOT meant to be included outside of llamacpp_backend_impl.cpp. Doing so is DANGEROUS. Be sure to know what you are doing before proceeding to #define LLAMACPP_BACKEND_H_I_KNOW_WHAT_I_AM_DOING_WHEN_INCLUDING_THIS_FILE #endif -#ifndef LLAMAMODEL_H -#define LLAMAMODEL_H -#include "llmodel.h" +#include "llamacpp_backend.h" #include #include #include -struct LLamaPrivate; +struct LlamaPrivate; struct EmbModelSpec; -class LLamaModel : public LLModel { +class LlamaCppBackendImpl : public LlamaCppBackend { public: - LLamaModel(); - ~LLamaModel(); + LlamaCppBackendImpl(); + ~LlamaCppBackendImpl(); bool supportsEmbedding() const override { return m_supportsEmbedding; } bool supportsCompletion() const override { return m_supportsCompletion; } @@ -47,7 +47,7 @@ class LLamaModel : public LLModel { size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override; private: - std::unique_ptr d_ptr; + std::unique_ptr d_ptr; bool m_supportsEmbedding = false; bool m_supportsCompletion = false; @@ -68,5 +68,3 @@ class LLamaModel : public LLModel { size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb, const EmbModelSpec *spec); }; - -#endif // LLAMAMODEL_H diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llamacpp_backend_manager.cpp similarity index 80% rename from gpt4all-backend/llmodel.cpp rename to gpt4all-backend/llamacpp_backend_manager.cpp index 1acf0642ef2a..5368d5c45ae2 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llamacpp_backend_manager.cpp @@ -1,19 +1,21 @@ -#include "llmodel.h" +#include "llamacpp_backend_manager.h" #include "dlhandle.h" #include +#include #include #include -#include #include #include #include #include #include #include +#include #include #include +#include #include #ifdef _WIN32 @@ -34,6 +36,7 @@ namespace fs = std::filesystem; + #ifndef __APPLE__ static const std::string DEFAULT_BACKENDS[] = {"kompute", "cpu"}; #elif defined(__aarch64__) @@ -66,7 +69,7 @@ std::string s_implementations_search_path = "."; #define cpu_supports_avx2() !!__builtin_cpu_supports("avx2") #endif -LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) +LlamaCppBackendManager::LlamaCppBackendManager(Dlhandle &&dlhandle_) : m_dlhandle(new Dlhandle(std::move(dlhandle_))) { auto get_model_type = m_dlhandle->get("get_model_type"); assert(get_model_type); @@ -78,11 +81,11 @@ LLModel::Implementation::Implementation(Dlhandle &&dlhandle_) assert(m_getFileArch); m_isArchSupported = m_dlhandle->get("is_arch_supported"); assert(m_isArchSupported); - m_construct = m_dlhandle->get("construct"); + m_construct = m_dlhandle->get("construct"); assert(m_construct); } -LLModel::Implementation::Implementation(Implementation &&o) +LlamaCppBackendManager::LlamaCppBackendManager(LlamaCppBackendManager &&o) : m_getFileArch(o.m_getFileArch) , m_isArchSupported(o.m_isArchSupported) , m_construct(o.m_construct) @@ -92,7 +95,7 @@ LLModel::Implementation::Implementation(Implementation &&o) o.m_dlhandle = nullptr; } -LLModel::Implementation::~Implementation() +LlamaCppBackendManager::~LlamaCppBackendManager() { delete m_dlhandle; } @@ -117,7 +120,7 @@ static void addCudaSearchPath() #endif } -const std::vector &LLModel::Implementation::implementationList() +const std::vector &LlamaCppBackendManager::implementationList() { if (cpu_supports_avx() == 0) { throw std::runtime_error("CPU does not support AVX"); @@ -125,12 +128,12 @@ const std::vector &LLModel::Implementation::implementat // NOTE: allocated on heap so we leak intentionally on exit so we have a chance to clean up the // individual models without the cleanup of the static list interfering - static auto* libs = new std::vector([] () { - std::vector fres; + static auto* libs = new std::vector([] () { + std::vector fres; addCudaSearchPath(); - std::string impl_name_re = "llamamodel-mainline-(cpu|metal|kompute|vulkan|cuda)"; + std::string impl_name_re = "llamacpp-(cpu|metal|kompute|vulkan|cuda)"; if (cpu_supports_avx2() == 0) { impl_name_re += "-avxonly"; } @@ -146,7 +149,10 @@ const std::vector &LLModel::Implementation::implementat const fs::path &p = f.path(); if (p.extension() != LIB_FILE_EXT) continue; - if (!std::regex_search(p.stem().string(), re)) continue; + if (!std::regex_search(p.stem().string(), re)) { + std::cerr << "did not match regex: " << p.stem().string() << "\n"; + continue; + } // Add to list if model implementation Dlhandle dl; @@ -160,7 +166,7 @@ const std::vector &LLModel::Implementation::implementat std::cerr << "Not an implementation: " << p.filename().string() << "\n"; continue; } - fres.emplace_back(Implementation(std::move(dl))); + fres.emplace_back(LlamaCppBackendManager(std::move(dl))); } } }; @@ -181,8 +187,10 @@ static std::string applyCPUVariant(const std::string &buildVariant) return buildVariant; } -const LLModel::Implementation* LLModel::Implementation::implementation(const char *fname, const std::string& buildVariant) -{ +const LlamaCppBackendManager* LlamaCppBackendManager::implementation( + const char *fname, + const std::string& buildVariant +) { bool buildVariantMatched = false; std::optional archName; for (const auto& i : implementationList()) { @@ -206,8 +214,11 @@ const LLModel::Implementation* LLModel::Implementation::implementation(const cha throw BadArchError(std::move(*archName)); } -LLModel *LLModel::Implementation::construct(const std::string &modelPath, const std::string &backend, int n_ctx) -{ +LlamaCppBackend *LlamaCppBackendManager::construct( + const std::string &modelPath, + const std::string &backend, + int n_ctx +) { std::vector desiredBackends; if (backend != "auto") { desiredBackends.push_back(backend); @@ -221,7 +232,7 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, const if (impl) { // Construct llmodel implementation auto *fres = impl->m_construct(); - fres->m_implementation = impl; + fres->m_manager = impl; #if defined(__APPLE__) && defined(__aarch64__) // FIXME: See if metal works for intel macs /* TODO(cebtenzzre): after we fix requiredMem, we should change this to happen at @@ -247,11 +258,11 @@ LLModel *LLModel::Implementation::construct(const std::string &modelPath, const throw MissingImplementationError("Could not find any implementations for backend: " + backend); } -LLModel *LLModel::Implementation::constructGlobalLlama(const std::optional &backend) +LlamaCppBackend *LlamaCppBackendManager::constructGlobalLlama(const std::optional &backend) { - static std::unordered_map> implCache; + static std::unordered_map> implCache; - const std::vector *impls; + const std::vector *impls; try { impls = &implementationList(); } catch (const std::runtime_error &e) { @@ -266,7 +277,7 @@ LLModel *LLModel::Implementation::constructGlobalLlama(const std::optionalm_construct(); - fres->m_implementation = impl; - implCache[desiredBackend] = std::unique_ptr(fres); + fres->m_manager = impl; + implCache[desiredBackend] = std::unique_ptr(fres); return fres; } } - std::cerr << __func__ << ": could not find Llama implementation for backend: " << backend.value_or("default") << "\n"; + std::cerr << __func__ << ": could not find Llama implementation for backend: " << backend.value_or("default") + << "\n"; return nullptr; } -std::vector LLModel::Implementation::availableGPUDevices(size_t memoryRequired) +std::vector LlamaCppBackendManager::availableGPUDevices(size_t memoryRequired) { - std::vector devices; + std::vector devices; #ifndef __APPLE__ static const std::string backends[] = {"kompute", "cuda"}; for (const auto &backend: backends) { @@ -308,40 +320,40 @@ std::vector LLModel::Implementation::availableGPUDevices(siz return devices; } -int32_t LLModel::Implementation::maxContextLength(const std::string &modelPath) +int32_t LlamaCppBackendManager::maxContextLength(const std::string &modelPath) { auto *llama = constructGlobalLlama(); return llama ? llama->maxContextLength(modelPath) : -1; } -int32_t LLModel::Implementation::layerCount(const std::string &modelPath) +int32_t LlamaCppBackendManager::layerCount(const std::string &modelPath) { auto *llama = constructGlobalLlama(); return llama ? llama->layerCount(modelPath) : -1; } -bool LLModel::Implementation::isEmbeddingModel(const std::string &modelPath) +bool LlamaCppBackendManager::isEmbeddingModel(const std::string &modelPath) { auto *llama = constructGlobalLlama(); return llama && llama->isEmbeddingModel(modelPath); } -void LLModel::Implementation::setImplementationsSearchPath(const std::string& path) +void LlamaCppBackendManager::setImplementationsSearchPath(const std::string& path) { s_implementations_search_path = path; } -const std::string& LLModel::Implementation::implementationsSearchPath() +const std::string& LlamaCppBackendManager::implementationsSearchPath() { return s_implementations_search_path; } -bool LLModel::Implementation::hasSupportedCPU() +bool LlamaCppBackendManager::hasSupportedCPU() { return cpu_supports_avx() != 0; } -int LLModel::Implementation::cpuSupportsAVX2() +int LlamaCppBackendManager::cpuSupportsAVX2() { return cpu_supports_avx2(); } diff --git a/gpt4all-backend/llamacpp_backend_manager.h b/gpt4all-backend/llamacpp_backend_manager.h new file mode 100644 index 000000000000..81d0853f5635 --- /dev/null +++ b/gpt4all-backend/llamacpp_backend_manager.h @@ -0,0 +1,69 @@ +#pragma once + +#include "llamacpp_backend.h" + +#include +#include +#include + +class Dlhandle; + + +class LlamaCppBackendManager { +public: + class BadArchError : public std::runtime_error { + public: + BadArchError(std::string arch) + : runtime_error("Unsupported model architecture: " + arch) + , m_arch(std::move(arch)) + {} + + const std::string &arch() const noexcept { return m_arch; } + + private: + std::string m_arch; + }; + + class MissingImplementationError : public std::runtime_error { + public: + using std::runtime_error::runtime_error; + }; + + class UnsupportedModelError : public std::runtime_error { + public: + using std::runtime_error::runtime_error; + }; + + LlamaCppBackendManager(const LlamaCppBackendManager &) = delete; + LlamaCppBackendManager(LlamaCppBackendManager &&); + ~LlamaCppBackendManager(); + + std::string_view modelType() const { return m_modelType; } + std::string_view buildVariant() const { return m_buildVariant; } + + static LlamaCppBackend *construct(const std::string &modelPath, const std::string &backend = "auto", int n_ctx = 2048); + static std::vector availableGPUDevices(size_t memoryRequired = 0); + static int32_t maxContextLength(const std::string &modelPath); + static int32_t layerCount(const std::string &modelPath); + static bool isEmbeddingModel(const std::string &modelPath); + static void setImplementationsSearchPath(const std::string &path); + static const std::string &implementationsSearchPath(); + static bool hasSupportedCPU(); + // 0 for no, 1 for yes, -1 for non-x86_64 + static int cpuSupportsAVX2(); + +private: + LlamaCppBackendManager(Dlhandle &&); + + static const std::vector &implementationList(); + static const LlamaCppBackendManager *implementation(const char *fname, const std::string &buildVariant); + static LlamaCppBackend *constructGlobalLlama(const std::optional &backend = std::nullopt); + + char *(*m_getFileArch)(const char *fname); + bool (*m_isArchSupported)(const char *arch); + LlamaCppBackend *(*m_construct)(); + + std::string_view m_modelType; + std::string_view m_buildVariant; + Dlhandle *m_dlhandle; +}; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h deleted file mode 100644 index 04a510dc740f..000000000000 --- a/gpt4all-backend/llmodel.h +++ /dev/null @@ -1,262 +0,0 @@ -#ifndef LLMODEL_H -#define LLMODEL_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -class Dlhandle; - -using namespace std::string_literals; - -#define LLMODEL_MAX_PROMPT_BATCH 128 - -class LLModel { -public: - using Token = int32_t; - - class BadArchError: public std::runtime_error { - public: - BadArchError(std::string arch) - : runtime_error("Unsupported model architecture: " + arch) - , m_arch(std::move(arch)) - {} - - const std::string &arch() const noexcept { return m_arch; } - - private: - std::string m_arch; - }; - - class MissingImplementationError: public std::runtime_error { - public: - using std::runtime_error::runtime_error; - }; - - class UnsupportedModelError: public std::runtime_error { - public: - using std::runtime_error::runtime_error; - }; - - struct GPUDevice { - const char *backend; - int index; - int type; - size_t heapSize; - std::string name; - std::string vendor; - - GPUDevice(const char *backend, int index, int type, size_t heapSize, std::string name, std::string vendor): - backend(backend), index(index), type(type), heapSize(heapSize), name(std::move(name)), - vendor(std::move(vendor)) {} - - std::string selectionName() const - { - assert(backend == "cuda"s || backend == "kompute"s); - return backendName() + ": " + name; - } - - std::string backendName() const { return backendIdToName(backend); } - - static std::string backendIdToName(const std::string &backend) { return s_backendNames.at(backend); } - - static std::string updateSelectionName(const std::string &name) { - if (name == "Auto" || name == "CPU" || name == "Metal") - return name; - auto it = std::find_if(s_backendNames.begin(), s_backendNames.end(), [&name](const auto &entry) { - return name.starts_with(entry.second + ": "); - }); - if (it != s_backendNames.end()) - return name; - return "Vulkan: " + name; // previously, there were only Vulkan devices - } - - private: - static inline const std::unordered_map s_backendNames { - {"cpu", "CPU"}, {"metal", "Metal"}, {"cuda", "CUDA"}, {"kompute", "Vulkan"}, - }; - }; - - class Implementation { - public: - Implementation(const Implementation &) = delete; - Implementation(Implementation &&); - ~Implementation(); - - std::string_view modelType() const { return m_modelType; } - std::string_view buildVariant() const { return m_buildVariant; } - - static LLModel *construct(const std::string &modelPath, const std::string &backend = "auto", int n_ctx = 2048); - static std::vector availableGPUDevices(size_t memoryRequired = 0); - static int32_t maxContextLength(const std::string &modelPath); - static int32_t layerCount(const std::string &modelPath); - static bool isEmbeddingModel(const std::string &modelPath); - static void setImplementationsSearchPath(const std::string &path); - static const std::string &implementationsSearchPath(); - static bool hasSupportedCPU(); - // 0 for no, 1 for yes, -1 for non-x86_64 - static int cpuSupportsAVX2(); - - private: - Implementation(Dlhandle &&); - - static const std::vector &implementationList(); - static const Implementation *implementation(const char *fname, const std::string &buildVariant); - static LLModel *constructGlobalLlama(const std::optional &backend = std::nullopt); - - char *(*m_getFileArch)(const char *fname); - bool (*m_isArchSupported)(const char *arch); - LLModel *(*m_construct)(); - - std::string_view m_modelType; - std::string_view m_buildVariant; - Dlhandle *m_dlhandle; - }; - - struct PromptContext { - std::vector tokens; // current tokens in the context window - int32_t n_past = 0; // number of tokens in past conversation - int32_t n_ctx = 0; // number of tokens possible in context window - int32_t n_predict = 200; - int32_t top_k = 40; - float top_p = 0.9f; - float min_p = 0.0f; - float temp = 0.9f; - int32_t n_batch = 9; - float repeat_penalty = 1.10f; - int32_t repeat_last_n = 64; // last n tokens to penalize - float contextErase = 0.5f; // percent of context to erase if we exceed the context window - }; - - using ProgressCallback = std::function; - - explicit LLModel() {} - virtual ~LLModel() {} - - virtual bool supportsEmbedding() const = 0; - virtual bool supportsCompletion() const = 0; - virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0; - virtual bool isModelBlacklisted(const std::string &modelPath) const { (void)modelPath; return false; }; - virtual bool isEmbeddingModel(const std::string &modelPath) const { (void)modelPath; return false; } - virtual bool isModelLoaded() const = 0; - virtual size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) = 0; - virtual size_t stateSize() const { return 0; } - virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; } - virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; } - - // This method requires the model to return true from supportsCompletion otherwise it will throw - // an error - virtual void prompt(const std::string &prompt, - const std::string &promptTemplate, - std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &ctx, - bool special = false, - std::string *fakeReply = nullptr); - - using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); - - virtual size_t embeddingSize() const { - throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings"); - } - // user-specified prefix - virtual void embed(const std::vector &texts, float *embeddings, std::optional prefix, - int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false, - EmbedCancelCallback *cancelCb = nullptr); - // automatic prefix - virtual void embed(const std::vector &texts, float *embeddings, bool isRetrieval, - int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false); - - virtual void setThreadCount(int32_t n_threads) { (void)n_threads; } - virtual int32_t threadCount() const { return 1; } - - const Implementation &implementation() const { - return *m_implementation; - } - - virtual std::vector availableGPUDevices(size_t memoryRequired) const { - (void)memoryRequired; - return {}; - } - - virtual bool initializeGPUDevice(size_t memoryRequired, const std::string &name) const { - (void)memoryRequired; - (void)name; - return false; - } - - virtual bool initializeGPUDevice(int device, std::string *unavail_reason = nullptr) const { - (void)device; - if (unavail_reason) { - *unavail_reason = "model has no GPU support"; - } - return false; - } - - virtual bool usingGPUDevice() const { return false; } - virtual const char *backendName() const { return "cpu"; } - virtual const char *gpuDeviceName() const { return nullptr; } - - void setProgressCallback(ProgressCallback callback) { m_progressCallback = callback; } - -protected: - // These are pure virtual because subclasses need to implement as the default implementation of - // 'prompt' above calls these functions - virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0; - virtual bool isSpecialToken(Token id) const = 0; - virtual std::string tokenToString(Token id) const = 0; - virtual Token sampleToken(PromptContext &ctx) const = 0; - virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) const = 0; - virtual void shiftContext(PromptContext &promptCtx) = 0; - virtual int32_t contextLength() const = 0; - virtual const std::vector &endTokens() const = 0; - virtual bool shouldAddBOS() const = 0; - - virtual int32_t maxContextLength(std::string const &modelPath) const - { - (void)modelPath; - return -1; - } - - virtual int32_t layerCount(std::string const &modelPath) const - { - (void)modelPath; - return -1; - } - - const Implementation *m_implementation = nullptr; - - ProgressCallback m_progressCallback; - static bool staticProgressCallback(float progress, void* ctx) - { - LLModel* model = static_cast(ctx); - if (model && model->m_progressCallback) - return model->m_progressCallback(progress); - return true; - } - - bool decodePrompt(std::function promptCallback, - std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx, - std::vector embd_inp); - void generateResponse(std::function responseCallback, - bool allowContextShift, - PromptContext &promptCtx); - - Token m_tokenize_last_token = -1; // not serialized - - friend class LLMImplementation; -}; - -#endif // LLMODEL_H diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index f3fd68ffa69c..18b59899e7b8 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -1,6 +1,8 @@ #include "llmodel_c.h" -#include "llmodel.h" +#include "llamacpp_backend.h" +#include "llamacpp_backend_manager.h" +#include "model_backend.h" #include #include @@ -15,8 +17,8 @@ #include struct LLModelWrapper { - LLModel *llModel = nullptr; - LLModel::PromptContext promptContext; + LlamaCppBackend *llModel = nullptr; + ModelBackend::PromptContext promptContext; ~LLModelWrapper() { delete llModel; } }; @@ -41,9 +43,9 @@ static void llmodel_set_error(const char **errptr, const char *message) llmodel_model llmodel_model_create2(const char *model_path, const char *backend, const char **error) { - LLModel *llModel; + LlamaCppBackend *llModel; try { - llModel = LLModel::Implementation::construct(model_path, backend); + llModel = LlamaCppBackendManager::construct(model_path, backend); } catch (const std::exception& e) { llmodel_set_error(error, e.what()); return nullptr; @@ -214,12 +216,12 @@ int32_t llmodel_threadCount(llmodel_model model) void llmodel_set_implementation_search_path(const char *path) { - LLModel::Implementation::setImplementationsSearchPath(path); + LlamaCppBackendManager::setImplementationsSearchPath(path); } const char *llmodel_get_implementation_search_path() { - return LLModel::Implementation::implementationsSearchPath().c_str(); + return LlamaCppBackendManager::implementationsSearchPath().c_str(); } // RAII wrapper around a C-style struct @@ -244,7 +246,7 @@ struct llmodel_gpu_device *llmodel_available_gpu_devices(size_t memoryRequired, { static thread_local std::unique_ptr c_devices; - auto devices = LLModel::Implementation::availableGPUDevices(memoryRequired); + auto devices = LlamaCppBackendManager::availableGPUDevices(memoryRequired); *num_devices = devices.size(); if (devices.empty()) { return nullptr; /* no devices */ } diff --git a/gpt4all-backend/llmodel_shared.h b/gpt4all-backend/llmodel_shared.h deleted file mode 100644 index 94a267bfa7c4..000000000000 --- a/gpt4all-backend/llmodel_shared.h +++ /dev/null @@ -1,49 +0,0 @@ -#pragma once - -#include - -#include -#include -#include - -struct llm_buffer { - uint8_t * addr = NULL; - size_t size = 0; - - void resize(size_t size) { - delete[] addr; - addr = new uint8_t[size]; - this->size = size; - } - - ~llm_buffer() { - delete[] addr; - } -}; - -struct llm_kv_cache { - struct ggml_tensor * k; - struct ggml_tensor * v; - - struct ggml_context * ctx = NULL; - - llm_buffer buf; - - int n; // number of tokens currently in the cache - - ~llm_kv_cache() { - if (ctx) { - ggml_free(ctx); - } - } -}; - -inline void ggml_graph_compute_g4a(llm_buffer& buf, ggml_cgraph * graph, int n_threads) -{ - struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); - if (plan.work_size > 0) { - buf.resize(plan.work_size); - plan.work_data = buf.addr; - } - ggml_graph_compute(graph, &plan); -} diff --git a/gpt4all-backend/model_backend.h b/gpt4all-backend/model_backend.h new file mode 100644 index 000000000000..467c4e8387e5 --- /dev/null +++ b/gpt4all-backend/model_backend.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#define LLMODEL_MAX_PROMPT_BATCH 128 + +class ModelBackend { +public: + using Token = int32_t; + + struct PromptContext { + std::vector tokens; // current tokens in the context window + int32_t n_past = 0; // number of tokens in past conversation + int32_t n_ctx = 0; // number of tokens possible in context window + int32_t n_predict = 200; + int32_t top_k = 40; + float top_p = 0.9f; + float min_p = 0.0f; + float temp = 0.9f; + int32_t n_batch = 9; + float repeat_penalty = 1.10f; + int32_t repeat_last_n = 64; // last n tokens to penalize + float contextErase = 0.5f; // percent of context to erase if we exceed the context window + }; + + virtual ~ModelBackend() {} + + virtual bool supportsCompletion() const { return true; } + virtual bool loadModel(const std::string &modelPath, int n_ctx, int ngl) = 0; + virtual bool isModelLoaded() const = 0; + virtual size_t stateSize() const { return 0; } + virtual size_t saveState(uint8_t *dest) const { (void)dest; return 0; } + virtual size_t restoreState(const uint8_t *src) { (void)src; return 0; } + + // This method requires the model to return true from supportsCompletion otherwise it will throw + // an error + virtual void prompt(const std::string &prompt, + const std::string &promptTemplate, + std::function promptCallback, + std::function responseCallback, + bool allowContextShift, + PromptContext &ctx, + bool special = false, + std::string *fakeReply = nullptr) = 0; + +protected: + explicit ModelBackend() {} +}; + +using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend); + +class EmbCapableBackend : virtual public ModelBackend { +public: + virtual bool supportsCompletion() const = 0; + virtual bool supportsEmbedding() const = 0; + virtual size_t embeddingSize() const = 0; + + // user-specified prefix + virtual void embed(const std::vector &texts, float *embeddings, std::optional prefix, + int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false, + EmbedCancelCallback *cancelCb = nullptr) = 0; + // automatic prefix + virtual void embed(const std::vector &texts, float *embeddings, bool isRetrieval, + int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) = 0; +}; diff --git a/gpt4all-bindings/python/setup.py b/gpt4all-bindings/python/setup.py index e92fba618b65..ed6f407191fd 100644 --- a/gpt4all-bindings/python/setup.py +++ b/gpt4all-bindings/python/setup.py @@ -55,7 +55,7 @@ def copy_prebuilt_C_lib(src_dir, dest_dir, dest_build_dir): # NOTE: You must provide correct path to the prebuilt llmodel C library. -# Specifically, the llmodel.h and C shared library are needed. +# Specifically, the model_backend.h and C shared library are needed. copy_prebuilt_C_lib(SRC_CLIB_DIRECTORY, DEST_CLIB_DIRECTORY, DEST_CLIB_BUILD_DIRECTORY) diff --git a/gpt4all-bindings/typescript/index.h b/gpt4all-bindings/typescript/index.h index db3ef11e6764..7726e8cf9f64 100644 --- a/gpt4all-bindings/typescript/index.h +++ b/gpt4all-bindings/typescript/index.h @@ -1,4 +1,4 @@ -#include "llmodel.h" +#include "model_backend.h" #include "llmodel_c.h" #include "prompt.h" #include diff --git a/gpt4all-bindings/typescript/prompt.h b/gpt4all-bindings/typescript/prompt.h index 49c43620368c..e1d0a5507c19 100644 --- a/gpt4all-bindings/typescript/prompt.h +++ b/gpt4all-bindings/typescript/prompt.h @@ -1,7 +1,7 @@ #ifndef PREDICT_WORKER_H #define PREDICT_WORKER_H -#include "llmodel.h" +#include "model_backend.h" #include "llmodel_c.h" #include "napi.h" #include diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 07acef15fa2b..325d21d1765b 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -109,7 +109,8 @@ endif() qt_add_executable(chat main.cpp chat.h chat.cpp - chatllm.h chatllm.cpp + llmodel.h llmodel.cpp + llamacpp_model.h llamacpp_model.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp chatapi.h chatapi.cpp chatviewtextprocessor.h chatviewtextprocessor.cpp @@ -326,18 +327,18 @@ install( # to the this component's dir for the finicky qt installer to work if (LLMODEL_KOMPUTE) set(MODEL_IMPL_TARGETS - llamamodel-mainline-kompute - llamamodel-mainline-kompute-avxonly + llamacpp-kompute + llamacpp-kompute-avxonly ) else() set(MODEL_IMPL_TARGETS - llamamodel-mainline-cpu - llamamodel-mainline-cpu-avxonly + llamacpp-cpu + llamacpp-cpu-avxonly ) endif() if (APPLE) - list(APPEND MODEL_IMPL_TARGETS llamamodel-mainline-metal) + list(APPEND MODEL_IMPL_TARGETS llamacpp-metal) endif() install( @@ -365,12 +366,12 @@ if(WIN32 AND GPT4ALL_SIGN_INSTALL) endif() if (LLMODEL_CUDA) - set_property(TARGET llamamodel-mainline-cuda llamamodel-mainline-cuda-avxonly + set_property(TARGET llamacpp-cuda llamacpp-cuda-avxonly APPEND PROPERTY INSTALL_RPATH "$ORIGIN") install( - TARGETS llamamodel-mainline-cuda - llamamodel-mainline-cuda-avxonly + TARGETS llamacpp-cuda + llamacpp-cuda-avxonly RUNTIME_DEPENDENCY_SET llama-cuda-deps LIBRARY DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN} # .so/.dylib RUNTIME DESTINATION lib COMPONENT ${COMPONENT_NAME_MAIN} # .dll diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index a44022c0bc24..4ef701db05ee 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -1,6 +1,7 @@ #include "chat.h" #include "chatlistmodel.h" +#include "llamacpp_model.h" #include "mysettings.h" #include "network.h" #include "server.h" @@ -26,7 +27,7 @@ Chat::Chat(QObject *parent) , m_chatModel(new ChatModel(this)) , m_responseState(Chat::ResponseStopped) , m_creationDate(QDateTime::currentSecsSinceEpoch()) - , m_llmodel(new ChatLLM(this)) + , m_llmodel(new LlamaCppModel(this)) , m_collectionModel(new LocalDocsCollectionsModel(this)) { connectLLM(); @@ -55,31 +56,30 @@ Chat::~Chat() void Chat::connectLLM() { // Should be in different threads - connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); - - connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); - connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); - connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); - connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); - connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); - connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); - connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection); - connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); - connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); + + connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection); + connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection); + connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection); + connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection); + connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection); + connect(this, &Chat::resetResponseRequested, m_llmodel, &LLModel::resetResponse, Qt::QueuedConnection); + connect(this, &Chat::resetContextRequested, m_llmodel, &LLModel::resetContext, Qt::QueuedConnection); + connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LLModel::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); } @@ -276,25 +276,23 @@ void Chat::markForDeletion() void Chat::unloadModel() { stopGenerating(); - m_llmodel->setShouldBeLoaded(false); + m_llmodel->releaseModelAsync(); } void Chat::reloadModel() { - m_llmodel->setShouldBeLoaded(true); + m_llmodel->loadModelAsync(); } void Chat::forceUnloadModel() { stopGenerating(); - m_llmodel->setForceUnloadModel(true); - m_llmodel->setShouldBeLoaded(false); + m_llmodel->releaseModelAsync(/*unload*/ true); } void Chat::forceReloadModel() { - m_llmodel->setForceUnloadModel(true); - m_llmodel->setShouldBeLoaded(true); + m_llmodel->loadModelAsync(/*reload*/ true); } void Chat::trySwitchContextOfLoadedModel() @@ -344,17 +342,20 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) QString Chat::deviceBackend() const { - return m_llmodel->deviceBackend(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->deviceBackend() : QString(); } QString Chat::device() const { - return m_llmodel->device(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->device() : QString(); } QString Chat::fallbackReason() const { - return m_llmodel->fallbackReason(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->fallbackReason() : QString(); } void Chat::handleDatabaseResultsChanged(const QList &results) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 065c624eef31..2b4187bf82ce 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -1,9 +1,9 @@ #ifndef CHAT_H #define CHAT_H -#include "chatllm.h" #include "chatmodel.h" #include "database.h" // IWYU pragma: keep +#include "llmodel.h" #include "localdocsmodel.h" // IWYU pragma: keep #include "modellist.h" @@ -94,7 +94,7 @@ class Chat : public QObject Q_INVOKABLE void reloadModel(); Q_INVOKABLE void forceUnloadModel(); Q_INVOKABLE void forceReloadModel(); - Q_INVOKABLE void trySwitchContextOfLoadedModel(); + void trySwitchContextOfLoadedModel(); void unloadAndDeleteLater(); void markForDeletion(); @@ -145,7 +145,6 @@ public Q_SLOTS: void modelChangeRequested(const ModelInfo &modelInfo); void modelInfoChanged(); void restoringFromTextChanged(); - void loadDefaultModelRequested(); void loadModelRequested(const ModelInfo &modelInfo); void generateNameRequested(); void modelLoadingErrorChanged(); @@ -161,7 +160,7 @@ public Q_SLOTS: private Q_SLOTS: void handleResponseChanged(const QString &response); - void handleModelLoadingPercentageChanged(float); + void handleModelLoadingPercentageChanged(float loadingPercentage); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); @@ -191,7 +190,7 @@ private Q_SLOTS: bool m_responseInProgress = false; ResponseState m_responseState; qint64 m_creationDate; - ChatLLM *m_llmodel; + LLModel *m_llmodel; QList m_databaseResults; bool m_isServer = false; bool m_shouldDeleteLater = false; diff --git a/gpt4all-chat/chatapi.cpp b/gpt4all-chat/chatapi.cpp index b443f24c3ab7..740f1a141dfc 100644 --- a/gpt4all-chat/chatapi.cpp +++ b/gpt4all-chat/chatapi.cpp @@ -1,6 +1,6 @@ #include "chatapi.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/model_backend.h" #include #include @@ -32,14 +32,6 @@ ChatAPI::ChatAPI() { } -size_t ChatAPI::requiredMem(const std::string &modelPath, int n_ctx, int ngl) -{ - Q_UNUSED(modelPath); - Q_UNUSED(n_ctx); - Q_UNUSED(ngl); - return 0; -} - bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl) { Q_UNUSED(modelPath); @@ -48,27 +40,14 @@ bool ChatAPI::loadModel(const std::string &modelPath, int n_ctx, int ngl) return true; } -void ChatAPI::setThreadCount(int32_t n_threads) -{ - Q_UNUSED(n_threads); - qt_noop(); -} - -int32_t ChatAPI::threadCount() const -{ - return 1; -} - -ChatAPI::~ChatAPI() -{ -} +ChatAPI::~ChatAPI() {} bool ChatAPI::isModelLoaded() const { return true; } -// All three of the state virtual functions are handled custom inside of chatllm save/restore +// All three of the state virtual functions are handled custom inside of LlamaCppModel save/restore size_t ChatAPI::stateSize() const { return 0; @@ -191,7 +170,7 @@ bool ChatAPI::callResponse(int32_t token, const std::string& string) } void ChatAPIWorker::request(const QString &apiKey, - LLModel::PromptContext *promptCtx, + ModelBackend::PromptContext *promptCtx, const QByteArray &array) { m_ctx = promptCtx; diff --git a/gpt4all-chat/chatapi.h b/gpt4all-chat/chatapi.h index 59b68f582108..45d50fe1fbbb 100644 --- a/gpt4all-chat/chatapi.h +++ b/gpt4all-chat/chatapi.h @@ -1,7 +1,7 @@ #ifndef CHATAPI_H #define CHATAPI_H -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/model_backend.h" #include #include @@ -33,7 +33,7 @@ class ChatAPIWorker : public QObject { QString currentResponse() const { return m_currentResponse; } void request(const QString &apiKey, - LLModel::PromptContext *promptCtx, + ModelBackend::PromptContext *promptCtx, const QByteArray &array); Q_SIGNALS: @@ -46,22 +46,19 @@ private Q_SLOTS: private: ChatAPI *m_chat; - LLModel::PromptContext *m_ctx; + ModelBackend::PromptContext *m_ctx; QNetworkAccessManager *m_networkManager; QString m_currentResponse; }; -class ChatAPI : public QObject, public LLModel { +class ChatAPI : public QObject, public ModelBackend { Q_OBJECT public: ChatAPI(); virtual ~ChatAPI(); - bool supportsEmbedding() const override { return false; } - bool supportsCompletion() const override { return true; } bool loadModel(const std::string &modelPath, int n_ctx, int ngl) override; bool isModelLoaded() const override; - size_t requiredMem(const std::string &modelPath, int n_ctx, int ngl) override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; @@ -74,9 +71,6 @@ class ChatAPI : public QObject, public LLModel { bool special, std::string *fakeReply) override; - void setThreadCount(int32_t n_threads) override; - int32_t threadCount() const override; - void setModelName(const QString &modelName) { m_modelName = modelName; } void setAPIKey(const QString &apiKey) { m_apiKey = apiKey; } void setRequestURL(const QString &requestURL) { m_requestURL = requestURL; } @@ -89,68 +83,9 @@ class ChatAPI : public QObject, public LLModel { Q_SIGNALS: void request(const QString &apiKey, - LLModel::PromptContext *ctx, + ModelBackend::PromptContext *ctx, const QByteArray &array); -protected: - // We have to implement these as they are pure virtual in base class, but we don't actually use - // them as they are only called from the default implementation of 'prompt' which we override and - // completely replace - - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override - { - (void)ctx; - (void)str; - (void)special; - throw std::logic_error("not implemented"); - } - - bool isSpecialToken(Token id) const override - { - (void)id; - throw std::logic_error("not implemented"); - } - - std::string tokenToString(Token id) const override - { - (void)id; - throw std::logic_error("not implemented"); - } - - Token sampleToken(PromptContext &ctx) const override - { - (void)ctx; - throw std::logic_error("not implemented"); - } - - bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override - { - (void)ctx; - (void)tokens; - throw std::logic_error("not implemented"); - } - - void shiftContext(PromptContext &promptCtx) override - { - (void)promptCtx; - throw std::logic_error("not implemented"); - } - - int32_t contextLength() const override - { - throw std::logic_error("not implemented"); - } - - const std::vector &endTokens() const override - { - throw std::logic_error("not implemented"); - } - - bool shouldAddBOS() const override - { - throw std::logic_error("not implemented"); - } - private: std::function m_responseCallback; QString m_modelName; diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index 95f8e6f2a523..cc6c079d4a10 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -2,7 +2,7 @@ #define CHATLISTMODEL_H #include "chat.h" -#include "chatllm.h" +#include "llamacpp_model.h" #include "chatmodel.h" #include @@ -220,11 +220,11 @@ class ChatListModel : public QAbstractListModel int count() const { return m_chats.size(); } - // stop ChatLLM threads for clean shutdown + // stop LlamaCppModel threads for clean shutdown void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } - ChatLLM::destroyStore(); + LlamaCppModel::destroyStore(); } void removeChatFile(Chat *chat) const; diff --git a/gpt4all-chat/download.cpp b/gpt4all-chat/download.cpp index 47981d0b73c7..843fefd94a20 100644 --- a/gpt4all-chat/download.cpp +++ b/gpt4all-chat/download.cpp @@ -263,16 +263,17 @@ void Download::installModel(const QString &modelFile, const QString &apiKey) QFile file(filePath); if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) { - QJsonObject obj; QString modelName(modelFile); modelName.remove(0, 8); // strip "gpt4all-" prefix modelName.chop(7); // strip ".rmodel" extension - obj.insert("apiKey", apiKey); - obj.insert("modelName", modelName); - QJsonDocument doc(obj); + QJsonObject obj { + { "type", ... }, + { "apiKey", apiKey }, + { "modelName", modelName }, + }; QTextStream stream(&file); - stream << doc.toJson(); + stream << QJsonDocument(doc).toJson(); file.close(); ModelList::globalInstance()->updateModelsFromDirectory(); emit toastMessage(tr("Model \"%1\" is installed successfully.").arg(modelName)); @@ -312,14 +313,15 @@ void Download::installCompatibleModel(const QString &modelName, const QString &a QString filePath = MySettings::globalInstance()->modelPath() + modelFile; QFile file(filePath); if (file.open(QIODeviceBase::WriteOnly | QIODeviceBase::Text)) { - QJsonObject obj; - obj.insert("apiKey", apiKey); - obj.insert("modelName", modelName); - obj.insert("baseUrl", apiBaseUrl.toString()); - QJsonDocument doc(obj); + QJsonObject obj { + { "type", "openai-generic" }, + { "apiKey", apiKey }, + { "modelName", modelName }, + { "baseUrl", apiBaseUrl.toString() }, + }; QTextStream stream(&file); - stream << doc.toJson(); + stream << QJsonDocument(obj).toJson(); file.close(); ModelList::globalInstance()->updateModelsFromDirectory(); emit toastMessage(tr("Model \"%1 (%2)\" is installed successfully.").arg(modelName, baseUrl)); @@ -336,20 +338,26 @@ void Download::removeModel(const QString &modelFile) incompleteFile.remove(); } - bool shouldRemoveInstalled = false; + bool removedFromList = false; QFile file(filePath); if (file.exists()) { const ModelInfo info = ModelList::globalInstance()->modelInfoByFilename(modelFile); MySettings::globalInstance()->eraseModel(info); - shouldRemoveInstalled = info.installed && !info.isClone() && (info.isDiscovered() || info.isCompatibleApi || info.description() == "" /*indicates sideloaded*/); - if (shouldRemoveInstalled) + if ( + info.installed && !info.isClone() && ( + info.isDiscovered() || info.description() == "" /*indicates sideloaded*/ + || info.provider == ModelInfo::Provider::OpenAIGeneric + ) + ) { ModelList::globalInstance()->removeInstalled(info); + removedFromList = true; + } Network::globalInstance()->trackEvent("remove_model", { {"model", modelFile} }); file.remove(); emit toastMessage(tr("Model \"%1\" is removed.").arg(info.name())); } - if (!shouldRemoveInstalled) { + if (!removedFromList) { QVector> data { { ModelList::InstalledRole, false }, { ModelList::BytesReceivedRole, 0 }, diff --git a/gpt4all-chat/embllm.cpp b/gpt4all-chat/embllm.cpp index 615a6ce4d925..57f5f3a87b73 100644 --- a/gpt4all-chat/embllm.cpp +++ b/gpt4all-chat/embllm.cpp @@ -3,7 +3,8 @@ #include "modellist.h" #include "mysettings.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include #include @@ -99,7 +100,7 @@ bool EmbeddingLLMWorker::loadModel() #endif try { - m_model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx); + m_model = LlamaCppBackendManager::construct(filePath.toStdString(), backend, n_ctx); } catch (const std::exception &e) { qWarning() << "embllm WARNING: Could not load embedding model:" << e.what(); return false; @@ -108,15 +109,15 @@ bool EmbeddingLLMWorker::loadModel() bool actualDeviceIsCPU = true; #if defined(Q_OS_MAC) && defined(__aarch64__) - if (m_model->implementation().buildVariant() == "metal") + if (m_model->manager().buildVariant() == "metal") actualDeviceIsCPU = false; #else if (requestedDevice != "CPU") { - const LLModel::GPUDevice *device = nullptr; - std::vector availableDevices = m_model->availableGPUDevices(0); + const LlamaCppBackend::GPUDevice *device = nullptr; + auto availableDevices = m_model->availableGPUDevices(0); if (requestedDevice != "Auto") { // Use the selected device - for (const LLModel::GPUDevice &d : availableDevices) { + for (const auto &d : availableDevices) { if (QString::fromStdString(d.selectionName()) == requestedDevice) { device = &d; break; @@ -145,7 +146,7 @@ bool EmbeddingLLMWorker::loadModel() if (backend == "cuda") { // For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls try { - m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto", n_ctx); + m_model = LlamaCppBackendManager::construct(filePath.toStdString(), "auto", n_ctx); } catch (const std::exception &e) { qWarning() << "embllm WARNING: Could not load embedding model:" << e.what(); return false; @@ -192,7 +193,7 @@ std::vector EmbeddingLLMWorker::generateQueryEmbedding(const QString &tex try { m_model->embed({text.toStdString()}, embedding.data(), /*isRetrieval*/ true); } catch (const std::exception &e) { - qWarning() << "WARNING: LLModel::embed failed:" << e.what(); + qWarning() << "WARNING: LlamaCppBackend::embed failed:" << e.what(); return {}; } @@ -286,7 +287,7 @@ void EmbeddingLLMWorker::docEmbeddingsRequested(const QVector &c try { m_model->embed(batchTexts, result.data() + j * m_model->embeddingSize(), /*isRetrieval*/ false); } catch (const std::exception &e) { - qWarning() << "WARNING: LLModel::embed failed:" << e.what(); + qWarning() << "WARNING: LlamaCppBackend::embed failed:" << e.what(); return; } } diff --git a/gpt4all-chat/embllm.h b/gpt4all-chat/embllm.h index 91376650d05e..fda773e54537 100644 --- a/gpt4all-chat/embllm.h +++ b/gpt4all-chat/embllm.h @@ -13,7 +13,7 @@ #include #include -class LLModel; +class LlamaCppBackend; class QNetworkAccessManager; struct EmbeddingChunk { @@ -67,7 +67,7 @@ private Q_SLOTS: QString m_nomicAPIKey; QNetworkAccessManager *m_networkManager; std::vector m_lastResponse; - LLModel *m_model = nullptr; + LlamaCppBackend *m_model = nullptr; std::atomic m_stopGenerating; QThread m_workerThread; QMutex m_mutex; // guards m_model and m_nomicAPIKey diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/llamacpp_model.cpp similarity index 83% rename from gpt4all-chat/chatllm.cpp rename to gpt4all-chat/llamacpp_model.cpp index e9fb7f3132f9..9112d563216c 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/llamacpp_model.cpp @@ -1,4 +1,4 @@ -#include "chatllm.h" +#include "llamacpp_model.h" #include "chat.h" #include "chatapi.h" @@ -6,6 +6,8 @@ #include "mysettings.h" #include "network.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" + #include #include #include @@ -92,19 +94,18 @@ void LLModelStore::destroy() m_availableModel.reset(); } -void LLModelInfo::resetModel(ChatLLM *cllm, LLModel *model) { +void LLModelInfo::resetModel(LlamaCppModel *cllm, ModelBackend *model) +{ this->model.reset(model); fallbackReason.reset(); emit cllm->loadedModelInfoChanged(); } -ChatLLM::ChatLLM(Chat *parent, bool isServer) - : QObject{nullptr} - , m_promptResponseTokens(0) +LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer) + : m_promptResponseTokens(0) , m_promptTokens(0) , m_restoringFromText(false) , m_shouldBeLoaded(false) - , m_forceUnloadModel(false) , m_markedForDeletion(false) , m_stopGenerating(false) , m_timer(nullptr) @@ -115,29 +116,31 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_restoreStateFromText(false) { moveToThread(&m_llmThread); - connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, + connect( + this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel + ); + connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel); + connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel, Qt::QueuedConnection); // explicitly queued - connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel, - Qt::QueuedConnection); // explicitly queued - connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); - connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); - connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); - connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); + connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged); + connect(&m_llmThread, &QThread::started, this, &LlamaCppModel::handleThreadStarted); + connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &LlamaCppModel::handleForceMetalChanged); + connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &LlamaCppModel::handleDeviceChanged); // The following are blocking operations and will block the llm thread - connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, + connect(this, &LlamaCppModel::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, Qt::BlockingQueuedConnection); m_llmThread.setObjectName(parent->id()); m_llmThread.start(); } -ChatLLM::~ChatLLM() +LlamaCppModel::~LlamaCppModel() { destroy(); } -void ChatLLM::destroy() +void LlamaCppModel::destroy() { m_stopGenerating = true; m_llmThread.quit(); @@ -150,52 +153,40 @@ void ChatLLM::destroy() } } -void ChatLLM::destroyStore() +void LlamaCppModel::destroyStore() { LLModelStore::globalInstance()->destroy(); } -void ChatLLM::handleThreadStarted() +void LlamaCppModel::handleThreadStarted() { m_timer = new TokenTimer(this); - connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed); + connect(m_timer, &TokenTimer::report, this, &LlamaCppModel::reportSpeed); emit threadStarted(); } -void ChatLLM::handleForceMetalChanged(bool forceMetal) +void LlamaCppModel::handleForceMetalChanged(bool forceMetal) { #if defined(Q_OS_MAC) && defined(__aarch64__) m_forceMetal = forceMetal; if (isModelLoaded() && m_shouldBeLoaded) { m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); + loadModel(/*reload*/ true); m_reloadingToChangeVariant = false; } #endif } -void ChatLLM::handleDeviceChanged() +void LlamaCppModel::handleDeviceChanged() { if (isModelLoaded() && m_shouldBeLoaded) { m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); + loadModel(/*reload*/ true); m_reloadingToChangeVariant = false; } } -bool ChatLLM::loadDefaultModel() -{ - ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); - if (defaultModel.filename().isEmpty()) { - emit modelLoadingError(u"Could not find any model to load"_qs); - return false; - } - return loadModel(defaultModel); -} - -void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) +void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) { // We're trying to see if the store already has the model fully loaded that we wish to use // and if so we just acquire it from the store and switch the context and return true. If the @@ -239,7 +230,7 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) processSystemPrompt(); } -bool ChatLLM::loadModel(const ModelInfo &modelInfo) +bool LlamaCppModel::loadModel(const ModelInfo &modelInfo) { // This is a complicated method because N different possible threads are interested in the outcome // of this method. Why? Because we have a main/gui thread trying to monitor the state of N different @@ -386,7 +377,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) /* Returns false if the model should no longer be loaded (!m_shouldBeLoaded). * Otherwise returns true, even on error. */ -bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps) +bool LlamaCppModel::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps) { QElapsedTimer modelLoadTimer; modelLoadTimer.start(); @@ -412,19 +403,20 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro QString filePath = modelInfo.dirpath + modelInfo.filename(); - auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx](std::string const &backend) { + auto construct = [this, &filePath, &modelInfo, &modelLoadProps, n_ctx](std::string const &backend) -> LlamaCppBackend * { + LlamaCppBackend *lcppmodel; QString constructError; m_llModelInfo.resetModel(this); try { - auto *model = LLModel::Implementation::construct(filePath.toStdString(), backend, n_ctx); - m_llModelInfo.resetModel(this, model); - } catch (const LLModel::MissingImplementationError &e) { + lcppmodel = LlamaCppBackendManager::construct(filePath.toStdString(), backend, n_ctx); + m_llModelInfo.resetModel(this, lcppmodel); + } catch (const LlamaCppBackendManager::MissingImplementationError &e) { modelLoadProps.insert("error", "missing_model_impl"); constructError = e.what(); - } catch (const LLModel::UnsupportedModelError &e) { + } catch (const LlamaCppBackendManager::UnsupportedModelError &e) { modelLoadProps.insert("error", "unsupported_model_file"); constructError = e.what(); - } catch (const LLModel::BadArchError &e) { + } catch (const LlamaCppBackendManager::BadArchError &e) { constructError = e.what(); modelLoadProps.insert("error", "unsupported_model_arch"); modelLoadProps.insert("model_arch", QString::fromStdString(e.arch())); @@ -435,21 +427,22 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); resetModel(); emit modelLoadingError(u"Error loading %1: %2"_s.arg(modelInfo.filename(), constructError)); - return false; + return nullptr; } - m_llModelInfo.model->setProgressCallback([this](float progress) -> bool { + lcppmodel->setProgressCallback([this](float progress) -> bool { progress = std::max(progress, std::numeric_limits::min()); // keep progress above zero emit modelLoadingPercentageChanged(progress); return m_shouldBeLoaded; }); - return true; + return lcppmodel; }; - if (!construct(backend)) + auto *lcppmodel = construct(backend); + if (!lcppmodel) return true; - if (m_llModelInfo.model->isModelBlacklisted(filePath.toStdString())) { + if (lcppmodel->isModelBlacklisted(filePath.toStdString())) { static QSet warned; auto fname = modelInfo.filename(); if (!warned.contains(fname)) { @@ -460,16 +453,16 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro } } - auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) { + auto approxDeviceMemGB = [](const LlamaCppBackend::GPUDevice *dev) { float memGB = dev->heapSize / float(1024 * 1024 * 1024); return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place }; - std::vector availableDevices; - const LLModel::GPUDevice *defaultDevice = nullptr; + std::vector availableDevices; + const LlamaCppBackend::GPUDevice *defaultDevice = nullptr; { - const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString(), n_ctx, ngl); - availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); + const size_t requiredMemory = lcppmodel->requiredMem(filePath.toStdString(), n_ctx, ngl); + availableDevices = lcppmodel->availableGPUDevices(requiredMemory); // Pick the best device // NB: relies on the fact that Kompute devices are listed first if (!availableDevices.empty() && availableDevices.front().type == 2 /*a discrete gpu*/) { @@ -485,14 +478,14 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro bool actualDeviceIsCPU = true; #if defined(Q_OS_MAC) && defined(__aarch64__) - if (m_llModelInfo.model->implementation().buildVariant() == "metal") + if (lcppmodel->manager().buildVariant() == "metal") actualDeviceIsCPU = false; #else if (requestedDevice != "CPU") { const auto *device = defaultDevice; if (requestedDevice != "Auto") { // Use the selected device - for (const LLModel::GPUDevice &d : availableDevices) { + for (const auto &d : availableDevices) { if (QString::fromStdString(d.selectionName()) == requestedDevice) { device = &d; break; @@ -503,7 +496,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro std::string unavail_reason; if (!device) { // GPU not available - } else if (!m_llModelInfo.model->initializeGPUDevice(device->index, &unavail_reason)) { + } else if (!lcppmodel->initializeGPUDevice(device->index, &unavail_reason)) { m_llModelInfo.fallbackReason = QString::fromStdString(unavail_reason); } else { actualDeviceIsCPU = false; @@ -512,7 +505,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro } #endif - bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl); + bool success = lcppmodel->loadModel(filePath.toStdString(), n_ctx, ngl); if (!m_shouldBeLoaded) { m_llModelInfo.resetModel(this); @@ -531,10 +524,13 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed"); // For CUDA, make sure we don't use the GPU at all - ngl=0 still offloads matmuls - if (backend == "cuda" && !construct("auto")) - return true; + if (backend == "cuda") { + lcppmodel = construct("auto"); + if (!lcppmodel) + return true; + } - success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0); + success = lcppmodel->loadModel(filePath.toStdString(), n_ctx, 0); if (!m_shouldBeLoaded) { m_llModelInfo.resetModel(this); @@ -544,7 +540,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro emit modelLoadingPercentageChanged(0.0f); return false; } - } else if (!m_llModelInfo.model->usingGPUDevice()) { + } else if (!lcppmodel->usingGPUDevice()) { // ggml_vk_init was not called in llama.cpp // We might have had to fallback to CPU after load if the model is not possible to accelerate // for instance if the quantization method is not supported on Vulkan yet @@ -562,7 +558,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro return true; } - switch (m_llModelInfo.model->implementation().modelType()[0]) { + switch (lcppmodel->manager().modelType()[0]) { case 'L': m_llModelType = LLModelType::LLAMA_; break; default: { @@ -576,43 +572,15 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro modelLoadProps.insert("$duration", modelLoadTimer.elapsed() / 1000.); return true; -}; - -bool ChatLLM::isModelLoaded() const -{ - return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); -} - -std::string remove_leading_whitespace(const std::string& input) -{ - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - if (first_non_whitespace == input.end()) - return std::string(); - - return std::string(first_non_whitespace, input.end()); } -std::string trim_whitespace(const std::string& input) +bool LlamaCppModel::isModelLoaded() const { - auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { - return !std::isspace(c); - }); - - if (first_non_whitespace == input.end()) - return std::string(); - - auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { - return !std::isspace(c); - }).base(); - - return std::string(first_non_whitespace, last_non_whitespace); + return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); } // FIXME(jared): we don't actually have to re-decode the prompt to generate a new response -void ChatLLM::regenerateResponse() +void LlamaCppModel::regenerateResponse() { // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // of n_past is of the number of prompt/response pairs, rather than for total tokens. @@ -628,7 +596,7 @@ void ChatLLM::regenerateResponse() emit responseChanged(QString::fromStdString(m_response)); } -void ChatLLM::resetResponse() +void LlamaCppModel::resetResponse() { m_promptTokens = 0; m_promptResponseTokens = 0; @@ -636,46 +604,43 @@ void ChatLLM::resetResponse() emit responseChanged(QString::fromStdString(m_response)); } -void ChatLLM::resetContext() +void LlamaCppModel::resetContext() { resetResponse(); m_processedSystemPrompt = false; - m_ctx = LLModel::PromptContext(); + m_ctx = ModelBackend::PromptContext(); } -QString ChatLLM::response() const +QString LlamaCppModel::response() const { return QString::fromStdString(remove_leading_whitespace(m_response)); } -ModelInfo ChatLLM::modelInfo() const -{ - return m_modelInfo; -} - -void ChatLLM::setModelInfo(const ModelInfo &modelInfo) +void LlamaCppModel::setModelInfo(const ModelInfo &modelInfo) { m_modelInfo = modelInfo; emit modelInfoChanged(modelInfo); } -void ChatLLM::acquireModel() { +void LlamaCppModel::acquireModel() +{ m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); emit loadedModelInfoChanged(); } -void ChatLLM::resetModel() { +void LlamaCppModel::resetModel() +{ m_llModelInfo = {}; emit loadedModelInfoChanged(); } -void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) +void LlamaCppModel::modelChangeRequested(const ModelInfo &modelInfo) { m_shouldBeLoaded = true; loadModel(modelInfo); } -bool ChatLLM::handlePrompt(int32_t token) +bool LlamaCppModel::handlePrompt(int32_t token) { // m_promptResponseTokens is related to last prompt/response not // the entire context window which we can reset on regenerate prompt @@ -688,7 +653,7 @@ bool ChatLLM::handlePrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleResponse(int32_t token, const std::string &response) +bool LlamaCppModel::handleResponse(int32_t token, const std::string &response) { #if defined(DEBUG) printf("%s", response.c_str()); @@ -712,7 +677,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) return !m_stopGenerating; } -bool ChatLLM::prompt(const QList &collectionList, const QString &prompt) +bool LlamaCppModel::prompt(const QList &collectionList, const QString &prompt) { if (m_restoreStateFromText) { Q_ASSERT(m_state.isEmpty()); @@ -734,7 +699,7 @@ bool ChatLLM::prompt(const QList &collectionList, const QString &prompt repeat_penalty, repeat_penalty_tokens); } -bool ChatLLM::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, +bool LlamaCppModel::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens) { @@ -762,8 +727,8 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString int n_threads = MySettings::globalInstance()->threadCount(); m_stopGenerating = false; - auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, + auto promptFunc = std::bind(&LlamaCppModel::handlePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&LlamaCppModel::handleResponse, this, std::placeholders::_1, std::placeholders::_2); emit promptProcessing(); m_ctx.n_predict = n_predict; @@ -774,11 +739,15 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); + + if (auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get())) + lcppmodel->setThreadCount(n_threads); + #if defined(DEBUG) printf("%s", qPrintable(prompt)); fflush(stdout); #endif + QElapsedTimer totalTime; totalTime.start(); m_timer->start(); @@ -812,38 +781,61 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString return true; } -void ChatLLM::setShouldBeLoaded(bool b) +void LlamaCppModel::loadModelAsync(bool reload) { -#if defined(DEBUG_MODEL_LOADING) - qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get(); -#endif - m_shouldBeLoaded = b; // atomic - emit shouldBeLoadedChanged(); + m_shouldBeLoaded = true; // atomic + emit requestLoadModel(reload); +} + +void LlamaCppModel::releaseModelAsync(bool unload) +{ + m_shouldBeLoaded = false; // atomic + emit requestReleaseModel(unload); } -void ChatLLM::requestTrySwitchContext() +void LlamaCppModel::requestTrySwitchContext() { m_shouldBeLoaded = true; // atomic emit trySwitchContextRequested(modelInfo()); } -void ChatLLM::handleShouldBeLoadedChanged() +void LlamaCppModel::loadModel(bool reload) { - if (m_shouldBeLoaded) - reloadModel(); - else - unloadModel(); + Q_ASSERT(m_shouldBeLoaded); + if (m_isServer) + return; // server managed models directly + + if (reload) + releaseModel(/*unload*/ true); + else if (isModelLoaded()) + return; // already loaded + +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); +#endif + ModelInfo m = modelInfo(); + if (m.name().isEmpty()) { + ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); + if (defaultModel.filename().isEmpty()) { + emit modelLoadingError(u"Could not find any model to load"_s); + return; + } + m = defaultModel; + } + loadModel(m); } -void ChatLLM::unloadModel() +void LlamaCppModel::releaseModel(bool unload) { if (!isModelLoaded() || m_isServer) return; - if (!m_forceUnloadModel || !m_shouldBeLoaded) + if (unload && m_shouldBeLoaded) { + // reloading the model, don't show unloaded status + emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small positive value + } else { emit modelLoadingPercentageChanged(0.0f); - else - emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + } if (!m_markedForDeletion) saveState(); @@ -852,34 +844,15 @@ void ChatLLM::unloadModel() qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - if (m_forceUnloadModel) { + if (unload) { m_llModelInfo.resetModel(this); - m_forceUnloadModel = false; } LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_pristineLoadedState = false; } -void ChatLLM::reloadModel() -{ - if (isModelLoaded() && m_forceUnloadModel) - unloadModel(); // we unload first if we are forcing an unload - - if (isModelLoaded() || m_isServer) - return; - -#if defined(DEBUG_MODEL_LOADING) - qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); -#endif - const ModelInfo m = modelInfo(); - if (m.name().isEmpty()) - loadDefaultModel(); - else - loadModel(m); -} - -void ChatLLM::generateName() +void LlamaCppModel::generateName() { Q_ASSERT(isModelLoaded()); if (!isModelLoaded()) @@ -887,14 +860,14 @@ void ChatLLM::generateName() const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo); if (chatNamePrompt.trimmed().isEmpty()) { - qWarning() << "ChatLLM: not generating chat name because prompt is empty"; + qWarning() << "LlamaCppModel: not generating chat name because prompt is empty"; return; } auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); - LLModel::PromptContext ctx = m_ctx; + auto promptFunc = std::bind(&LlamaCppModel::handleNamePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&LlamaCppModel::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); + ModelBackend::PromptContext ctx = m_ctx; m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, /*allowContextShift*/ false, ctx); std::string trimmed = trim_whitespace(m_nameResponse); @@ -905,12 +878,12 @@ void ChatLLM::generateName() m_pristineLoadedState = false; } -void ChatLLM::handleChatIdChanged(const QString &id) +void LlamaCppModel::handleChatIdChanged(const QString &id) { m_llmThread.setObjectName(id); } -bool ChatLLM::handleNamePrompt(int32_t token) +bool LlamaCppModel::handleNamePrompt(int32_t token) { #if defined(DEBUG) qDebug() << "name prompt" << m_llmThread.objectName() << token; @@ -919,7 +892,7 @@ bool ChatLLM::handleNamePrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) +bool LlamaCppModel::handleNameResponse(int32_t token, const std::string &response) { #if defined(DEBUG) qDebug() << "name response" << m_llmThread.objectName() << token << response; @@ -933,7 +906,7 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) return words.size() <= 3; } -bool ChatLLM::handleQuestionPrompt(int32_t token) +bool LlamaCppModel::handleQuestionPrompt(int32_t token) { #if defined(DEBUG) qDebug() << "question prompt" << m_llmThread.objectName() << token; @@ -942,7 +915,7 @@ bool ChatLLM::handleQuestionPrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response) +bool LlamaCppModel::handleQuestionResponse(int32_t token, const std::string &response) { #if defined(DEBUG) qDebug() << "question response" << m_llmThread.objectName() << token << response; @@ -971,7 +944,7 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response) return true; } -void ChatLLM::generateQuestions(qint64 elapsed) +void LlamaCppModel::generateQuestions(qint64 elapsed) { Q_ASSERT(isModelLoaded()); if (!isModelLoaded()) { @@ -988,9 +961,9 @@ void ChatLLM::generateQuestions(qint64 elapsed) emit generatingQuestions(); m_questionResponse.clear(); auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); - auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1); - auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); - LLModel::PromptContext ctx = m_ctx; + auto promptFunc = std::bind(&LlamaCppModel::handleQuestionPrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&LlamaCppModel::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); + ModelBackend::PromptContext ctx = m_ctx; QElapsedTimer totalTime; totalTime.start(); m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc, @@ -1000,7 +973,7 @@ void ChatLLM::generateQuestions(qint64 elapsed) } -bool ChatLLM::handleSystemPrompt(int32_t token) +bool LlamaCppModel::handleSystemPrompt(int32_t token) { #if defined(DEBUG) qDebug() << "system prompt" << m_llmThread.objectName() << token << m_stopGenerating; @@ -1009,7 +982,7 @@ bool ChatLLM::handleSystemPrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) +bool LlamaCppModel::handleRestoreStateFromTextPrompt(int32_t token) { #if defined(DEBUG) qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating; @@ -1020,7 +993,7 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) // this function serialized the cached model state to disk. // we want to also serialize n_ctx, and read it at load time. -bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) +bool LlamaCppModel::serialize(QDataStream &stream, int version, bool serializeKV) { if (version > 1) { stream << m_llModelType; @@ -1060,7 +1033,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) return stream.status() == QDataStream::Ok; } -bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) +bool LlamaCppModel::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) { if (version > 1) { int internalStateVersion; @@ -1140,7 +1113,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, return stream.status() == QDataStream::Ok; } -void ChatLLM::saveState() +void LlamaCppModel::saveState() { if (!isModelLoaded() || m_pristineLoadedState) return; @@ -1162,7 +1135,7 @@ void ChatLLM::saveState() m_llModelInfo.model->saveState(static_cast(reinterpret_cast(m_state.data()))); } -void ChatLLM::restoreState() +void LlamaCppModel::restoreState() { if (!isModelLoaded()) return; @@ -1203,7 +1176,7 @@ void ChatLLM::restoreState() } } -void ChatLLM::processSystemPrompt() +void LlamaCppModel::processSystemPrompt() { Q_ASSERT(isModelLoaded()); if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer) @@ -1217,9 +1190,9 @@ void ChatLLM::processSystemPrompt() // Start with a whole new context m_stopGenerating = false; - m_ctx = LLModel::PromptContext(); + m_ctx = ModelBackend::PromptContext(); - auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); + auto promptFunc = std::bind(&LlamaCppModel::handleSystemPrompt, this, std::placeholders::_1); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); @@ -1238,11 +1211,15 @@ void ChatLLM::processSystemPrompt() m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); + + if (auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get())) + lcppmodel->setThreadCount(n_threads); + #if defined(DEBUG) printf("%s", qPrintable(QString::fromStdString(systemPrompt))); fflush(stdout); #endif + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response // use "%1%2" and not "%1" to avoid implicit whitespace m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); @@ -1256,7 +1233,7 @@ void ChatLLM::processSystemPrompt() m_pristineLoadedState = false; } -void ChatLLM::processRestoreStateFromText() +void LlamaCppModel::processRestoreStateFromText() { Q_ASSERT(isModelLoaded()); if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) @@ -1266,9 +1243,9 @@ void ChatLLM::processRestoreStateFromText() emit restoringFromTextChanged(); m_stopGenerating = false; - m_ctx = LLModel::PromptContext(); + m_ctx = ModelBackend::PromptContext(); - auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); + auto promptFunc = std::bind(&LlamaCppModel::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); @@ -1288,7 +1265,9 @@ void ChatLLM::processRestoreStateFromText() m_ctx.n_batch = n_batch; m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); + + if (auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get())) + lcppmodel->setThreadCount(n_threads); auto it = m_stateFromText.begin(); while (it < m_stateFromText.end()) { diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/llamacpp_model.h similarity index 65% rename from gpt4all-chat/chatllm.h rename to gpt4all-chat/llamacpp_model.h index d123358ad58e..07a9c91b35cc 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/llamacpp_model.h @@ -1,10 +1,11 @@ -#ifndef CHATLLM_H -#define CHATLLM_H +#pragma once #include "database.h" // IWYU pragma: keep +#include "llmodel.h" #include "modellist.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend.h" +#include "../gpt4all-backend/model_backend.h" #include #include @@ -26,6 +27,8 @@ using namespace Qt::Literals::StringLiterals; +class Chat; +class LlamaCppModel; class QDataStream; // NOTE: values serialized to disk, do not change or reuse @@ -36,17 +39,15 @@ enum LLModelType { BERT_ = 3, // no longer used }; -class ChatLLM; - struct LLModelInfo { - std::unique_ptr model; + std::unique_ptr model; QFileInfo fileInfo; std::optional fallbackReason; - // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which + // NOTE: This does not store the model type or name on purpose as this is left for LlamaCppModel which // must be able to serialize the information even if it is in the unloaded state - void resetModel(ChatLLM *cllm, LLModel *model = nullptr); + void resetModel(LlamaCppModel *cllm, ModelBackend *model = nullptr); }; class TokenTimer : public QObject { @@ -89,54 +90,47 @@ private Q_SLOTS: quint32 m_tokens; }; -class Chat; -class ChatLLM : public QObject +class LlamaCppModel : public LLModel { Q_OBJECT - Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) + public: - ChatLLM(Chat *parent, bool isServer = false); - virtual ~ChatLLM(); + LlamaCppModel(Chat *parent, bool isServer = false); + ~LlamaCppModel() override; - void destroy(); + void destroy() override; static void destroyStore(); - bool isModelLoaded() const; - void regenerateResponse(); - void resetResponse(); - void resetContext(); + void regenerateResponse() override; + void resetResponse() override; + void resetContext() override; - void stopGenerating() { m_stopGenerating = true; } + void stopGenerating() override { m_stopGenerating = true; } - bool shouldBeLoaded() const { return m_shouldBeLoaded; } - void setShouldBeLoaded(bool b); - void requestTrySwitchContext(); - void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } - void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } + void loadModelAsync(bool reload = false) override; + void releaseModelAsync(bool unload = false) override; + void requestTrySwitchContext() override; + void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; } - QString response() const; + void setModelInfo(const ModelInfo &info) override; - ModelInfo modelInfo() const; - void setModelInfo(const ModelInfo &info); - - bool restoringFromText() const { return m_restoringFromText; } - - void acquireModel(); - void resetModel(); + bool restoringFromText() const override { return m_restoringFromText; } QString deviceBackend() const { - if (!isModelLoaded()) return QString(); - std::string name = LLModel::GPUDevice::backendIdToName(m_llModelInfo.model->backendName()); + auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get()); + if (!isModelLoaded() && !lcppmodel) return QString(); + std::string name = LlamaCppBackend::GPUDevice::backendIdToName(lcppmodel->backendName()); return QString::fromStdString(name); } QString device() const { - if (!isModelLoaded()) return QString(); - const char *name = m_llModelInfo.model->gpuDeviceName(); + auto *lcppmodel = dynamic_cast(m_llModelInfo.model.get()); + if (!isModelLoaded() || !lcppmodel) return QString(); + const char *name = lcppmodel->gpuDeviceName(); return name ? QString(name) : u"CPU"_s; } @@ -147,55 +141,25 @@ class ChatLLM : public QObject return m_llModelInfo.fallbackReason.value_or(u""_s); } - QString generatedName() const { return QString::fromStdString(m_nameResponse); } - - bool serialize(QDataStream &stream, int version, bool serializeKV); - bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); - void setStateFromText(const QVector> &stateFromText) { m_stateFromText = stateFromText; } + bool serialize(QDataStream &stream, int version, bool serializeKV) override; + bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override; + void setStateFromText(const QVector> &stateFromText) override { m_stateFromText = stateFromText; } public Q_SLOTS: - bool prompt(const QList &collectionList, const QString &prompt); - bool loadDefaultModel(); - void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); - bool loadModel(const ModelInfo &modelInfo); - void modelChangeRequested(const ModelInfo &modelInfo); - void unloadModel(); - void reloadModel(); - void generateName(); - void generateQuestions(qint64 elapsed); - void handleChatIdChanged(const QString &id); - void handleShouldBeLoadedChanged(); - void handleThreadStarted(); - void handleForceMetalChanged(bool forceMetal); - void handleDeviceChanged(); - void processSystemPrompt(); - void processRestoreStateFromText(); + bool prompt(const QList &collectionList, const QString &prompt) override; + bool loadModel(const ModelInfo &modelInfo) override; + void modelChangeRequested(const ModelInfo &modelInfo) override; + void generateName() override; + void processSystemPrompt() override; Q_SIGNALS: - void restoringFromTextChanged(); - void loadedModelInfoChanged(); - void modelLoadingPercentageChanged(float); - void modelLoadingError(const QString &error); - void modelLoadingWarning(const QString &warning); - void responseChanged(const QString &response); - void promptProcessing(); - void generatingQuestions(); - void responseStopped(qint64 promptResponseMs); - void generatedNameChanged(const QString &name); - void generatedQuestionFinished(const QString &generatedQuestion); - void stateChanged(); - void threadStarted(); - void shouldBeLoadedChanged(); - void trySwitchContextRequested(const ModelInfo &modelInfo); - void trySwitchContextOfLoadedModelCompleted(int value); - void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); - void reportSpeed(const QString &speed); - void reportDevice(const QString &device); - void reportFallbackReason(const QString &fallbackReason); - void databaseResultsChanged(const QList&); - void modelInfoChanged(const ModelInfo &modelInfo); + void requestLoadModel(bool reload); + void requestReleaseModel(bool unload); protected: + bool isModelLoaded() const; + void acquireModel(); + void resetModel(); bool promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t repeat_penalty_tokens); @@ -212,14 +176,33 @@ public Q_SLOTS: void saveState(); void restoreState(); + // used by Server class + ModelInfo modelInfo() const { return m_modelInfo; } + QString response() const; + QString generatedName() const { return QString::fromStdString(m_nameResponse); } + +protected Q_SLOTS: + void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); + void loadModel(bool reload = false); + void releaseModel(bool unload = false); + void generateQuestions(qint64 elapsed); + void handleChatIdChanged(const QString &id); + void handleThreadStarted(); + void handleForceMetalChanged(bool forceMetal); + void handleDeviceChanged(); + void processRestoreStateFromText(); + +private: + bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); + protected: - LLModel::PromptContext m_ctx; + // used by Server quint32 m_promptTokens; quint32 m_promptResponseTokens; + std::atomic m_shouldBeLoaded; private: - bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); - + ModelBackend::PromptContext m_ctx; std::string m_response; std::string m_nameResponse; QString m_questionResponse; @@ -230,9 +213,7 @@ public Q_SLOTS: QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; - std::atomic m_shouldBeLoaded; std::atomic m_restoringFromText; // status indication - std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; bool m_isServer; bool m_forceMetal; @@ -240,10 +221,8 @@ public Q_SLOTS: bool m_processedSystemPrompt; bool m_restoreStateFromText; // m_pristineLoadedState is set if saveSate is unnecessary, either because: - // - an unload was queued during LLModel::restoreState() + // - an unload was queued during ModelBackend::restoreState() // - the chat will be restored from text and hasn't been interacted with yet bool m_pristineLoadedState = false; QVector> m_stateFromText; }; - -#endif // CHATLLM_H diff --git a/gpt4all-chat/llm.cpp b/gpt4all-chat/llm.cpp index 13820030393e..03679030f66b 100644 --- a/gpt4all-chat/llm.cpp +++ b/gpt4all-chat/llm.cpp @@ -1,6 +1,6 @@ #include "llm.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include "../gpt4all-backend/sysinfo.h" #include @@ -30,7 +30,7 @@ LLM *LLM::globalInstance() LLM::LLM() : QObject{nullptr} - , m_compatHardware(LLModel::Implementation::hasSupportedCPU()) + , m_compatHardware(LlamaCppBackendManager::hasSupportedCPU()) { QNetworkInformation::loadDefaultBackend(); auto * netinfo = QNetworkInformation::instance(); diff --git a/gpt4all-chat/llmodel.cpp b/gpt4all-chat/llmodel.cpp new file mode 100644 index 000000000000..4b6ffddbe3a6 --- /dev/null +++ b/gpt4all-chat/llmodel.cpp @@ -0,0 +1,34 @@ +#include "llmodel.h" + +#include +#include +#include + + +std::string remove_leading_whitespace(const std::string &input) +{ + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + if (first_non_whitespace == input.end()) + return std::string(); + + return std::string(first_non_whitespace, input.end()); +} + +std::string trim_whitespace(const std::string &input) +{ + auto first_non_whitespace = std::find_if(input.begin(), input.end(), [](unsigned char c) { + return !std::isspace(c); + }); + + if (first_non_whitespace == input.end()) + return std::string(); + + auto last_non_whitespace = std::find_if(input.rbegin(), input.rend(), [](unsigned char c) { + return !std::isspace(c); + }).base(); + + return std::string(first_non_whitespace, last_non_whitespace); +} diff --git a/gpt4all-chat/llmodel.h b/gpt4all-chat/llmodel.h new file mode 100644 index 000000000000..e2f915d97e8c --- /dev/null +++ b/gpt4all-chat/llmodel.h @@ -0,0 +1,78 @@ +#pragma once + +#include "database.h" // IWYU pragma: keep +#include "modellist.h" // IWYU pragma: keep + +#include +#include +#include +#include +#include + +class Chat; +class QDataStream; + +class LLModel : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) + +protected: + LLModel() = default; + +public: + virtual ~LLModel() = default; + + virtual void destroy() {} + virtual void regenerateResponse() = 0; + virtual void resetResponse() = 0; + virtual void resetContext() = 0; + + virtual void stopGenerating() = 0; + + virtual void loadModelAsync(bool reload = false) = 0; + virtual void releaseModelAsync(bool unload = false) = 0; + virtual void requestTrySwitchContext() = 0; + virtual void setMarkedForDeletion(bool b) = 0; + + virtual void setModelInfo(const ModelInfo &info) = 0; + + virtual bool restoringFromText() const = 0; + + virtual bool serialize(QDataStream &stream, int version, bool serializeKV) = 0; + virtual bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) = 0; + virtual void setStateFromText(const QVector> &stateFromText) = 0; + +public Q_SLOTS: + virtual bool prompt(const QList &collectionList, const QString &prompt) = 0; + virtual bool loadModel(const ModelInfo &modelInfo) = 0; + virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0; + virtual void generateName() = 0; + virtual void processSystemPrompt() = 0; + +Q_SIGNALS: + void restoringFromTextChanged(); + void loadedModelInfoChanged(); + void modelLoadingPercentageChanged(float loadingPercentage); + void modelLoadingError(const QString &error); + void modelLoadingWarning(const QString &warning); + void responseChanged(const QString &response); + void promptProcessing(); + void generatingQuestions(); + void responseStopped(qint64 promptResponseMs); + void generatedNameChanged(const QString &name); + void generatedQuestionFinished(const QString &generatedQuestion); + void stateChanged(); + void threadStarted(); + void trySwitchContextRequested(const ModelInfo &modelInfo); + void trySwitchContextOfLoadedModelCompleted(int value); + void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void reportSpeed(const QString &speed); + void reportDevice(const QString &device); + void reportFallbackReason(const QString &fallbackReason); + void databaseResultsChanged(const QList &results); + void modelInfoChanged(const ModelInfo &modelInfo); +}; + +std::string remove_leading_whitespace(const std::string &input); +std::string trim_whitespace(const std::string &input); diff --git a/gpt4all-chat/main.cpp b/gpt4all-chat/main.cpp index 4546a95bcf32..9f848a2f70be 100644 --- a/gpt4all-chat/main.cpp +++ b/gpt4all-chat/main.cpp @@ -8,7 +8,7 @@ #include "mysettings.h" #include "network.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include #include @@ -46,7 +46,7 @@ int main(int argc, char *argv[]) if (LLM::directoryExists(frameworksDir)) llmodelSearchPaths += ";" + frameworksDir; #endif - LLModel::Implementation::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); + LlamaCppBackendManager::setImplementationsSearchPath(llmodelSearchPaths.toStdString()); // Set the local and language translation before the qml engine has even been started. This will // use the default system locale unless the user has explicitly set it to use a different one. @@ -87,7 +87,7 @@ int main(int argc, char *argv[]) int res = app.exec(); - // Make sure ChatLLM threads are joined before global destructors run. + // Make sure LlamaCppModel threads are joined before global destructors run. // Otherwise, we can get a heap-use-after-free inside of llama.cpp. ChatListModel::globalInstance()->destroyChats(); diff --git a/gpt4all-chat/modellist.cpp b/gpt4all-chat/modellist.cpp index 580b615ff4e6..9564bcda2baa 100644 --- a/gpt4all-chat/modellist.cpp +++ b/gpt4all-chat/modellist.cpp @@ -4,7 +4,7 @@ #include "mysettings.h" #include "network.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include #include @@ -36,6 +36,8 @@ #include #include #include +#include +#include #include #include @@ -43,8 +45,33 @@ using namespace Qt::Literals::StringLiterals; //#define USE_LOCAL_MODELSJSON + static const QStringList FILENAME_BLACKLIST { u"gpt4all-nomic-embed-text-v1.rmodel"_s }; +// Maps "type" of current .rmodel format to a provider. +static const QHash RMODEL_TYPES { + { u"openai"_s, ModelInfo::Provider::OpenAI }, + { u"mistral"_s, ModelInfo::Provider::Mistral }, + { u"openai-generic"_s, ModelInfo::Provider::OpenAIGeneric }, +}; + +// For backwards compatbility only. Do not add to this list. +static const QHash BUILTIN_RMODEL_FILENAMES { + { u"gpt4all-gpt-3.5-turbo.rmodel"_s, ModelInfo::Provider::OpenAI }, + { u"gpt4all-gpt-4.rmodel"_s, ModelInfo::Provider::OpenAI }, + { u"gpt4all-mistral-tiny.rmodel"_s, ModelInfo::Provider::Mistral }, + { u"gpt4all-mistral-small.rmodel"_s, ModelInfo::Provider::Mistral }, + { u"gpt4all-mistral-medium.rmodel"_s, ModelInfo::Provider::Mistral }, +}; + +static ModelInfo::Provider getBuiltinRmodelFilename(const QString &filename) +{ + auto provider = BUILTIN_RMODEL_FILENAMES.value(filename, ModelInfo::INVALID_PROVIDER); + if (provider == ModelInfo::INVALID_PROVIDER) + throw std::invalid_arugment("unrecognized rmodel filename: " + filename.toStdString()); + return provider; +} + QString ModelInfo::id() const { return m_id; @@ -258,7 +285,7 @@ int ModelInfo::maxContextLength() const if (!installed || isOnline) return -1; if (m_maxContextLength != -1) return m_maxContextLength; auto path = (dirpath + filename()).toStdString(); - int n_ctx = LLModel::Implementation::maxContextLength(path); + int n_ctx = LlamaCppBackendManager::maxContextLength(path); if (n_ctx < 0) { n_ctx = 4096; // fallback value } @@ -282,7 +309,7 @@ int ModelInfo::maxGpuLayers() const if (!installed || isOnline) return -1; if (m_maxGpuLayers != -1) return m_maxGpuLayers; auto path = (dirpath + filename()).toStdString(); - int layers = LLModel::Implementation::layerCount(path); + int layers = LlamaCppBackendManager::layerCount(path); if (layers < 0) { layers = 100; // fallback value } @@ -514,16 +541,18 @@ ModelList::ModelList() QCoreApplication::instance()->installEventFilter(this); } -QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) { +QString ModelList::compatibleModelNameHash(QUrl baseUrl, QString modelName) +{ QCryptographicHash sha256(QCryptographicHash::Sha256); sha256.addData((baseUrl.toString() + "_" + modelName).toUtf8()); return sha256.result().toHex(); -}; +} -QString ModelList::compatibleModelFilename(QUrl baseUrl, QString modelName) { +QString ModelList::compatibleModelFilename(QUrl baseUrl, QString modelName) +{ QString hash(compatibleModelNameHash(baseUrl, modelName)); return QString(u"gpt4all-%1-capi.rmodel"_s).arg(hash); -}; +} bool ModelList::eventFilter(QObject *obj, QEvent *ev) { @@ -692,6 +721,8 @@ int ModelList::rowCount(const QModelIndex &parent) const QVariant ModelList::dataInternal(const ModelInfo *info, int role) const { switch (role) { + case ProviderRole: + return info->provider(); case IdRole: return info->id(); case NameRole: @@ -701,7 +732,7 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const case DirpathRole: return info->dirpath; case FilesizeRole: - return info->filesize; + return info->filesize(); case HashRole: return info->hash; case HashAlgorithmRole: @@ -712,10 +743,6 @@ QVariant ModelList::dataInternal(const ModelInfo *info, int role) const return info->installed; case DefaultRole: return info->isDefault; - case OnlineRole: - return info->isOnline; - case CompatibleApiRole: - return info->isCompatibleApi; case DescriptionRole: return info->description(); case RequiresVersionRole: @@ -844,6 +871,8 @@ void ModelList::updateData(const QString &id, const QVector const int role = d.first; const QVariant value = d.second; switch (role) { + case ProviderRole: + info->m_provider = value.value(); case IdRole: { if (info->id() != value.toString()) { @@ -859,21 +888,17 @@ void ModelList::updateData(const QString &id, const QVector case DirpathRole: info->dirpath = value.toString(); break; case FilesizeRole: - info->filesize = value.toString(); break; + info->m_filesize = value.toULongLong(); break; case HashRole: info->hash = value.toByteArray(); break; case HashAlgorithmRole: - info->hashAlgorithm = static_cast(value.toInt()); break; + info->hashAlgorithm = value.value(); break; case CalcHashRole: info->calcHash = value.toBool(); break; case InstalledRole: info->installed = value.toBool(); break; case DefaultRole: info->isDefault = value.toBool(); break; - case OnlineRole: - info->isOnline = value.toBool(); break; - case CompatibleApiRole: - info->isCompatibleApi = value.toBool(); break; case DescriptionRole: info->setDescription(value.toString()); break; case RequiresVersionRole: @@ -997,7 +1022,7 @@ void ModelList::updateData(const QString &id, const QVector && (info->isDiscovered() || info->description().isEmpty())) { // read GGUF and decide based on model architecture - info->isEmbeddingModel = LLModel::Implementation::isEmbeddingModel(modelPath.toStdString()); + info->isEmbeddingModel = LlamaCppBackendManager::isEmbeddingModel(modelPath.toStdString()); info->checkedEmbeddingModel = true; } @@ -1088,13 +1113,12 @@ QString ModelList::clone(const ModelInfo &model) addModel(id); QVector> data { + { ModelList::ProviderRole, model.provider }, { ModelList::InstalledRole, model.installed }, { ModelList::IsCloneRole, true }, { ModelList::NameRole, uniqueModelName(model) }, { ModelList::FilenameRole, model.filename() }, { ModelList::DirpathRole, model.dirpath }, - { ModelList::OnlineRole, model.isOnline }, - { ModelList::CompatibleApiRole, model.isCompatibleApi }, { ModelList::IsEmbeddingModelRole, model.isEmbeddingModel }, { ModelList::TemperatureRole, model.temperature() }, { ModelList::TopPRole, model.topP() }, @@ -1127,9 +1151,9 @@ void ModelList::removeClone(const ModelInfo &model) void ModelList::removeInstalled(const ModelInfo &model) { + Q_ASSERT(model.provider == ModelInfo::Provider::LlamaCpp || model.provider == ModelInfo::Provider::OpenAIGeneric); Q_ASSERT(model.installed); Q_ASSERT(!model.isClone()); - Q_ASSERT(model.isDiscovered() || model.isCompatibleApi || model.description() == "" /*indicates sideloaded*/); removeInternal(model); emit layoutChanged(); } @@ -1208,132 +1232,183 @@ bool ModelList::modelExists(const QString &modelFilename) const return false; } -void ModelList::updateModelsFromDirectory() +static void updateOldRemoteModels(const QString &path) { - const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); - const QString localPath = MySettings::globalInstance()->modelPath(); + QDirIterator it(path, QDir::Files, QDirIterator::Subdirectories); + while (it.hasNext()) { + QFileInfo info = it.nextFileInfo(); + QString filename = info.fileName(); + if (!filename.startsWith("chatgpt-") || !filename.endsWith(".txt")) + continue; - auto updateOldRemoteModels = [&](const QString& path) { - QDirIterator it(path, QDirIterator::Subdirectories); - while (it.hasNext()) { - it.next(); - if (!it.fileInfo().isDir()) { - QString filename = it.fileName(); - if (filename.startsWith("chatgpt-") && filename.endsWith(".txt")) { - QString apikey; - QString modelname(filename); - modelname.chop(4); // strip ".txt" extension - modelname.remove(0, 8); // strip "chatgpt-" prefix - QFile file(path + filename); - if (file.open(QIODevice::ReadWrite)) { - QTextStream in(&file); - apikey = in.readAll(); - file.close(); - } + QString apikey; + QString modelname(filename); + modelname.chop(4); // strip ".txt" extension + modelname.remove(0, 8); // strip "chatgpt-" prefix + QFile file(info.filePath()); + if (!file.open(QFile::ReadOnly)) { + qWarning(tr("cannot open \"%s\": %s"), info.filePath(), file.errorString()); + continue; + } - QJsonObject obj; - obj.insert("apiKey", apikey); - obj.insert("modelName", modelname); - QJsonDocument doc(obj); - - auto newfilename = u"gpt4all-%1.rmodel"_s.arg(modelname); - QFile newfile(path + newfilename); - if (newfile.open(QIODevice::ReadWrite)) { - QTextStream out(&newfile); - out << doc.toJson(); - newfile.close(); - } - file.remove(); - } - } + { + QTextStream in(&file); + apikey = in.readAll(); + file.close(); } - }; - auto processDirectory = [&](const QString& path) { - QDirIterator it(path, QDir::Files, QDirIterator::Subdirectories); - while (it.hasNext()) { - it.next(); + QJsonObject obj { + { "type", "openai" }, + { "apiKey", apikey }, + { "modelName", modelname }, + }; - QString filename = it.fileName(); - if (filename.startsWith("incomplete") || FILENAME_BLACKLIST.contains(filename)) - continue; - if (!filename.endsWith(".gguf") && !filename.endsWith(".rmodel")) - continue; + QFile newfile(u"%1/gpt4all-%2.rmodel"_s.arg(info.dir().path(), modelname)); + if (!newfile.open(QFile::ReadWrite)) { + qWarning(tr("cannot create \"%s\": %s"), newfile.fileName(), file.errorString()); + continue; + } - QVector modelsById; - { - QMutexLocker locker(&m_mutex); - for (ModelInfo *info : m_models) - if (info->filename() == filename) - modelsById.append(info->id()); - } + QTextStream out(&newfile); + out << QJsonDocument(obj).toJson(); + newfile.close(); + file.remove(); + } +} - if (modelsById.isEmpty()) { - if (!contains(filename)) - addModel(filename); - modelsById.append(filename); - } +[[nodiscard]] +static bool parseRemoteModel(QVector> &props, const QFileInfo &info) +{ + QJsonObject remoteModel; + { + QFile file(info.filePath()); + if (!file.open(QFile::ReadOnly)) { + qWarning(tr("cannot open \"%s\": %s"), info.filePath(), file.errorString()); + return false; + } + QJsonDocument doc = QJsonDocument::fromJson(file.readAll()); + remoteModel = doc.object(); + } - QFileInfo info = it.fileInfo(); + ModelInfo::Provider provider; + QString remoteModelName, remoteApiKey; + { + const auto INVALID = ModelInfo::INVALID_PROVIDER; - bool isOnline(filename.endsWith(".rmodel")); - bool isCompatibleApi(filename.endsWith("-capi.rmodel")); + std::optional providerJson; + if (auto type = remoteModel["type"]; type.type() != QJsonValue::Unknown) + providerJson.reset(RMODEL_TYPES.value(type, INVALID)); - QString name; - QString description; - if (isCompatibleApi) { - QJsonObject obj; - { - QFile file(path + filename); - bool success = file.open(QIODeviceBase::ReadOnly); - (void)success; - Q_ASSERT(success); - QJsonDocument doc = QJsonDocument::fromJson(file.readAll()); - obj = doc.object(); - } - { - QString apiKey(obj["apiKey"].toString()); - QString baseUrl(obj["baseUrl"].toString()); - QString modelName(obj["modelName"].toString()); - apiKey = apiKey.length() < 10 ? "*****" : apiKey.left(5) + "*****"; - name = tr("%1 (%2)").arg(modelName, baseUrl); - description = tr("OpenAI-Compatible API Model
" - "
  • API Key: %1
  • " - "
  • Base URL: %2
  • " - "
  • Model Name: %3
") - .arg(apiKey, baseUrl, modelName); - } - } + auto apiKey = remoteModel["apiKey"]; + auto modelName = remoteModel["modelName"]; + if (modelName.type() != QJsonValue::String || apiKey.type() != QJsonValue::String || providerJson == INVALID) { + qWarning(tr("bad rmodel \"%s\": unrecognized format"), info.filePath()); + return false; + } + remoteModelName = modelName.toString(); + remoteApiKey = apiKey.toString(); - for (const QString &id : modelsById) { - QVector> data { - { InstalledRole, true }, - { FilenameRole, filename }, - { OnlineRole, isOnline }, - { CompatibleApiRole, isCompatibleApi }, - { DirpathRole, info.dir().absolutePath() + "/" }, - { FilesizeRole, toFileSize(info.size()) }, - }; - if (isCompatibleApi) { - // The data will be saved to "GPT4All.ini". - data.append({ NameRole, name }); - // The description is hard-coded into "GPT4All.ini" due to performance issue. - // If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList. - data.append({ DescriptionRole, description }); - // Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server. - data.append({ PromptTemplateRole, "%1" }); - } - updateData(id, data); - } + if (providerJson) { + provider = providerJson.value(); + } else if (auto builtin = BUILTIN_RMODEL_FILENAMES.value(filename, INVALID); builtin != INVALID) { + provider = builtin; + } else { + goto bad_data; } + } + + QString name; + QString description; + if (provider == ModelInfo::Provider::OpenAIGeneric) { + auto baseUrl = remoteModel["baseUrl"]; + if (baseUrl.type() != QJsonValue::String) + goto bad_data; + + QString apiKey = remoteApiKey; + apiKey = apiKey.length() < 10 ? "*****" : apiKey.left(5) + "*****"; + QString baseUrl(remoteModel["baseUrl"].toString()); + name = tr("%1 (%2)").arg(remoteModelName, baseUrl); + description = tr("OpenAI-Compatible API Model
" + "
  • API Key: %1
  • " + "
  • Base URL: %2
  • " + "
  • Model Name: %3
") + .arg(apiKey, baseUrl, remoteModelName); + + // The description is hard-coded into "GPT4All.ini" due to performance issue. + // If the description goes to be dynamic from its .rmodel file, it will get high I/O usage while using the ModelList. + props << QVector> { + { NameRole, name }, + { DescriptionRole, description }, + // Prompt template should be clear while using ChatML format which is using in most of OpenAI-Compatible API server. + { PromptTemplateRole, "%1" }, + }; + } + + props << QVector> { + { ProviderRole, provider }, }; + return true; + +bad_data: + qWarning(tr("bad rmodel \"%s\": unrecognized data"), info.filePath()); + return false; +} + +void ModelList::processModelDirectory(const QString &path) +{ + QDirIterator it(path, QDir::Files, QDirIterator::Subdirectories); + while (it.hasNext()) { + QFileInfo info = it.nextFileInfo(); + QString filename = info.fileName(); + if (filename.startsWith("incomplete") || FILENAME_BLACKLIST.contains(filename)) + continue; + if (!filename.endsWith(".gguf") && !filename.endsWith(".rmodel")) + continue; + + QVector> props; + if (!filename.endswith(".rmodel")) { + props.emplaceBack(ProviderRole, ModelInfo::Provider::LlamaCpp); + } else if (!parseRemoteModel(props, info)) + continue; + + QVector modelsById; + { + QMutexLocker locker(&m_mutex); + for (ModelInfo *info : m_models) + if (info->filename() == filename) + modelsById.append(info->id()); + } + + if (modelsById.isEmpty()) { + if (!contains(filename)) + addModel(filename); + modelsById.append(filename); + } + + for (const QString &id : modelsById) { + props << QVector> { + { ProviderRole, provider }, + { InstalledRole, true }, + { FilenameRole, filename }, + { DirpathRole, info.dir().absolutePath() + "/" }, + { FilesizeRole, info.size() }, + }; + updateData(id, props); + } + } +} + +void ModelList::updateModelsFromDirectory() +{ + const QString exePath = QCoreApplication::applicationDirPath() + QDir::separator(); + const QString localPath = MySettings::globalInstance()->modelPath(); updateOldRemoteModels(exePath); - processDirectory(exePath); + processModelDirectory(exePath); if (localPath != exePath) { updateOldRemoteModels(localPath); - processDirectory(localPath); + processModelDirectory(localPath); } } @@ -1500,8 +1575,6 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!versionRemoved.isEmpty() && Download::compareAppVersions(versionRemoved, currentVersion) <= 0) continue; - modelFilesize = ModelList::toFileSize(modelFilesize.toULongLong()); - const QString id = modelName; Q_ASSERT(!id.isEmpty()); @@ -1514,7 +1587,7 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) QVector> data { { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, modelFilesize }, + { ModelList::FilesizeRole, modelFilesize.toULongLong() }, { ModelList::HashRole, modelHash }, { ModelList::HashAlgorithmRole, ModelInfo::Md5 }, { ModelList::DefaultRole, isDefault }, @@ -1570,9 +1643,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, BUILTIN_RMODEL_FILENAMES. }, { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, + { ModelList::FilesizeRole, 0 }, { ModelList::OnlineRole, true }, { ModelList::DescriptionRole, tr("OpenAI's ChatGPT model GPT-3.5 Turbo
%1").arg(chatGPTDesc) }, @@ -1598,9 +1672,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, getBuiltinRmodelFilename(modelFilename) }, { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, + { ModelList::FilesizeRole, 0 }, { ModelList::OnlineRole, true }, { ModelList::DescriptionRole, tr("OpenAI's ChatGPT model GPT-4
%1 %2").arg(chatGPTDesc).arg(chatGPT4Warn) }, @@ -1629,9 +1704,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, getBuiltinRmodelFilename(modelFilename) }, { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, + { ModelList::FilesizeRole, 0 }, { ModelList::OnlineRole, true }, { ModelList::DescriptionRole, tr("Mistral Tiny model
%1").arg(mistralDesc) }, @@ -1654,9 +1730,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, getBuiltinRmodelFilename(modelFilename) }, { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, + { ModelList::FilesizeRole, 0 }, { ModelList::OnlineRole, true }, { ModelList::DescriptionRole, tr("Mistral Small model
%1").arg(mistralDesc) }, @@ -1680,10 +1757,10 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, getBuiltinRmodelFilename(modelFilename) }, { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, "minimal" }, - { ModelList::OnlineRole, true }, + { ModelList::FilesizeRole, 0 }, { ModelList::DescriptionRole, tr("Mistral Medium model
%1").arg(mistralDesc) }, { ModelList::RequiresVersionRole, "2.7.4" }, @@ -1709,10 +1786,9 @@ void ModelList::parseModelsJsonFile(const QByteArray &jsonData, bool save) if (!contains(id)) addModel(id); QVector> data { + { ModelList::ProviderRole, ModelInfo::Provider::OpenAIGeneric }, { ModelList::NameRole, modelName }, - { ModelList::FilesizeRole, "minimal" }, - { ModelList::OnlineRole, true }, - { ModelList::CompatibleApiRole, true }, + { ModelList::FilesizeRole, 0 }, { ModelList::DescriptionRole, tr("Connect to OpenAI-compatible API server
%1").arg(compatibleDesc) }, { ModelList::RequiresVersionRole, "2.7.4" }, @@ -2126,7 +2202,6 @@ void ModelList::handleDiscoveryItemFinished() // QString locationHeader = reply->header(QNetworkRequest::LocationHeader).toString(); QString modelFilename = reply->request().attribute(QNetworkRequest::UserMax).toString(); - QString modelFilesize = ModelList::toFileSize(QString(linkedSizeHeader).toULongLong()); QString description = tr("Created by %1.
    " "
  • Published on %2." @@ -2151,7 +2226,7 @@ void ModelList::handleDiscoveryItemFinished() QVector> data { { ModelList::NameRole, modelName }, { ModelList::FilenameRole, modelFilename }, - { ModelList::FilesizeRole, modelFilesize }, + { ModelList::FilesizeRole, QString(linkedSizeHeader).toULongLong() }, { ModelList::DescriptionRole, description }, { ModelList::IsDiscoveredRole, true }, { ModelList::UrlRole, reply->request().url() }, diff --git a/gpt4all-chat/modellist.h b/gpt4all-chat/modellist.h index 7c13da8ef4fd..2515110f4f65 100644 --- a/gpt4all-chat/modellist.h +++ b/gpt4all-chat/modellist.h @@ -22,20 +22,21 @@ using namespace Qt::Literals::StringLiterals; + struct ModelInfo { Q_GADGET + Q_PROPERTY(Provider provider READ provider) Q_PROPERTY(QString id READ id WRITE setId) Q_PROPERTY(QString name READ name WRITE setName) Q_PROPERTY(QString filename READ filename WRITE setFilename) Q_PROPERTY(QString dirpath MEMBER dirpath) - Q_PROPERTY(QString filesize MEMBER filesize) + Q_PROPERTY(QString filesize READ filesize) Q_PROPERTY(QByteArray hash MEMBER hash) Q_PROPERTY(HashAlgorithm hashAlgorithm MEMBER hashAlgorithm) Q_PROPERTY(bool calcHash MEMBER calcHash) Q_PROPERTY(bool installed MEMBER installed) Q_PROPERTY(bool isDefault MEMBER isDefault) - Q_PROPERTY(bool isOnline MEMBER isOnline) - Q_PROPERTY(bool isCompatibleApi MEMBER isCompatibleApi) + Q_PROPERTY(bool isOnline READ isOnline) Q_PROPERTY(QString description READ description WRITE setDescription) Q_PROPERTY(QString requiresVersion MEMBER requiresVersion) Q_PROPERTY(QString versionRemoved MEMBER versionRemoved) @@ -76,10 +77,27 @@ struct ModelInfo { Q_PROPERTY(QDateTime recency READ recency WRITE setRecency) public: - enum HashAlgorithm { + enum class Provider { + LlamaCpp, + // Pre-configured model from openai.com or mistral.ai + OpenAI, + Mistral, + // Model with a custom endpoint configured by the user (stored in *-capi.rmodel) + OpenAIGeneric, + }; + Q_ENUM(Provider) + + // Not a valid member of the Provider enum. Used as a sentinel with Qt containers. + static constexpr Provider INVALID_PROVIDER = Provider(-1); + + enum class HashAlgorithm { Md5, Sha256 }; + Q_ENUM(HashAlgorithm) + + Provider provider() const { return m_provider; } + bool isOnline() const { return m_provider != Provider::LlamaCpp; } QString id() const; void setId(const QString &id); @@ -90,9 +108,27 @@ struct ModelInfo { QString filename() const; void setFilename(const QString &name); + QString filesize() const + { + qsizetype sz = m_filesize; + if (!sz) + return u"minimal"_s; + if (sz < 1024) + return u"%1 bytes"_s.arg(sz); + if (sz < 1024 * 1024) + return u"%1 KB"_s.arg(qreal(sz) / 1024, 0, 'g', 3); + if (sz < 1024 * 1024 * 1024) + return u"%1 MB"_s.arg(qreal(sz) / (1024 * 1024), 0, 'g', 3); + + return u"%1 GB"_s.arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); + } + QString description() const; void setDescription(const QString &d); + /* For built-in OpenAI-compatible models, this is the full completions endpoint URL. + * For custom OpenAI-compatible models (Provider::OpenAIGeneric), this is not set. + * For discovered models (isDiscovered), this is the resolved URL of the GGUF file. */ QString url() const; void setUrl(const QString &u); @@ -118,23 +154,11 @@ struct ModelInfo { void setRecency(const QDateTime &r); QString dirpath; - QString filesize; QByteArray hash; HashAlgorithm hashAlgorithm; bool calcHash = false; bool installed = false; bool isDefault = false; - // Differences between 'isOnline' and 'isCompatibleApi' in ModelInfo: - // 'isOnline': - // - Indicates whether this is a online model. - // - Linked with the ModelList, fetching info from it. - bool isOnline = false; - // 'isCompatibleApi': - // - Indicates whether the model is using the OpenAI-compatible API which user custom. - // - When the property is true, 'isOnline' should also be true. - // - Does not link to the ModelList directly; instead, fetches info from the *-capi.rmodel file and works standalone. - // - Still needs to copy data from gpt4all.ini and *-capi.rmodel to the ModelList in memory while application getting started(as custom .gguf models do). - bool isCompatibleApi = false; QString requiresVersion; QString versionRemoved; qint64 bytesReceived = 0; @@ -150,9 +174,7 @@ struct ModelInfo { bool isEmbeddingModel = false; bool checkedEmbeddingModel = false; - bool operator==(const ModelInfo &other) const { - return m_id == other.m_id; - } + bool operator==(const ModelInfo &other) const { return m_id == other.m_id; } double temperature() const; void setTemperature(double t); @@ -190,9 +212,11 @@ struct ModelInfo { private: QVariantMap getFields() const; + Provider m_provider; QString m_id; QString m_name; QString m_filename; + qsizetype m_filesize; QString m_description; QString m_url; QString m_quant; @@ -209,9 +233,9 @@ struct ModelInfo { int m_maxLength = 4096; int m_promptBatchSize = 128; int m_contextLength = 2048; - mutable int m_maxContextLength = -1; + mutable int m_maxContextLength = -1; // cache int m_gpuLayers = 100; - mutable int m_maxGpuLayers = -1; + mutable int m_maxGpuLayers = -1; // cache double m_repeatPenalty = 1.18; int m_repeatPenaltyTokens = 64; QString m_promptTemplate = "### Human:\n%1\n\n### Assistant:\n"; @@ -219,6 +243,7 @@ struct ModelInfo { QString m_chatNamePrompt = "Describe the above conversation in seven words or less."; QString m_suggestedFollowUpPrompt = "Suggest three very short factual follow-up questions that have not been answered yet or cannot be found inspired by the previous conversation and excerpts."; friend class MySettings; + friend class ModelList; }; Q_DECLARE_METATYPE(ModelInfo) @@ -350,55 +375,55 @@ class ModelList : public QAbstractListModel QHash roleNames() const override { - QHash roles; - roles[IdRole] = "id"; - roles[NameRole] = "name"; - roles[FilenameRole] = "filename"; - roles[DirpathRole] = "dirpath"; - roles[FilesizeRole] = "filesize"; - roles[HashRole] = "hash"; - roles[HashAlgorithmRole] = "hashAlgorithm"; - roles[CalcHashRole] = "calcHash"; - roles[InstalledRole] = "installed"; - roles[DefaultRole] = "isDefault"; - roles[OnlineRole] = "isOnline"; - roles[CompatibleApiRole] = "isCompatibleApi"; - roles[DescriptionRole] = "description"; - roles[RequiresVersionRole] = "requiresVersion"; - roles[VersionRemovedRole] = "versionRemoved"; - roles[UrlRole] = "url"; - roles[BytesReceivedRole] = "bytesReceived"; - roles[BytesTotalRole] = "bytesTotal"; - roles[TimestampRole] = "timestamp"; - roles[SpeedRole] = "speed"; - roles[DownloadingRole] = "isDownloading"; - roles[IncompleteRole] = "isIncomplete"; - roles[DownloadErrorRole] = "downloadError"; - roles[OrderRole] = "order"; - roles[RamrequiredRole] = "ramrequired"; - roles[ParametersRole] = "parameters"; - roles[QuantRole] = "quant"; - roles[TypeRole] = "type"; - roles[IsCloneRole] = "isClone"; - roles[IsDiscoveredRole] = "isDiscovered"; - roles[IsEmbeddingModelRole] = "isEmbeddingModel"; - roles[TemperatureRole] = "temperature"; - roles[TopPRole] = "topP"; - roles[MinPRole] = "minP"; - roles[TopKRole] = "topK"; - roles[MaxLengthRole] = "maxLength"; - roles[PromptBatchSizeRole] = "promptBatchSize"; - roles[ContextLengthRole] = "contextLength"; - roles[GpuLayersRole] = "gpuLayers"; - roles[RepeatPenaltyRole] = "repeatPenalty"; - roles[RepeatPenaltyTokensRole] = "repeatPenaltyTokens"; - roles[PromptTemplateRole] = "promptTemplate"; - roles[SystemPromptRole] = "systemPrompt"; - roles[ChatNamePromptRole] = "chatNamePrompt"; - roles[SuggestedFollowUpPromptRole] = "suggestedFollowUpPrompt"; - roles[LikesRole] = "likes"; - roles[DownloadsRole] = "downloads"; - roles[RecencyRole] = "recency"; + static const QHash roles { + { ProviderRole, "provider" }, + { IdRole, "id" }, + { NameRole, "name" }, + { FilenameRole, "filename" }, + { DirpathRole, "dirpath" }, + { FilesizeRole, "filesize" }, + { HashRole, "hash" }, + { HashAlgorithmRole, "hashAlgorithm" }, + { CalcHashRole, "calcHash" }, + { InstalledRole, "installed" }, + { DefaultRole, "isDefault" }, + { DescriptionRole, "description" }, + { RequiresVersionRole, "requiresVersion" }, + { VersionRemovedRole, "versionRemoved" }, + { UrlRole, "url" }, + { BytesReceivedRole, "bytesReceived" }, + { BytesTotalRole, "bytesTotal" }, + { TimestampRole, "timestamp" }, + { SpeedRole, "speed" }, + { DownloadingRole, "isDownloading" }, + { IncompleteRole, "isIncomplete" }, + { DownloadErrorRole, "downloadError" }, + { OrderRole, "order" }, + { RamrequiredRole, "ramrequired" }, + { ParametersRole, "parameters" }, + { QuantRole, "quant" }, + { TypeRole, "type" }, + { IsCloneRole, "isClone" }, + { IsDiscoveredRole, "isDiscovered" }, + { IsEmbeddingModelRole, "isEmbeddingModel" }, + { TemperatureRole, "temperature" }, + { TopPRole, "topP" }, + { MinPRole, "minP" }, + { TopKRole, "topK" }, + { MaxLengthRole, "maxLength" }, + { PromptBatchSizeRole, "promptBatchSize" }, + { ContextLengthRole, "contextLength" }, + { GpuLayersRole, "gpuLayers" }, + { RepeatPenaltyRole, "repeatPenalty" }, + { RepeatPenaltyTokensRole, "repeatPenaltyTokens" }, + { PromptTemplateRole, "promptTemplate" }, + { SystemPromptRole, "systemPrompt" }, + { ChatNamePromptRole, "chatNamePrompt" }, + { SuggestedFollowUpPromptRole, "suggestedFollowUpPrompt" }, + { LikesRole, "likes" }, + { DownloadsRole, "downloads" }, + { RecencyRole, "recency" }, + }; return roles; } @@ -418,7 +443,8 @@ class ModelList : public QAbstractListModel Q_INVOKABLE bool isUniqueName(const QString &name) const; Q_INVOKABLE QString clone(const ModelInfo &model); Q_INVOKABLE void removeClone(const ModelInfo &model); - Q_INVOKABLE void removeInstalled(const ModelInfo &model); + // Delist a model that is about to be removed from the model dir + void removeInstalled(const ModelInfo &model); ModelInfo defaultModelInfo() const; void addModel(const QString &id); @@ -430,18 +456,6 @@ class ModelList : public QAbstractListModel InstalledModels *selectableModels() const { return m_selectableModels; } DownloadableModels *downloadableModels() const { return m_downloadableModels; } - static inline QString toFileSize(quint64 sz) { - if (sz < 1024) { - return u"%1 bytes"_s.arg(sz); - } else if (sz < 1024 * 1024) { - return u"%1 KB"_s.arg(qreal(sz) / 1024, 0, 'g', 3); - } else if (sz < 1024 * 1024 * 1024) { - return u"%1 MB"_s.arg(qreal(sz) / (1024 * 1024), 0, 'g', 3); - } else { - return u"%1 GB"_s.arg(qreal(sz) / (1024 * 1024 * 1024), 0, 'g', 3); - } - } - QString incompleteDownloadPath(const QString &modelFile); bool asyncModelRequestOngoing() const { return m_asyncModelRequestOngoing; } @@ -502,6 +516,7 @@ private Q_SLOTS: void parseModelsJsonFile(const QByteArray &jsonData, bool save); void parseDiscoveryJsonFile(const QByteArray &jsonData); QString uniqueModelName(const ModelInfo &model) const; + void processModelDirectory(const QString &path); private: mutable QMutex m_mutex; diff --git a/gpt4all-chat/mysettings.cpp b/gpt4all-chat/mysettings.cpp index b29ec431f302..d7db791eb17c 100644 --- a/gpt4all-chat/mysettings.cpp +++ b/gpt4all-chat/mysettings.cpp @@ -1,6 +1,7 @@ #include "mysettings.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include #include @@ -95,8 +96,8 @@ static QStringList getDevices(bool skipKompute = false) #if defined(Q_OS_MAC) && defined(__aarch64__) deviceList << "Metal"; #else - std::vector devices = LLModel::Implementation::availableGPUDevices(); - for (LLModel::GPUDevice &d : devices) { + auto devices = LlamaCppBackendManager::availableGPUDevices(); + for (auto &d : devices) { if (!skipKompute || strcmp(d.backend, "kompute")) deviceList << QString::fromStdString(d.selectionName()); } @@ -512,7 +513,7 @@ QString MySettings::device() auto device = value.toString(); if (!device.isEmpty()) { auto deviceStr = device.toStdString(); - auto newNameStr = LLModel::GPUDevice::updateSelectionName(deviceStr); + auto newNameStr = LlamaCppBackend::GPUDevice::updateSelectionName(deviceStr); if (newNameStr != deviceStr) { auto newName = QString::fromStdString(newNameStr); qWarning() << "updating device name:" << device << "->" << newName; diff --git a/gpt4all-chat/network.cpp b/gpt4all-chat/network.cpp index e7ee616cd2ca..2b5f332c9a6f 100644 --- a/gpt4all-chat/network.cpp +++ b/gpt4all-chat/network.cpp @@ -9,7 +9,7 @@ #include "modellist.h" #include "mysettings.h" -#include "../gpt4all-backend/llmodel.h" +#include "../gpt4all-backend/llamacpp_backend_manager.h" #include #include @@ -99,7 +99,8 @@ Network *Network::globalInstance() return networkInstance(); } -bool Network::isHttpUrlValid(QUrl url) { +bool Network::isHttpUrlValid(QUrl url) +{ if (!url.isValid()) return false; QString scheme(url.scheme()); @@ -290,7 +291,7 @@ void Network::sendStartup() {"display", u"%1x%2"_s.arg(display->size().width()).arg(display->size().height())}, {"ram", LLM::globalInstance()->systemTotalRAMInGB()}, {"cpu", getCPUModel()}, - {"cpu_supports_avx2", LLModel::Implementation::cpuSupportsAVX2()}, + {"cpu_supports_avx2", LlamaCppBackendManager::cpuSupportsAVX2()}, {"datalake_active", mySettings->networkIsActive()}, }); sendIpify(); diff --git a/gpt4all-chat/ollama_model.cpp b/gpt4all-chat/ollama_model.cpp new file mode 100644 index 000000000000..f0f29cdb8f0b --- /dev/null +++ b/gpt4all-chat/ollama_model.cpp @@ -0,0 +1,670 @@ +#include "ollama_model.h" + +#include "chat.h" +#include "chatapi.h" +#include "localdocs.h" +#include "mysettings.h" +#include "network.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Qt::Literals::StringLiterals; + + +#define OLLAMA_INTERNAL_STATE_VERSION 0 + +OllamaModel::OllamaModel() + : m_shouldBeLoaded(false) + , m_forceUnloadModel(false) + , m_markedForDeletion(false) + , m_stopGenerating(false) + , m_timer(new TokenTimer(this)) + , m_processedSystemPrompt(false) +{ + connect(this, &OllamaModel::shouldBeLoadedChanged, this, &OllamaModel::handleShouldBeLoadedChanged); + connect(this, &OllamaModel::trySwitchContextRequested, this, &OllamaModel::trySwitchContextOfLoadedModel); + connect(m_timer, &TokenTimer::report, this, &OllamaModel::reportSpeed); + + // The following are blocking operations and will block the llm thread + connect(this, &OllamaModel::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, + Qt::BlockingQueuedConnection); +} + +OllamaModel::~OllamaModel() +{ + destroy(); +} + +void OllamaModel::destroy() +{ + // TODO(jared): cancel pending network requests +} + +void OllamaModel::destroyStore() +{ + LLModelStore::globalInstance()->destroy(); +} + +bool OllamaModel::loadDefaultModel() +{ + ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); + if (defaultModel.filename().isEmpty()) { + emit modelLoadingError(u"Could not find any model to load"_s); + return false; + } + return loadModel(defaultModel); +} + +void OllamaModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) +{ + // no-op: we require the model to be explicitly loaded for now. +} + +bool OllamaModel::loadModel(const ModelInfo &modelInfo) +{ + // We're already loaded with this model + if (isModelLoaded() && this->modelInfo() == modelInfo) + return true; + + // reset status + emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + emit modelLoadingError(""); + + QString filePath = modelInfo.dirpath + modelInfo.filename(); + QFileInfo fileInfo(filePath); + + // We have a live model, but it isn't the one we want + bool alreadyAcquired = isModelLoaded(); + if (alreadyAcquired) { + resetContext(); + m_llModelInfo.resetModel(this); + } else { + // This is a blocking call that tries to retrieve the model we need from the model store. + // If it succeeds, then we just have to restore state. If the store has never had a model + // returned to it, then the modelInfo.model pointer should be null which will happen on startup + acquireModel(); + // At this point it is possible that while we were blocked waiting to acquire the model from the + // store, that our state was changed to not be loaded. If this is the case, release the model + // back into the store and quit loading + if (!m_shouldBeLoaded) { + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); + emit modelLoadingPercentageChanged(0.0f); + return false; + } + + // Check if the store just gave us exactly the model we were looking for + if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo) { + restoreState(); + emit modelLoadingPercentageChanged(1.0f); + setModelInfo(modelInfo); + Q_ASSERT(!m_modelInfo.filename().isEmpty()); + if (m_modelInfo.filename().isEmpty()) + emit modelLoadingError(u"Modelinfo is left null for %1"_s.arg(modelInfo.filename())); + else + processSystemPrompt(); + return true; + } else { + // Release the memory since we have to switch to a different model. + m_llModelInfo.resetModel(this); + } + } + + // Guarantee we've released the previous models memory + Q_ASSERT(!m_llModelInfo.model); + + // Store the file info in the modelInfo in case we have an error loading + m_llModelInfo.fileInfo = fileInfo; + + if (fileInfo.exists()) { + QVariantMap modelLoadProps; + + // TODO(jared): load the model here +#if 0 + if (modelInfo.isOnline) { + QString apiKey; + QString requestUrl; + QString modelName; + { + QFile file(filePath); + bool success = file.open(QIODeviceBase::ReadOnly); + (void)success; + Q_ASSERT(success); + QJsonDocument doc = QJsonDocument::fromJson(file.readAll()); + QJsonObject obj = doc.object(); + apiKey = obj["apiKey"].toString(); + modelName = obj["modelName"].toString(); + if (modelInfo.isCompatibleApi) { + QString baseUrl(obj["baseUrl"].toString()); + QUrl apiUrl(QUrl::fromUserInput(baseUrl)); + if (!Network::isHttpUrlValid(apiUrl)) + return false; + + QString currentPath(apiUrl.path()); + QString suffixPath("%1/chat/completions"); + apiUrl.setPath(suffixPath.arg(currentPath)); + requestUrl = apiUrl.toString(); + } else { + requestUrl = modelInfo.url(); + } + } + ChatAPI *model = new ChatAPI(); + model->setModelName(modelName); + model->setRequestURL(requestUrl); + model->setAPIKey(apiKey); + m_llModelInfo.resetModel(this, model); + } else if (!loadNewModel(modelInfo, modelLoadProps)) { + return false; // m_shouldBeLoaded became false + } +#endif + + restoreState(); + emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f); + emit loadedModelInfoChanged(); + + modelLoadProps.insert("model", modelInfo.filename()); + Network::globalInstance()->trackChatEvent("model_load", modelLoadProps); + } else { + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store + resetModel(); + emit modelLoadingError(u"Could not find file for model %1"_s.arg(modelInfo.filename())); + } + + if (m_llModelInfo.model) { + setModelInfo(modelInfo); + processSystemPrompt(); + } + return bool(m_llModelInfo.model); +} + +bool OllamaModel::isModelLoaded() const +{ + return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); +} + +// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response +void OllamaModel::regenerateResponse() +{ + m_ctx.n_past = std::max(0, m_ctx.n_past - m_promptResponseTokens); + m_ctx.tokens.erase(m_ctx.tokens.end() - m_promptResponseTokens, m_ctx.tokens.end()); + m_promptResponseTokens = 0; + m_promptTokens = 0; + m_response = std::string(); + emit responseChanged(QString::fromStdString(m_response)); +} + +void OllamaModel::resetResponse() +{ + m_promptTokens = 0; + m_promptResponseTokens = 0; + m_response = std::string(); + emit responseChanged(QString::fromStdString(m_response)); +} + +void OllamaModel::resetContext() +{ + resetResponse(); + m_processedSystemPrompt = false; + m_ctx = ModelBackend::PromptContext(); +} + +QString OllamaModel::response() const +{ + return QString::fromStdString(remove_leading_whitespace(m_response)); +} + +void OllamaModel::setModelInfo(const ModelInfo &modelInfo) +{ + m_modelInfo = modelInfo; + emit modelInfoChanged(modelInfo); +} + +void OllamaModel::acquireModel() +{ + m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); + emit loadedModelInfoChanged(); +} + +void OllamaModel::resetModel() +{ + m_llModelInfo = {}; + emit loadedModelInfoChanged(); +} + +void OllamaModel::modelChangeRequested(const ModelInfo &modelInfo) +{ + m_shouldBeLoaded = true; + loadModel(modelInfo); +} + +bool OllamaModel::handlePrompt(int32_t token) +{ + // m_promptResponseTokens is related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptTokens; + ++m_promptResponseTokens; + m_timer->start(); + return !m_stopGenerating; +} + +bool OllamaModel::handleResponse(int32_t token, const std::string &response) +{ + // check for error + if (token < 0) { + m_response.append(response); + emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + return false; + } + + // m_promptResponseTokens is related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptResponseTokens; + m_timer->inc(); + Q_ASSERT(!response.empty()); + m_response.append(response); + emit responseChanged(QString::fromStdString(remove_leading_whitespace(m_response))); + return !m_stopGenerating; +} + +bool OllamaModel::prompt(const QList &collectionList, const QString &prompt) +{ + if (!m_processedSystemPrompt) + processSystemPrompt(); + const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); + const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); + const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); + const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); + const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); + const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); + const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); + const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); + return promptInternal(collectionList, prompt, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, + repeat_penalty, repeat_penalty_tokens); +} + +bool OllamaModel::promptInternal(const QList &collectionList, const QString &prompt, const QString &promptTemplate, + int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, + int32_t repeat_penalty_tokens) +{ + if (!isModelLoaded()) + return false; + + QList databaseResults; + const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); + if (!collectionList.isEmpty()) { + emit requestRetrieveFromDB(collectionList, prompt, retrievalSize, &databaseResults); // blocks + emit databaseResultsChanged(databaseResults); + } + + // Augment the prompt template with the results if any + QString docsContext; + if (!databaseResults.isEmpty()) { + QStringList results; + for (const ResultInfo &info : databaseResults) + results << u"Collection: %1\nPath: %2\nExcerpt: %3"_s.arg(info.collection, info.path, info.text); + + // FIXME(jared): use a Jinja prompt template instead of hardcoded Alpaca-style localdocs template + docsContext = u"### Context:\n%1\n\n"_s.arg(results.join("\n\n")); + } + + int n_threads = MySettings::globalInstance()->threadCount(); + + m_stopGenerating = false; + auto promptFunc = std::bind(&OllamaModel::handlePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&OllamaModel::handleResponse, this, std::placeholders::_1, + std::placeholders::_2); + emit promptProcessing(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.min_p = min_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + + QElapsedTimer totalTime; + totalTime.start(); + m_timer->start(); + if (!docsContext.isEmpty()) { + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response + m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, + /*allowContextShift*/ true, m_ctx); + m_ctx.n_predict = old_n_predict; // now we are ready for a response + } + m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, + /*allowContextShift*/ true, m_ctx); + + m_timer->stop(); + qint64 elapsed = totalTime.elapsed(); + std::string trimmed = trim_whitespace(m_response); + if (trimmed != m_response) { + m_response = trimmed; + emit responseChanged(QString::fromStdString(m_response)); + } + + SuggestionMode mode = MySettings::globalInstance()->suggestionMode(); + if (mode == SuggestionMode::On || (!databaseResults.isEmpty() && mode == SuggestionMode::LocalDocsOnly)) + generateQuestions(elapsed); + else + emit responseStopped(elapsed); + + return true; +} + +void OllamaModel::setShouldBeLoaded(bool value, bool forceUnload) +{ + m_shouldBeLoaded = b; // atomic + emit shouldBeLoadedChanged(forceUnload); +} + +void OllamaModel::requestTrySwitchContext() +{ + m_shouldBeLoaded = true; // atomic + emit trySwitchContextRequested(modelInfo()); +} + +void OllamaModel::handleShouldBeLoadedChanged() +{ + if (m_shouldBeLoaded) + reloadModel(); + else + unloadModel(); +} + +void OllamaModel::unloadModel() +{ + if (!isModelLoaded()) + return; + + if (!m_forceUnloadModel || !m_shouldBeLoaded) + emit modelLoadingPercentageChanged(0.0f); + else + emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + + if (!m_markedForDeletion) + saveState(); + + if (m_forceUnloadModel) { + m_llModelInfo.resetModel(this); + m_forceUnloadModel = false; + } + + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); +} + +void OllamaModel::reloadModel() +{ + if (isModelLoaded() && m_forceUnloadModel) + unloadModel(); // we unload first if we are forcing an unload + + if (isModelLoaded()) + return; + + const ModelInfo m = modelInfo(); + if (m.name().isEmpty()) + loadDefaultModel(); + else + loadModel(m); +} + +void OllamaModel::generateName() +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded()) + return; + + const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo); + if (chatNamePrompt.trimmed().isEmpty()) { + qWarning() << "OllamaModel: not generating chat name because prompt is empty"; + return; + } + + auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); + auto promptFunc = std::bind(&OllamaModel::handleNamePrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&OllamaModel::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); + ModelBackend::PromptContext ctx = m_ctx; + m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(), + promptFunc, responseFunc, /*allowContextShift*/ false, ctx); + std::string trimmed = trim_whitespace(m_nameResponse); + if (trimmed != m_nameResponse) { + m_nameResponse = trimmed; + emit generatedNameChanged(QString::fromStdString(m_nameResponse)); + } +} + +bool OllamaModel::handleNamePrompt(int32_t token) +{ + Q_UNUSED(token); + return !m_stopGenerating; +} + +bool OllamaModel::handleNameResponse(int32_t token, const std::string &response) +{ + Q_UNUSED(token); + + m_nameResponse.append(response); + emit generatedNameChanged(QString::fromStdString(m_nameResponse)); + QString gen = QString::fromStdString(m_nameResponse).simplified(); + QStringList words = gen.split(' ', Qt::SkipEmptyParts); + return words.size() <= 3; +} + +bool OllamaModel::handleQuestionPrompt(int32_t token) +{ + Q_UNUSED(token); + return !m_stopGenerating; +} + +bool OllamaModel::handleQuestionResponse(int32_t token, const std::string &response) +{ + Q_UNUSED(token); + + // add token to buffer + m_questionResponse.append(response); + + // match whole question sentences + // FIXME: This only works with response by the model in english which is not ideal for a multi-language + // model. + static const QRegularExpression reQuestion(R"(\b(What|Where|How|Why|When|Who|Which|Whose|Whom)\b[^?]*\?)"); + + // extract all questions from response + int lastMatchEnd = -1; + for (const auto &match : reQuestion.globalMatch(m_questionResponse)) { + lastMatchEnd = match.capturedEnd(); + emit generatedQuestionFinished(match.captured()); + } + + // remove processed input from buffer + if (lastMatchEnd != -1) + m_questionResponse.erase(m_questionResponse.cbegin(), m_questionResponse.cbegin() + lastMatchEnd); + + return true; +} + +void OllamaModel::generateQuestions(qint64 elapsed) +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded()) { + emit responseStopped(elapsed); + return; + } + + const std::string suggestedFollowUpPrompt = MySettings::globalInstance()->modelSuggestedFollowUpPrompt(m_modelInfo).toStdString(); + if (QString::fromStdString(suggestedFollowUpPrompt).trimmed().isEmpty()) { + emit responseStopped(elapsed); + return; + } + + emit generatingQuestions(); + m_questionResponse.clear(); + auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); + auto promptFunc = std::bind(&OllamaModel::handleQuestionPrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&OllamaModel::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); + ModelBackend::PromptContext ctx = m_ctx; + QElapsedTimer totalTime; + totalTime.start(); + m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc, + /*allowContextShift*/ false, ctx); + elapsed += totalTime.elapsed(); + emit responseStopped(elapsed); +} + + +bool OllamaModel::handleSystemPrompt(int32_t token) +{ + Q_UNUSED(token); + return !m_stopGenerating; +} + +// this function serialized the cached model state to disk. +// we want to also serialize n_ctx, and read it at load time. +bool OllamaModel::serialize(QDataStream &stream, int version, bool serializeKV) +{ + Q_UNUSED(serializeKV); + + if (version < 10) + throw std::out_of_range("ollama not avaliable until chat version 10, attempted to serialize version " + std::to_string(version)); + + stream << OLLAMA_INTERNAL_STATE_VERSION; + + stream << response(); + stream << generatedName(); + // TODO(jared): do not save/restore m_promptResponseTokens, compute the appropriate value instead + stream << m_promptResponseTokens; + + stream << m_ctx.n_ctx; + saveState(); + QByteArray compressed = qCompress(m_state); + stream << compressed; + + return stream.status() == QDataStream::Ok; +} + +bool OllamaModel::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) +{ + Q_UNUSED(deserializeKV); + Q_UNUSED(discardKV); + + Q_ASSERT(version >= 10); + + int internalStateVersion; + stream >> internalStateVersion; // for future use + + QString response; + stream >> response; + m_response = response.toStdString(); + QString nameResponse; + stream >> nameResponse; + m_nameResponse = nameResponse.toStdString(); + stream >> m_promptResponseTokens; + + uint32_t n_ctx; + stream >> n_ctx; + m_ctx.n_ctx = n_ctx; + + QByteArray compressed; + stream >> compressed; + m_state = qUncompress(compressed); + + return stream.status() == QDataStream::Ok; +} + +void OllamaModel::saveState() +{ + if (!isModelLoaded()) + return; + + // m_llModelType == LLModelType::API_ + m_state.clear(); + QDataStream stream(&m_state, QIODeviceBase::WriteOnly); + stream.setVersion(QDataStream::Qt_6_4); + ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); + stream << chatAPI->context(); + // end API +} + +void OllamaModel::restoreState() +{ + if (!isModelLoaded()) + return; + + // m_llModelType == LLModelType::API_ + QDataStream stream(&m_state, QIODeviceBase::ReadOnly); + stream.setVersion(QDataStream::Qt_6_4); + ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); + QList context; + stream >> context; + chatAPI->setContext(context); + m_state.clear(); + m_state.squeeze(); + // end API +} + +void OllamaModel::processSystemPrompt() +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText) + return; + + const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString(); + if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) { + m_processedSystemPrompt = true; + return; + } + + // Start with a whole new context + m_stopGenerating = false; + m_ctx = ModelBackend::PromptContext(); + + auto promptFunc = std::bind(&OllamaModel::handleSystemPrompt, this, std::placeholders::_1); + + const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); + const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); + const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); + const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); + const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); + const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); + const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); + int n_threads = MySettings::globalInstance()->threadCount(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.min_p = min_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response + // use "%1%2" and not "%1" to avoid implicit whitespace + m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); + m_ctx.n_predict = old_n_predict; + + m_processedSystemPrompt = m_stopGenerating == false; +} diff --git a/gpt4all-chat/ollama_model.h b/gpt4all-chat/ollama_model.h new file mode 100644 index 000000000000..1582b617e432 --- /dev/null +++ b/gpt4all-chat/ollama_model.h @@ -0,0 +1,51 @@ +#pragma once + +#include "database.h" // IWYU pragma: keep +#include "llmodel.h" +#include "modellist.h" // IWYU pragma: keep + +#include +#include +#include +#include +#include + +class Chat; +class QDataStream; + + +class OllamaModel : public LLModel +{ + Q_OBJECT + +public: + OllamaModel(); + ~OllamaModel() override = default; + + void regenerateResponse() override; + void resetResponse() override; + void resetContext() override; + + void stopGenerating() override; + + void setShouldBeLoaded(bool b) override; + void requestTrySwitchContext() override; + void setForceUnloadModel(bool b) override; + void setMarkedForDeletion(bool b) override; + + void setModelInfo(const ModelInfo &info) override; + + bool restoringFromText() const override; + + bool serialize(QDataStream &stream, int version, bool serializeKV) override; + bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override; + void setStateFromText(const QVector> &stateFromText) override; + +public Q_SLOTS: + bool prompt(const QList &collectionList, const QString &prompt) override; + bool loadDefaultModel() override; + bool loadModel(const ModelInfo &modelInfo) override; + void modelChangeRequested(const ModelInfo &modelInfo) override; + void generateName() override; + void processSystemPrompt() override; +}; diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index c8485d93e11d..dd498dbd08d3 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -71,7 +71,7 @@ static inline QJsonObject resultToJson(const ResultInfo &info) } Server::Server(Chat *chat) - : ChatLLM(chat, true /*isServer*/) + : LlamaCppModel(chat, true /*isServer*/) , m_chat(chat) , m_server(nullptr) { @@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re emit requestServerNewPromptResponsePair(actualPrompt); // blocks // load the new model if necessary - setShouldBeLoaded(true); + m_shouldBeLoaded = true; if (modelInfo.filename().isEmpty()) { std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl; diff --git a/gpt4all-chat/server.h b/gpt4all-chat/server.h index 689f0b6061e3..6686e1527045 100644 --- a/gpt4all-chat/server.h +++ b/gpt4all-chat/server.h @@ -1,7 +1,7 @@ #ifndef SERVER_H #define SERVER_H -#include "chatllm.h" +#include "llamacpp_model.h" #include "database.h" #include @@ -13,7 +13,7 @@ class Chat; class QHttpServer; -class Server : public ChatLLM +class Server : public LlamaCppModel { Q_OBJECT