Skip to content

Commit c9c8525

Browse files
committed
load model in background
1 parent 3fa86dd commit c9c8525

File tree

5 files changed

+81
-33
lines changed

5 files changed

+81
-33
lines changed

src/App.jsx

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import SendIcon from "./components/icons/SendIcon";
55
import StopIcon from "./components/icons/StopIcon";
66
import GitHubIcon from "./components/icons/GitHubIcon";
77
import ModelSelector, { AVAILABLE_MODELS } from "./components/ModelSelector";
8-
import LoadingModal from "./components/LoadingModal";
98
import ModelSelectionModal from "./components/ModelSelectionModal";
9+
import InlineProgress from "./components/InlineProgress";
1010

1111
const IS_WEBGPU_AVAILABLE = !!navigator.gpu;
1212
const STICKY_SCROLL_THRESHOLD = 120;
@@ -65,23 +65,33 @@ function App() {
6565
// Inputs and outputs
6666
const [input, setInput] = useState("");
6767
const [messages, setMessages] = useState([]);
68+
const [queuedMessage, setQueuedMessage] = useState(null); // For storing message when model is loading
6869
const [tps, setTps] = useState(null);
6970
const [numTokens, setNumTokens] = useState(null);
7071

7172
function onEnter(message) {
73+
// Prevent queueing multiple messages
74+
if (status === "loading" && queuedMessage) {
75+
return;
76+
}
77+
78+
// Always add the message to the chat immediately for visibility
7279
setMessages((prev) => [...prev, { role: "user", content: message }]);
7380
setTps(null);
7481
setInput("");
7582

76-
// Load model if not already loaded
7783
if (status !== "ready") {
78-
setStatus("loading");
79-
setIsRunning(false); // Don't set running yet, wait for model to load
80-
worker.current.postMessage({
81-
type: "load",
82-
model_id: selectedModel
83-
});
84+
// Model not ready - queue the message and start loading if not already loading
85+
setQueuedMessage(message);
86+
if (status !== "loading") {
87+
setStatus("loading");
88+
worker.current.postMessage({
89+
type: "load",
90+
model_id: selectedModel
91+
});
92+
}
8493
} else {
94+
// Model ready - start generation immediately
8595
setIsRunning(true);
8696
}
8797
}
@@ -94,14 +104,16 @@ function App() {
94104

95105
function handleModelChange(modelId) {
96106
if (modelId === selectedModel) return;
97-
if (isRunning) return; // Prevent model switching during text generation
107+
if (isRunning || status === "loading") return; // Prevent model switching during text generation or loading
98108

99109
setSelectedModel(modelId);
100110
setStoredModel(modelId); // Save to localStorage
101111
setStatus("loading");
102-
// Don't clear messages - keep chat history
112+
// Don't clear messages - keep chat history, but clear queued message since we're switching models
113+
setQueuedMessage(null); // Clear any queued message when switching models
103114
setProgressItems([]);
104115

116+
// Start loading new model
105117
worker.current.postMessage({
106118
type: "load",
107119
model_id: modelId
@@ -178,9 +190,9 @@ function App() {
178190
case "ready":
179191
// Pipeline ready: the worker is ready to accept messages.
180192
setStatus("ready");
181-
// If we have messages and were waiting for the model to load, start generation
182-
// Only set running to true if we have user messages waiting to be processed
183-
if (messages.length > 0 && messages[messages.length - 1].role === "user") {
193+
// If we have a queued message, start generation
194+
if (queuedMessage) {
195+
setQueuedMessage(null);
184196
setIsRunning(true);
185197
}
186198
break;
@@ -325,14 +337,15 @@ function App() {
325337
onClick={() => {
326338
worker.current.postMessage({ type: "reset" });
327339
setMessages([]);
340+
setQueuedMessage(null); // Clear queued message
328341
}}
329342
>
330343
New Chat
331344
</button>
332345
<ModelSelector
333346
selectedModel={selectedModel}
334347
onModelChange={handleModelChange}
335-
disabled={status === "loading" || isRunning}
348+
disabled={isRunning || status === "loading"} // Disable during generation or loading
336349
/>
337350
</div>
338351

@@ -357,14 +370,6 @@ function App() {
357370
/>
358371
)}
359372

360-
{/* Loading Modal */}
361-
{status === "loading" && (
362-
<LoadingModal
363-
loadingMessage={loadingMessage}
364-
progressItems={progressItems}
365-
/>
366-
)}
367-
368373
{(status === "ready" || status === "loading" || status === null) && (
369374
<div
370375
ref={chatContainerRef}
@@ -417,15 +422,28 @@ function App() {
417422
</div>
418423
)}
419424

425+
{/* Inline Progress Display */}
426+
<InlineProgress
427+
loadingMessage={loadingMessage}
428+
progressItems={progressItems}
429+
isVisible={status === "loading"}
430+
/>
431+
420432
<div className="mt-2 border dark:bg-gray-700 rounded-lg w-[800px] max-w-[80%] max-h-[200px] mx-auto relative mb-3 flex">
421433
<textarea
422434
ref={textareaRef}
423435
className="scrollbar-thin w-full dark:bg-gray-700 px-3 py-4 rounded-lg bg-transparent border-none outline-none text-gray-800 disabled:text-gray-400 dark:text-gray-200 placeholder-gray-500 dark:placeholder-gray-400 disabled:placeholder-gray-200 resize-none disabled:cursor-not-allowed"
424-
placeholder="Type your message..."
436+
placeholder={
437+
status === "loading" && queuedMessage
438+
? "Message queued, loading model..."
439+
: status === "loading"
440+
? "Type your message (will be queued)..."
441+
: "Type your message..."
442+
}
425443
type="text"
426444
rows={1}
427445
value={input}
428-
disabled={status === "loading" || isRunning}
446+
disabled={isRunning || (status === "loading" && queuedMessage)} // Disable when running or when a message is already queued
429447
title={status === "ready" ? "Model is ready" : status === "loading" ? "Loading model..." : "Send a message to load the model"}
430448
autoComplete="off"
431449
autoCorrect="off"
@@ -435,7 +453,8 @@ function App() {
435453
onKeyDown={(e) => {
436454
if (
437455
input.length > 0 &&
438-
!isRunning &&
456+
!isRunning && // Prevent during generation
457+
!(status === "loading" && queuedMessage) && // Prevent when message already queued
439458
e.key === "Enter" &&
440459
!e.shiftKey
441460
) {
@@ -449,7 +468,7 @@ function App() {
449468
<div className="cursor-pointer" onClick={onInterrupt}>
450469
<StopIcon className="h-8 w-8 p-1 rounded-md text-gray-800 dark:text-gray-100 absolute right-3 bottom-3" />
451470
</div>
452-
) : input.length > 0 ? (
471+
) : input.length > 0 && !(status === "loading" && queuedMessage) ? (
453472
<div className="cursor-pointer" onClick={() => onEnter(input)}>
454473
<SendIcon
455474
className={`h-8 w-8 p-1 bg-gray-800 dark:bg-gray-100 text-white dark:text-black rounded-md absolute right-3 bottom-3`}

src/components/Chat.jsx

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,7 @@ export default function Chat({ messages, isRunning, loading = false, selectedMod
234234
loading ? "opacity-50 pointer-events-none" : "opacity-100"
235235
} ${empty ? "flex flex-col items-center justify-end" : "space-y-4"}`}
236236
>
237-
{empty ? (
238-
<div className="text-xl">Ready!</div>
239-
) : (
240-
messages.map((msg, i) => (
237+
{messages.map((msg, i) => (
241238
<div key={`message-${i}`} className={`flex ${msg.role === "user" ? "justify-end" : "justify-start"}`}>
242239
{msg.role === "assistant" ? (
243240
<div className="relative group w-full max-w-none">
@@ -278,8 +275,7 @@ export default function Chat({ messages, isRunning, loading = false, selectedMod
278275
</div>
279276
)}
280277
</div>
281-
))
282-
)}
278+
))}
283279
</div>
284280
);
285281
}

src/components/InlineProgress.jsx

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import Progress from './Progress';
2+
3+
export default function InlineProgress({ loadingMessage, progressItems, isVisible }) {
4+
if (!isVisible) return null;
5+
6+
return (
7+
<div className="w-[800px] max-w-[80%] mx-auto mb-3 p-4 bg-gray-50 dark:bg-gray-800 rounded-lg border border-gray-200 dark:border-gray-700">
8+
<div className="mb-3">
9+
<p className="text-sm text-gray-600 dark:text-gray-400 text-center">
10+
{loadingMessage || 'Preparing model...'}
11+
</p>
12+
</div>
13+
14+
<div className="space-y-2">
15+
{progressItems.map(({ file, progress, total }, i) => (
16+
<Progress
17+
key={i}
18+
text={file}
19+
percentage={progress}
20+
total={total}
21+
/>
22+
))}
23+
</div>
24+
25+
{progressItems.length === 0 && (
26+
<div className="flex justify-center">
27+
<div className="animate-spin rounded-full h-6 w-6 border-b-2 border-blue-500"></div>
28+
</div>
29+
)}
30+
</div>
31+
);
32+
}

src/components/ModelSelectionModal.jsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ export default function ModelSelectionModal({ onModelSelect, onClose }) {
5959
<svg className="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24">
6060
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 10V3L4 14h7v7l9-11h-7z" />
6161
</svg>
62-
Load Model
62+
Continue
6363
<div className="absolute inset-0 rounded-lg bg-gradient-to-r from-blue-600 to-blue-700 opacity-0 group-hover:opacity-20 transition-opacity duration-200"></div>
6464
</button>
6565
</div>

src/worker.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ async function load(model_id) {
175175
// Run model with dummy input to compile shaders
176176
const inputs = tokenizer("a");
177177
await model.generate({ ...inputs, max_new_tokens: 1 });
178+
178179
self.postMessage({ status: "ready" });
179180
} catch (error) {
180181
// Check if this is an unsupported model error

0 commit comments

Comments
 (0)