Skip to content

Commit 3076e0b

Browse files
committed
Only show GPU when we're actually using it.
1 parent 1fa67a5 commit 3076e0b

File tree

6 files changed

+29
-3
lines changed

6 files changed

+29
-3
lines changed

gpt4all-backend/llamamodel.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,16 @@ bool LLamaModel::hasGPUDevice()
337337
#endif
338338
}
339339

340+
bool LLamaModel::usingGPUDevice()
341+
{
342+
#if defined(GGML_USE_KOMPUTE)
343+
return ggml_vk_using_vulkan();
344+
#elif defined(GGML_USE_METAL)
345+
return true;
346+
#endif
347+
return false;
348+
}
349+
340350
#if defined(_WIN32)
341351
#define DLL_EXPORT __declspec(dllexport)
342352
#else

gpt4all-backend/llamamodel_impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class LLamaModel : public LLModel {
3030
bool initializeGPUDevice(const GPUDevice &device) override;
3131
bool initializeGPUDevice(int device) override;
3232
bool hasGPUDevice() override;
33+
bool usingGPUDevice() override;
3334

3435
private:
3536
LLamaPrivate *d_ptr;

gpt4all-backend/llmodel.h

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class LLModel {
100100
virtual bool initializeGPUDevice(const GPUDevice &/*device*/) { return false; }
101101
virtual bool initializeGPUDevice(int /*device*/) { return false; }
102102
virtual bool hasGPUDevice() { return false; }
103+
virtual bool usingGPUDevice() { return false; }
103104

104105
protected:
105106
// These are pure virtual because subclasses need to implement as the default implementation of

gpt4all-backend/replit.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ struct mpt_hparams {
163163
int32_t n_embd = 0; //max_seq_len
164164
int32_t n_head = 0; // n_heads
165165
int32_t n_layer = 0; //n_layers
166-
int32_t ftype = 0;
166+
int32_t ftype = 0;
167167
};
168168

169169
struct replit_layer {
@@ -220,7 +220,7 @@ static bool kv_cache_init(
220220
params.mem_size = cache.buf.size;
221221
params.mem_buffer = cache.buf.addr;
222222
params.no_alloc = false;
223-
223+
224224
cache.ctx = ggml_init(params);
225225
if (!cache.ctx) {
226226
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
@@ -503,7 +503,7 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode
503503
}
504504

505505
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "data", data_ptr, data_size, max_size));
506-
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "kv", ggml_get_mem_buffer(model.kv_self.ctx),
506+
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "kv", ggml_get_mem_buffer(model.kv_self.ctx),
507507
ggml_get_mem_size(model.kv_self.ctx), 0));
508508
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "eval", model.eval_buf.addr, model.eval_buf.size, 0));
509509
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr0", model.scr0_buf.addr, model.scr0_buf.size, 0));
@@ -975,6 +975,14 @@ const std::vector<LLModel::Token> &Replit::endTokens() const
975975
return fres;
976976
}
977977

978+
bool Replit::usingGPUDevice()
979+
{
980+
#if defined(GGML_USE_METAL)
981+
return true;
982+
#endif
983+
return false;
984+
}
985+
978986
#if defined(_WIN32)
979987
#define DLL_EXPORT __declspec(dllexport)
980988
#else

gpt4all-backend/replit_impl.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Replit : public LLModel {
2727
size_t restoreState(const uint8_t *src) override;
2828
void setThreadCount(int32_t n_threads) override;
2929
int32_t threadCount() const override;
30+
bool usingGPUDevice() override;
3031

3132
private:
3233
ReplitPrivate *d_ptr;

gpt4all-chat/chatllm.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,11 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
302302
m_llModelInfo = LLModelInfo();
303303
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
304304
} else {
305+
// We might have had to fallback to CPU after load if the model is not possible to accelerate
306+
// for instance if the quantization method is not supported on Vulkan yet
307+
if (actualDevice != "CPU" && !m_llModelInfo.model->usingGPUDevice())
308+
emit reportDevice("CPU");
309+
305310
switch (m_llModelInfo.model->implementation().modelType()[0]) {
306311
case 'L': m_llModelType = LLModelType::LLAMA_; break;
307312
case 'G': m_llModelType = LLModelType::GPTJ_; break;

0 commit comments

Comments
 (0)