Skip to content

Commit

Permalink
Per workspace model selection (#582)
Browse files Browse the repository at this point in the history
* WIP model selection per workspace (migrations and openai saves properly

* revert OpenAiOption

* add support for models per workspace for anthropic, localAi, ollama, openAi, and togetherAi

* remove unneeded comments

* update logic for when LLMProvider is reset, reset Ai provider files with master

* remove frontend/api reset of workspace chat and move logic to updateENV
add postUpdate callbacks to envs

* set preferred model for chat on class instantiation

* remove extra param

* linting

* remove unused var

* refactor chat model selection on workspace

* linting

* add fallback for base path to localai models

---------

Co-authored-by: timothycarambat <[email protected]>
  • Loading branch information
shatfield4 and timothycarambat authored Jan 17, 2024
1 parent bf503ee commit 90df375
Show file tree
Hide file tree
Showing 24 changed files with 263 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import useGetProviderModels, {
DISABLED_PROVIDERS,
} from "./useGetProviderModels";

export default function ChatModelSelection({
settings,
workspace,
setHasChanges,
}) {
const { defaultModels, customModels, loading } = useGetProviderModels(
settings?.LLMProvider
);
if (DISABLED_PROVIDERS.includes(settings?.LLMProvider)) return null;

if (loading) {
return (
<div>
<div className="flex flex-col">
<label
htmlFor="name"
className="block text-sm font-medium text-white"
>
Chat model
</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
The specific chat model that will be used for this workspace. If
empty, will use the system LLM preference.
</p>
</div>
<select
name="chatModel"
required={true}
disabled={true}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
>
<option disabled={true} selected={true}>
-- waiting for models --
</option>
</select>
</div>
);
}

return (
<div>
<div className="flex flex-col">
<label htmlFor="name" className="block text-sm font-medium text-white">
Chat model{" "}
<span className="font-normal">({settings?.LLMProvider})</span>
</label>
<p className="text-white text-opacity-60 text-xs font-medium py-1.5">
The specific chat model that will be used for this workspace. If
empty, will use the system LLM preference.
</p>
</div>

<select
name="chatModel"
required={true}
onChange={() => {
setHasChanges(true);
}}
className="bg-zinc-900 text-white text-sm rounded-lg focus:ring-blue-500 focus:border-blue-500 block w-full p-2.5"
>
<option disabled={true} selected={workspace?.chatModel === null}>
System default
</option>
{defaultModels.length > 0 && (
<optgroup label="General models">
{defaultModels.map((model) => {
return (
<option
key={model}
value={model}
selected={workspace?.chatModel === model}
>
{model}
</option>
);
})}
</optgroup>
)}
{Array.isArray(customModels) && customModels.length > 0 && (
<optgroup label="Custom models">
{customModels.map((model) => {
return (
<option
key={model.id}
value={model.id}
selected={workspace?.chatModel === model.id}
>
{model.id}
</option>
);
})}
</optgroup>
)}
{/* For providers like TogetherAi where we partition model by creator entity. */}
{!Array.isArray(customModels) &&
Object.keys(customModels).length > 0 && (
<>
{Object.entries(customModels).map(([organization, models]) => (
<optgroup key={organization} label={organization}>
{models.map((model) => (
<option
key={model.id}
value={model.id}
selected={workspace?.chatModel === model.id}
>
{model.name}
</option>
))}
</optgroup>
))}
</>
)}
</select>
</div>
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import System from "@/models/system";
import { useEffect, useState } from "react";

// Providers which cannot use this feature for workspace<>model selection
export const DISABLED_PROVIDERS = ["azure", "lmstudio"];
const PROVIDER_DEFAULT_MODELS = {
openai: ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-32k"],
gemini: ["gemini-pro"],
anthropic: ["claude-2", "claude-instant-1"],
azure: [],
lmstudio: [],
localai: [],
ollama: [],
togetherai: [],
native: [],
};

// For togetherAi, which has a large model list - we subgroup the options
// by their creator organization (eg: Meta, Mistral, etc)
// which makes selection easier to read.
function groupModels(models) {
return models.reduce((acc, model) => {
acc[model.organization] = acc[model.organization] || [];
acc[model.organization].push(model);
return acc;
}, {});
}

export default function useGetProviderModels(provider = null) {
const [defaultModels, setDefaultModels] = useState([]);
const [customModels, setCustomModels] = useState([]);
const [loading, setLoading] = useState(true);

useEffect(() => {
async function fetchProviderModels() {
if (!provider) return;
const { models = [] } = await System.customModels(provider);
if (PROVIDER_DEFAULT_MODELS.hasOwnProperty(provider))
setDefaultModels(PROVIDER_DEFAULT_MODELS[provider]);
provider === "togetherai"
? setCustomModels(groupModels(models))
: setCustomModels(models);
setLoading(false);
}
fetchProviderModels();
}, [provider]);

return { defaultModels, customModels, loading };
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import System from "../../../../models/system";
import PreLoader from "../../../Preloader";
import { useParams } from "react-router-dom";
import showToast from "../../../../utils/toast";
import ChatModelPreference from "./ChatModelPreference";

// Ensure that a type is correct before sending the body
// to the backend.
Expand All @@ -26,7 +27,7 @@ function castToType(key, value) {
return definitions[key].cast(value);
}

export default function WorkspaceSettings({ active, workspace }) {
export default function WorkspaceSettings({ active, workspace, settings }) {
const { slug } = useParams();
const formEl = useRef(null);
const [saving, setSaving] = useState(false);
Expand Down Expand Up @@ -99,6 +100,11 @@ export default function WorkspaceSettings({ active, workspace }) {
<div className="flex">
<div className="flex flex-col gap-y-4 w-1/2">
<div className="w-3/4 flex flex-col gap-y-4">
<ChatModelPreference
settings={settings}
workspace={workspace}
setHasChanges={setHasChanges}
/>
<div>
<div className="flex flex-col">
<label
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/Modals/MangeWorkspace/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const ManageWorkspace = ({ hideModal = noop, providedSlug = null }) => {
<WorkspaceSettings
active={selectedTab === "settings"} // To force reload live sub-components like VectorCount
workspace={workspace}
settings={settings}
/>
</div>
</Suspense>
Expand Down
6 changes: 2 additions & 4 deletions frontend/src/pages/GeneralSettings/LLMPreference/index.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,17 @@ export default function GeneralLLMPreference() {
const [hasChanges, setHasChanges] = useState(false);
const [settings, setSettings] = useState(null);
const [loading, setLoading] = useState(true);

const [searchQuery, setSearchQuery] = useState("");
const [filteredLLMs, setFilteredLLMs] = useState([]);
const [selectedLLM, setSelectedLLM] = useState(null);

const isHosted = window.location.hostname.includes("useanything.com");

const handleSubmit = async (e) => {
e.preventDefault();
const form = e.target;
const data = {};
const data = { LLMProvider: selectedLLM };
const formData = new FormData(form);
data.LLMProvider = selectedLLM;

for (var [key, value] of formData.entries()) data[key] = value;
const { error } = await System.updateSystem(data);
setSaving(true);
Expand Down
2 changes: 1 addition & 1 deletion server/endpoints/api/system/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function apiSystemEndpoints(app) {
*/
try {
const body = reqBody(request);
const { newValues, error } = updateENV(body);
const { newValues, error } = await updateENV(body);
if (process.env.NODE_ENV === "production") await dumpENV();
response.status(200).json({ newValues, error });
} catch (e) {
Expand Down
6 changes: 3 additions & 3 deletions server/endpoints/system.js
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ function systemEndpoints(app) {
}

const body = reqBody(request);
const { newValues, error } = updateENV(body);
const { newValues, error } = await updateENV(body);
if (process.env.NODE_ENV === "production") await dumpENV();
response.status(200).json({ newValues, error });
} catch (e) {
Expand All @@ -312,7 +312,7 @@ function systemEndpoints(app) {
}

const { usePassword, newPassword } = reqBody(request);
const { error } = updateENV(
const { error } = await updateENV(
{
AuthToken: usePassword ? newPassword : "",
JWTSecret: usePassword ? v4() : "",
Expand Down Expand Up @@ -355,7 +355,7 @@ function systemEndpoints(app) {
message_limit: 25,
});

updateENV(
await updateENV(
{
AuthToken: "",
JWTSecret: process.env.JWT_SECRET || v4(),
Expand Down
15 changes: 15 additions & 0 deletions server/models/workspace.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const Workspace = {
"lastUpdatedAt",
"openAiPrompt",
"similarityThreshold",
"chatModel",
],

new: async function (name = null, creatorId = null) {
Expand Down Expand Up @@ -191,6 +192,20 @@ const Workspace = {
return { success: false, error: error.message };
}
},

resetWorkspaceChatModels: async () => {
try {
await prisma.workspaces.updateMany({
data: {
chatModel: null,
},
});
return { success: true, error: null };
} catch (error) {
console.error("Error resetting workspace chat models:", error.message);
return { success: false, error: error.message };
}
},
};

module.exports = { Workspace };
2 changes: 2 additions & 0 deletions server/prisma/migrations/20240113013409_init/migration.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "workspaces" ADD COLUMN "chatModel" TEXT;
1 change: 1 addition & 0 deletions server/prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ model workspaces {
lastUpdatedAt DateTime @default(now())
openAiPrompt String?
similarityThreshold Float? @default(0.25)
chatModel String?
workspace_users workspace_users[]
documents workspace_documents[]
}
Expand Down
5 changes: 3 additions & 2 deletions server/utils/AiProviders/anthropic/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const { v4 } = require("uuid");
const { chatPrompt } = require("../../chats");

class AnthropicLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.ANTHROPIC_API_KEY)
throw new Error("No Anthropic API key was set.");

Expand All @@ -12,7 +12,8 @@ class AnthropicLLM {
apiKey: process.env.ANTHROPIC_API_KEY,
});
this.anthropic = anthropic;
this.model = process.env.ANTHROPIC_MODEL_PREF || "claude-2";
this.model =
modelPreference || process.env.ANTHROPIC_MODEL_PREF || "claude-2";
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
Expand Down
2 changes: 1 addition & 1 deletion server/utils/AiProviders/azureOpenAi/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const { AzureOpenAiEmbedder } = require("../../EmbeddingEngines/azureOpenAi");
const { chatPrompt } = require("../../chats");

class AzureOpenAiLLM {
constructor(embedder = null) {
constructor(embedder = null, _modelPreference = null) {
const { OpenAIClient, AzureKeyCredential } = require("@azure/openai");
if (!process.env.AZURE_OPENAI_ENDPOINT)
throw new Error("No Azure API endpoint was set.");
Expand Down
5 changes: 3 additions & 2 deletions server/utils/AiProviders/gemini/index.js
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
const { chatPrompt } = require("../../chats");

class GeminiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.GEMINI_API_KEY)
throw new Error("No Gemini API key was set.");

// Docs: https://ai.google.dev/tutorials/node_quickstart
const { GoogleGenerativeAI } = require("@google/generative-ai");
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
this.model = process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
this.model =
modelPreference || process.env.GEMINI_LLM_MODEL_PREF || "gemini-pro";
this.gemini = genAI.getGenerativeModel({ model: this.model });
this.limits = {
history: this.promptWindowLimit() * 0.15,
Expand Down
4 changes: 2 additions & 2 deletions server/utils/AiProviders/lmStudio/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ const { chatPrompt } = require("../../chats");

// hybrid of openAi LLM chat completion for LMStudio
class LMStudioLLM {
constructor(embedder = null) {
constructor(embedder = null, _modelPreference = null) {
if (!process.env.LMSTUDIO_BASE_PATH)
throw new Error("No LMStudio API Base Path was set.");

Expand All @@ -12,7 +12,7 @@ class LMStudioLLM {
});
this.lmstudio = new OpenAIApi(config);
// When using LMStudios inference server - the model param is not required so
// we can stub it here.
// we can stub it here. LMStudio can only run one model at a time.
this.model = "model-placeholder";
this.limits = {
history: this.promptWindowLimit() * 0.15,
Expand Down
4 changes: 2 additions & 2 deletions server/utils/AiProviders/localAi/index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
const { chatPrompt } = require("../../chats");

class LocalAiLLM {
constructor(embedder = null) {
constructor(embedder = null, modelPreference = null) {
if (!process.env.LOCAL_AI_BASE_PATH)
throw new Error("No LocalAI Base Path was set.");

Expand All @@ -15,7 +15,7 @@ class LocalAiLLM {
: {}),
});
this.openai = new OpenAIApi(config);
this.model = process.env.LOCAL_AI_MODEL_PREF;
this.model = modelPreference || process.env.LOCAL_AI_MODEL_PREF;
this.limits = {
history: this.promptWindowLimit() * 0.15,
system: this.promptWindowLimit() * 0.15,
Expand Down
Loading

0 comments on commit 90df375

Please sign in to comment.