Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { render, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { ModelStatusUpdate } from "@/controllers/API/queries/models/use-update-enabled-models";
import ModelSelection from "../components/ModelSelection";
import { Model } from "../components/types";

Expand Down Expand Up @@ -173,4 +174,90 @@ describe("ModelSelection", () => {
).not.toBeInTheDocument();
});
});

describe("Optimistic UI with pendingUpdates", () => {
it("should show pending enabled state when pendingUpdates has enabled=true", () => {
const pendingUpdates = new Map<string, ModelStatusUpdate>();
pendingUpdates.set("OpenAI:gpt-3.5-turbo", {
provider: "OpenAI",
model_id: "gpt-3.5-turbo",
enabled: true,
});

render(
<ModelSelection {...defaultProps} pendingUpdates={pendingUpdates} />,
);

// gpt-3.5-turbo is false in server state but true in pending
const toggle = screen.getByTestId("llm-toggle-gpt-3.5-turbo");
expect(toggle).toHaveAttribute("data-state", "checked");
});

it("should show pending disabled state when pendingUpdates has enabled=false", () => {
const pendingUpdates = new Map<string, ModelStatusUpdate>();
pendingUpdates.set("OpenAI:gpt-4", {
provider: "OpenAI",
model_id: "gpt-4",
enabled: false,
});

render(
<ModelSelection {...defaultProps} pendingUpdates={pendingUpdates} />,
);

// gpt-4 is true in server state but false in pending
const toggle = screen.getByTestId("llm-toggle-gpt-4");
expect(toggle).toHaveAttribute("data-state", "unchecked");
});

it("should use server state when model is not in pendingUpdates", () => {
const pendingUpdates = new Map<string, ModelStatusUpdate>();
// Only gpt-3.5-turbo is in pending, gpt-4 should use server state

render(
<ModelSelection {...defaultProps} pendingUpdates={pendingUpdates} />,
);

// gpt-4 is true in server state (mockEnabledModels)
const gpt4Toggle = screen.getByTestId("llm-toggle-gpt-4");
expect(gpt4Toggle).toHaveAttribute("data-state", "checked");

// gpt-3.5-turbo is false in server state
const gpt35Toggle = screen.getByTestId("llm-toggle-gpt-3.5-turbo");
expect(gpt35Toggle).toHaveAttribute("data-state", "unchecked");
});

it("should prioritize pendingUpdates over server state", () => {
const pendingUpdates = new Map<string, ModelStatusUpdate>();
// Override both models with opposite values
pendingUpdates.set("OpenAI:gpt-4", {
provider: "OpenAI",
model_id: "gpt-4",
enabled: false, // server has true
});
pendingUpdates.set("OpenAI:gpt-3.5-turbo", {
provider: "OpenAI",
model_id: "gpt-3.5-turbo",
enabled: true, // server has false
});

render(
<ModelSelection {...defaultProps} pendingUpdates={pendingUpdates} />,
);

const gpt4Toggle = screen.getByTestId("llm-toggle-gpt-4");
expect(gpt4Toggle).toHaveAttribute("data-state", "unchecked");

const gpt35Toggle = screen.getByTestId("llm-toggle-gpt-3.5-turbo");
expect(gpt35Toggle).toHaveAttribute("data-state", "checked");
});

it("should work without pendingUpdates prop (undefined)", () => {
render(<ModelSelection {...defaultProps} pendingUpdates={undefined} />);

// Should fall back to server state
const gpt4Toggle = screen.getByTestId("llm-toggle-gpt-4");
expect(gpt4Toggle).toHaveAttribute("data-state", "checked");
});
});
});
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { useQueryClient } from "@tanstack/react-query";
import { useEffect, useMemo, useRef, useState } from "react";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import {
NO_API_KEY_PROVIDERS,
PROVIDER_VARIABLE_MAPPING,
VARIABLE_CATEGORY,
} from "@/constants/providerConstants";
import { useUpdateEnabledModels } from "@/controllers/API/queries/models/use-update-enabled-models";
import {
ModelStatusUpdate,
useUpdateEnabledModels,
} from "@/controllers/API/queries/models/use-update-enabled-models";
import {
useGetGlobalVariables,
usePatchGlobalVariables,
Expand All @@ -17,6 +20,7 @@ import { useDebounce } from "@/hooks/use-debounce";
import ProviderList from "@/modals/modelProviderModal/components/ProviderList";
import { Provider } from "@/modals/modelProviderModal/components/types";
import useAlertStore from "@/stores/alertStore";
import { ResponseErrorDetailAPI } from "@/types/api";
import { cn } from "@/utils/utils";
import ModelSelection from "./ModelSelection";

Expand All @@ -34,6 +38,9 @@ const ModelProvidersContent = ({
);
const [apiKey, setApiKey] = useState("");
const [validationFailed, setValidationFailed] = useState(false);
const [pendingUpdates, setPendingUpdates] = useState<
Map<string, ModelStatusUpdate>
>(new Map());
// Track if API key change came from user typing (vs programmatic reset)
// Used to prevent auto-save from triggering when we clear the input after success
const isUserInputRef = useRef(false);
Expand Down Expand Up @@ -82,28 +89,60 @@ const ModelProvidersContent = ({
}
}, [apiKey, debouncedConfigureProvider]);

// Update enabled models when toggled
const handleModelToggle = (modelName: string, enabled: boolean) => {
if (!selectedProvider?.provider) return;
// Flush pending model updates to the server in a single batch
const flushPendingUpdates = useCallback(() => {
if (pendingUpdates.size === 0) return;

const updates = Array.from(pendingUpdates.values());
updateEnabledModels(
{
updates: [
{
provider: selectedProvider.provider,
model_id: modelName,
enabled,
},
],
},
{ updates },
{
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ["useGetEnabledModels"] });
},
},
);
setPendingUpdates(new Map());
}, [pendingUpdates, updateEnabledModels, queryClient]);

// Accumulate model toggle changes locally for batching
const handleModelToggle = (modelName: string, enabled: boolean) => {
if (!selectedProvider?.provider) return;

setPendingUpdates((prev) => {
const next = new Map(prev);
const key = `${selectedProvider.provider}:${modelName}`;
next.set(key, {
provider: selectedProvider.provider,
model_id: modelName,
enabled,
});
return next;
});
};

// Flush pending updates when provider changes
const prevProviderRef = useRef<string | null>(null);
useEffect(() => {
if (
prevProviderRef.current !== null &&
prevProviderRef.current !== selectedProvider?.provider
) {
flushPendingUpdates();
}
prevProviderRef.current = selectedProvider?.provider ?? null;
}, [selectedProvider?.provider, flushPendingUpdates]);

// Flush pending updates on unmount (modal close)
useEffect(() => {
return () => {
if (pendingUpdates.size > 0) {
const updates = Array.from(pendingUpdates.values());
updateEnabledModels({ updates });
}
};
}, [pendingUpdates, updateEnabledModels]);

// Toggle provider selection - clicking same provider deselects it
const handleProviderSelect = (provider: Provider) => {
setSelectedProvider((prev) =>
Expand Down Expand Up @@ -148,7 +187,7 @@ const ModelProvidersContent = ({
);
};

const onError = (error: any) => {
const onError = (error: ResponseErrorDetailAPI) => {
setErrorData({
title: "Error Activating Provider",
list: [
Expand Down Expand Up @@ -205,7 +244,7 @@ const ModelProvidersContent = ({
);
};

const onError = (error: any) => {
const onError = (error: ResponseErrorDetailAPI) => {
setValidationFailed(true);
setErrorData({
title: existingVariable
Expand Down Expand Up @@ -345,6 +384,7 @@ const ModelProvidersContent = ({
onModelToggle={handleModelToggle}
providerName={selectedProvider?.provider}
isEnabledModel={selectedProvider?.is_enabled}
pendingUpdates={pendingUpdates}
/>
</div>
</div>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ForwardedIconComponent from "@/components/common/genericIconComponent";
import { Switch } from "@/components/ui/switch";
import { useGetEnabledModels } from "@/controllers/API/queries/models/use-get-enabled-models";
import { ModelStatusUpdate } from "@/controllers/API/queries/models/use-update-enabled-models";

import { Model } from "@/modals/modelProviderModal/components/types";
import { cn } from "@/utils/utils";
Expand All @@ -11,6 +12,7 @@ export interface ModelProviderSelectionProps {
modelType: "llm" | "embeddings" | "all";
providerName?: string;
isEnabledModel?: boolean;
pendingUpdates?: Map<string, ModelStatusUpdate>;
}

interface ModelRowProps {
Expand Down Expand Up @@ -61,10 +63,16 @@ const ModelSelection = ({
onModelToggle,
providerName,
isEnabledModel,
pendingUpdates,
}: ModelProviderSelectionProps) => {
const { data: enabledModelsData } = useGetEnabledModels();

const isModelEnabled = (modelName: string): boolean => {
// Check pending updates first for optimistic UI
const key = `${providerName}:${modelName}`;
if (pendingUpdates?.has(key)) {
return pendingUpdates.get(key)!.enabled;
}
if (!providerName || !enabledModelsData?.enabled_models) return false;
return enabledModelsData.enabled_models[providerName]?.[modelName] ?? false;
};
Expand Down
Loading