Skip to content

Commit

Permalink
Patch Ollama Streaming chunk issues (#500)
Browse files Browse the repository at this point in the history
Replace stream/sync chats with Langchain interface for now
connect #499
ref: #495 (comment)
  • Loading branch information
timothycarambat authored Dec 28, 2023
1 parent d748167 commit 2a1202d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 132 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cSpell.words": [
"Dockerized",
"Langchain",
"Ollama",
"openai",
"Qdrant",
Expand Down
175 changes: 74 additions & 101 deletions server/utils/AiProviders/ollama/index.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
const { chatPrompt } = require("../../chats");
const { StringOutputParser } = require("langchain/schema/output_parser");

// Docs: https://github.com/jmorganca/ollama/blob/main/docs/api.md
class OllamaAILLM {
Expand All @@ -21,6 +22,42 @@ class OllamaAILLM {
this.embedder = embedder;
}

#ollamaClient({ temperature = 0.07 }) {
const { ChatOllama } = require("langchain/chat_models/ollama");
return new ChatOllama({
baseUrl: this.basePath,
model: this.model,
temperature,
});
}

// For streaming we use Langchain's wrapper to handle weird chunks
// or otherwise absorb headaches that can arise from Ollama models
#convertToLangchainPrototypes(chats = []) {
const {
HumanMessage,
SystemMessage,
AIMessage,
} = require("langchain/schema");
const langchainChats = [];
for (const chat of chats) {
switch (chat.role) {
case "system":
langchainChats.push(new SystemMessage({ content: chat.content }));
break;
case "user":
langchainChats.push(new HumanMessage({ content: chat.content }));
break;
case "assistant":
langchainChats.push(new AIMessage({ content: chat.content }));
break;
default:
break;
}
}
return langchainChats;
}

streamingEnabled() {
return "streamChat" in this && "streamGetChatCompletion" in this;
}
Expand Down Expand Up @@ -63,37 +100,21 @@ Context:
}

async sendChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
const textResponse = await fetch(`${this.basePath}/api/chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
const messages = await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
body: JSON.stringify({
model: this.model,
stream: false,
options: {
temperature: Number(workspace?.openAiTemp ?? 0.7),
},
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
}),
})
.then((res) => {
if (!res.ok)
throw new Error(`Ollama:sendChat ${res.status} ${res.statusText}`);
return res.json();
})
.then((data) => data?.message?.content)
.catch((e) => {
console.error(e);
throw new Error(`Ollama::sendChat failed with: ${error.message}`);
});
rawHistory
);

const model = this.#ollamaClient({
temperature: Number(workspace?.openAiTemp ?? 0.7),
});
const textResponse = await model
.pipe(new StringOutputParser())
.invoke(this.#convertToLangchainPrototypes(messages));

if (!textResponse.length)
throw new Error(`Ollama::sendChat text response was empty.`);
Expand All @@ -102,63 +123,29 @@ Context:
}

async streamChat(chatHistory = [], prompt, workspace = {}, rawHistory = []) {
const response = await fetch(`${this.basePath}/api/chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
const messages = await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
body: JSON.stringify({
model: this.model,
stream: true,
options: {
temperature: Number(workspace?.openAiTemp ?? 0.7),
},
messages: await this.compressMessages(
{
systemPrompt: chatPrompt(workspace),
userPrompt: prompt,
chatHistory,
},
rawHistory
),
}),
}).catch((e) => {
console.error(e);
throw new Error(`Ollama:streamChat ${error.message}`);
});
rawHistory
);

return { type: "ollamaStream", response };
const model = this.#ollamaClient({
temperature: Number(workspace?.openAiTemp ?? 0.7),
});
const stream = await model
.pipe(new StringOutputParser())
.stream(this.#convertToLangchainPrototypes(messages));
return stream;
}

async getChatCompletion(messages = null, { temperature = 0.7 }) {
const textResponse = await fetch(`${this.basePath}/api/chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
model: this.model,
messages,
stream: false,
options: {
temperature,
},
}),
})
.then((res) => {
if (!res.ok)
throw new Error(
`Ollama:getChatCompletion ${res.status} ${res.statusText}`
);
return res.json();
})
.then((data) => data?.message?.content)
.catch((e) => {
console.error(e);
throw new Error(
`Ollama::getChatCompletion failed with: ${error.message}`
);
});
const model = this.#ollamaClient({ temperature });
const textResponse = await model
.pipe(new StringOutputParser())
.invoke(this.#convertToLangchainPrototypes(messages));

if (!textResponse.length)
throw new Error(`Ollama::getChatCompletion text response was empty.`);
Expand All @@ -167,25 +154,11 @@ Context:
}

async streamGetChatCompletion(messages = null, { temperature = 0.7 }) {
const response = await fetch(`${this.basePath}/api/chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
model: this.model,
stream: true,
messages,
options: {
temperature,
},
}),
}).catch((e) => {
console.error(e);
throw new Error(`Ollama:streamGetChatCompletion ${error.message}`);
});

return { type: "ollamaStream", response };
const model = this.#ollamaClient({ temperature });
const stream = await model
.pipe(new StringOutputParser())
.stream(this.#convertToLangchainPrototypes(messages));
return stream;
}

// Simple wrapper for dynamic embedder & normalize interface for all LLM implementations
Expand Down
35 changes: 4 additions & 31 deletions server/utils/chats/stream.js
Original file line number Diff line number Diff line change
Expand Up @@ -232,46 +232,19 @@ function handleStreamResponses(response, stream, responseProps) {
});
}

if (stream?.type === "ollamaStream") {
return new Promise(async (resolve) => {
let fullText = "";
for await (const dataChunk of stream.response.body) {
const chunk = JSON.parse(Buffer.from(dataChunk).toString());
fullText += chunk.message.content;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: chunk.message.content,
close: false,
error: false,
});
}

writeResponseChunk(response, {
uuid,
sources,
type: "textResponseChunk",
textResponse: "",
close: true,
error: false,
});
resolve(fullText);
});
}

// If stream is not a regular OpenAI Stream (like if using native model)
// If stream is not a regular OpenAI Stream (like if using native model, Ollama, or most LangChain interfaces)
// we can just iterate the stream content instead.
if (!stream.hasOwnProperty("data")) {
return new Promise(async (resolve) => {
let fullText = "";
for await (const chunk of stream) {
fullText += chunk.content;
const content = chunk.hasOwnProperty("content") ? chunk.content : chunk;
fullText += content;
writeResponseChunk(response, {
uuid,
sources: [],
type: "textResponseChunk",
textResponse: chunk.content,
textResponse: content,
close: false,
error: false,
});
Expand Down

0 comments on commit 2a1202d

Please sign in to comment.