From 6d5968bf7e694aae8e68bbb4cd2e0e6f491f47d1 Mon Sep 17 00:00:00 2001 From: Timothy Carambat Date: Thu, 28 Dec 2023 14:42:34 -0800 Subject: [PATCH] Llm chore cleanup (#501) * move internal functions to private in class simplify lc message convertor * Fix hanging Context text when none is present --- server/utils/AiProviders/azureOpenAi/index.js | 20 +++-- server/utils/AiProviders/gemini/index.js | 22 +++-- server/utils/AiProviders/lmStudio/index.js | 20 +++-- server/utils/AiProviders/localAi/index.js | 20 +++-- server/utils/AiProviders/native/index.js | 82 +++++++++---------- server/utils/AiProviders/ollama/index.js | 43 +++++----- server/utils/AiProviders/openAi/index.js | 20 +++-- 7 files changed, 129 insertions(+), 98 deletions(-) diff --git a/server/utils/AiProviders/azureOpenAi/index.js b/server/utils/AiProviders/azureOpenAi/index.js index 82e28204bc..83ac3c4cd5 100644 --- a/server/utils/AiProviders/azureOpenAi/index.js +++ b/server/utils/AiProviders/azureOpenAi/index.js @@ -27,6 +27,18 @@ class AzureOpenAiLLM { this.embedder = !embedder ? new AzureOpenAiEmbedder() : embedder; } + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -55,13 +67,7 @@ class AzureOpenAiLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; } diff --git a/server/utils/AiProviders/gemini/index.js b/server/utils/AiProviders/gemini/index.js index d0a76c550a..03388e3e20 100644 --- a/server/utils/AiProviders/gemini/index.js +++ b/server/utils/AiProviders/gemini/index.js @@ -1,4 +1,3 @@ -const { v4 } = require("uuid"); const { chatPrompt } = require("../../chats"); class GeminiLLM { @@ -22,7 +21,18 @@ class GeminiLLM { "INVALID GEMINI LLM SETUP. No embedding engine has been set. Go to instance settings and set up an embedding interface to use Gemini as your LLM." ); this.embedder = embedder; - this.answerKey = v4().split("-")[0]; + } + + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); } streamingEnabled() { @@ -57,13 +67,7 @@ class GeminiLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [ prompt, diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index 4d9770e665..28c107df08 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -27,6 +27,18 @@ class LMStudioLLM { this.embedder = embedder; } + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -54,13 +66,7 @@ class LMStudioLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; } diff --git a/server/utils/AiProviders/localAi/index.js b/server/utils/AiProviders/localAi/index.js index 6c7a3263fb..84954c9942 100644 --- a/server/utils/AiProviders/localAi/index.js +++ b/server/utils/AiProviders/localAi/index.js @@ -29,6 +29,18 @@ class LocalAiLLM { this.embedder = embedder; } + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -54,13 +66,7 @@ class LocalAiLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; } diff --git a/server/utils/AiProviders/native/index.js b/server/utils/AiProviders/native/index.js index 89589d112e..faac4fa030 100644 --- a/server/utils/AiProviders/native/index.js +++ b/server/utils/AiProviders/native/index.js @@ -1,8 +1,6 @@ -const os = require("os"); const fs = require("fs"); const path = require("path"); const { NativeEmbedder } = require("../../EmbeddingEngines/native"); -const { HumanMessage, SystemMessage, AIMessage } = require("langchain/schema"); const { chatPrompt } = require("../../chats"); // Docs: https://api.js.langchain.com/classes/chat_models_llama_cpp.ChatLlamaCpp.html @@ -29,12 +27,6 @@ class NativeLLM { : path.resolve(__dirname, `../../../storage/models/downloaded`) ); - // Set ENV for if llama.cpp needs to rebuild at runtime and machine is not - // running Apple Silicon. - process.env.NODE_LLAMA_CPP_METAL = os - .cpus() - .some((cpu) => cpu.model.includes("Apple")); - // Make directory when it does not exist in existing installations if (!fs.existsSync(this.cacheDir)) fs.mkdirSync(this.cacheDir); } @@ -56,12 +48,46 @@ class NativeLLM { // If the model has been loaded once, it is in the memory now // so we can skip re-loading it and instead go straight to inference. // Note: this will break temperature setting hopping between workspaces with different temps. - async llamaClient({ temperature = 0.7 }) { + async #llamaClient({ temperature = 0.7 }) { if (global.llamaModelInstance) return global.llamaModelInstance; await this.#initializeLlamaModel(temperature); return global.llamaModelInstance; } + #convertToLangchainPrototypes(chats = []) { + const { + HumanMessage, + SystemMessage, + AIMessage, + } = require("langchain/schema"); + const langchainChats = []; + const roleToMessageMap = { + system: SystemMessage, + user: HumanMessage, + assistant: AIMessage, + }; + + for (const chat of chats) { + if (!roleToMessageMap.hasOwnProperty(chat.role)) continue; + const MessageClass = roleToMessageMap[chat.role]; + langchainChats.push(new MessageClass({ content: chat.content })); + } + + return langchainChats; + } + + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -84,13 +110,7 @@ class NativeLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; } @@ -111,7 +131,7 @@ Context: rawHistory ); - const model = await this.llamaClient({ + const model = await this.#llamaClient({ temperature: Number(workspace?.openAiTemp ?? 0.7), }); const response = await model.call(messages); @@ -124,7 +144,7 @@ Context: } async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) { - const model = await this.llamaClient({ + const model = await this.#llamaClient({ temperature: Number(workspace?.openAiTemp ?? 0.7), }); const messages = await this.compressMessages( @@ -140,13 +160,13 @@ Context: } async getChatCompletion(messages = null, { temperature = 0.7 }) { - const model = await this.llamaClient({ temperature }); + const model = await this.#llamaClient({ temperature }); const response = await model.call(messages); return response.content; } async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { - const model = await this.llamaClient({ temperature }); + const model = await this.#llamaClient({ temperature }); const responseStream = await model.stream(messages); return responseStream; } @@ -167,27 +187,7 @@ Context: messageArray, rawHistory ); - return this.convertToLangchainPrototypes(compressedMessages); - } - - convertToLangchainPrototypes(chats = []) { - const langchainChats = []; - for (const chat of chats) { - switch (chat.role) { - case "system": - langchainChats.push(new SystemMessage({ content: chat.content })); - break; - case "user": - langchainChats.push(new HumanMessage({ content: chat.content })); - break; - case "assistant": - langchainChats.push(new AIMessage({ content: chat.content })); - break; - default: - break; - } - } - return langchainChats; + return this.#convertToLangchainPrototypes(compressedMessages); } } diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index f160e5d36f..55205c23d9 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -40,24 +40,33 @@ class OllamaAILLM { AIMessage, } = require("langchain/schema"); const langchainChats = []; + const roleToMessageMap = { + system: SystemMessage, + user: HumanMessage, + assistant: AIMessage, + }; + for (const chat of chats) { - switch (chat.role) { - case "system": - langchainChats.push(new SystemMessage({ content: chat.content })); - break; - case "user": - langchainChats.push(new HumanMessage({ content: chat.content })); - break; - case "assistant": - langchainChats.push(new AIMessage({ content: chat.content })); - break; - default: - break; - } + if (!roleToMessageMap.hasOwnProperty(chat.role)) continue; + const MessageClass = roleToMessageMap[chat.role]; + langchainChats.push(new MessageClass({ content: chat.content })); } + return langchainChats; } + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -83,13 +92,7 @@ class OllamaAILLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; } diff --git a/server/utils/AiProviders/openAi/index.js b/server/utils/AiProviders/openAi/index.js index 4646427196..ccc7ba0e9b 100644 --- a/server/utils/AiProviders/openAi/index.js +++ b/server/utils/AiProviders/openAi/index.js @@ -24,6 +24,18 @@ class OpenAiLLM { this.embedder = !embedder ? new OpenAiEmbedder() : embedder; } + #appendContext(contextTexts = []) { + if (!contextTexts || !contextTexts.length) return ""; + return ( + "\nContext:\n" + + contextTexts + .map((text, i) => { + return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; + }) + .join("") + ); + } + streamingEnabled() { return "streamChat" in this && "streamGetChatCompletion" in this; } @@ -68,13 +80,7 @@ class OpenAiLLM { }) { const prompt = { role: "system", - content: `${systemPrompt} -Context: - ${contextTexts - .map((text, i) => { - return `[CONTEXT ${i}]:\n${text}\n[END CONTEXT ${i}]\n\n`; - }) - .join("")}`, + content: `${systemPrompt}${this.#appendContext(contextTexts)}`, }; return [prompt, ...chatHistory, { role: "user", content: userPrompt }]; }