Skip to content

Commit

Permalink
added support for chat.completion to return sse response
Browse files Browse the repository at this point in the history
Signed-off-by: raja <[email protected]>
  • Loading branch information
raja-jamwal committed Sep 8, 2024
1 parent 1aae4ff commit 05f125d
Showing 1 changed file with 82 additions and 183 deletions.
265 changes: 82 additions & 183 deletions gpt4all-chat/src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
#include <QtLogging>

#include <iostream>
#include <string>
#include <type_traits>
#include <utility>

using namespace Qt::Literals::StringLiterals;
Expand Down Expand Up @@ -207,26 +205,29 @@ void Server::start()

QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &request, bool isChat)
{
// We've been asked to do a completion...
// Parse JSON request
QJsonParseError err;
const QJsonDocument document = QJsonDocument::fromJson(request.body(), &err);
if (err.error || !document.isObject()) {
std::cerr << "ERROR: invalid json in completions body" << std::endl;
std::cerr << "ERROR: invalid JSON in completions body" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}

#if defined(DEBUG)
printf("/v1/completions %s\n", qPrintable(document.toJson(QJsonDocument::Indented)));
fflush(stdout);
#endif

const QJsonObject body = document.object();
if (!body.contains("model")) { // required
std::cerr << "ERROR: completions contains no model" << std::endl;
if (!body.contains("model")) {
std::cerr << "ERROR: completions contain no model" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}

QJsonArray messages;
if (isChat) {
if (!body.contains("messages")) {
std::cerr << "ERROR: chat completions contains no messages" << std::endl;
std::cerr << "ERROR: chat completions contain no messages" << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::NoContent);
}
messages = body["messages"].toArray();
Expand All @@ -236,16 +237,12 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
ModelInfo modelInfo = ModelList::globalInstance()->defaultModelInfo();
const QList<ModelInfo> modelList = ModelList::globalInstance()->selectableModelList();
for (const ModelInfo &info : modelList) {
Q_ASSERT(info.installed);
if (!info.installed)
continue;
if (modelRequested == info.name() || modelRequested == info.filename()) {
if (info.installed && (modelRequested == info.name() || modelRequested == info.filename())) {
modelInfo = info;
break;
}
}

// We only support one prompt for now
QList<QString> prompts;
if (body.contains("prompt")) {
QJsonValue promptValue = body["prompt"];
Expand All @@ -256,217 +253,119 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
for (const QJsonValue &v : array)
prompts.append(v.toString());
}
} else
} else {
prompts.append(" ");

int max_tokens = 16;
if (body.contains("max_tokens"))
max_tokens = body["max_tokens"].toInt();

float temperature = 1.f;
if (body.contains("temperature"))
temperature = body["temperature"].toDouble();

float top_p = 1.f;
if (body.contains("top_p"))
top_p = body["top_p"].toDouble();

float min_p = 0.f;
if (body.contains("min_p"))
min_p = body["min_p"].toDouble();

int n = 1;
if (body.contains("n"))
n = body["n"].toInt();

int logprobs = -1; // supposed to be null by default??
if (body.contains("logprobs"))
logprobs = body["logprobs"].toInt();

bool echo = false;
if (body.contains("echo"))
echo = body["echo"].toBool();

// We currently don't support any of the following...
#if 0
// FIXME: Need configurable reverse prompts
QList<QString> stop;
if (body.contains("stop")) {
QJsonValue stopValue = body["stop"];
if (stopValue.isString())
stop.append(stopValue.toString());
else {
QJsonArray array = stopValue.toArray();
for (QJsonValue v : array)
stop.append(v.toString());
}
}

// FIXME: QHttpServer doesn't support server-sent events
bool stream = false;
if (body.contains("stream"))
stream = body["stream"].toBool();

// FIXME: What does this do?
QString suffix;
if (body.contains("suffix"))
suffix = body["suffix"].toString();

// FIXME: We don't support
float presence_penalty = 0.f;
if (body.contains("presence_penalty"))
top_p = body["presence_penalty"].toDouble();

// FIXME: We don't support
float frequency_penalty = 0.f;
if (body.contains("frequency_penalty"))
top_p = body["frequency_penalty"].toDouble();

// FIXME: We don't support
int best_of = 1;
if (body.contains("best_of"))
logprobs = body["best_of"].toInt();

// FIXME: We don't need
QString user;
if (body.contains("user"))
suffix = body["user"].toString();
#endif
int max_tokens = body.value("max_tokens").toInt(16);
float temperature = body.value("temperature").toDouble(1.0);
float top_p = body.value("top_p").toDouble(1.0);
float min_p = body.value("min_p").toDouble(0.0);
int n = body.value("n").toInt(1);
bool echo = body.value("echo").toBool(false);

QString actualPrompt = prompts.first();

// if we're a chat completion we have messages which means we need to prepend these to the prompt
if (!messages.isEmpty()) {
QList<QString> chats;
for (int i = 0; i < messages.count(); ++i) {
QJsonValue v = messages.at(i);
// FIXME: Deal with system messages correctly
QString role = v.toObject()["role"].toString();
if (role != "user")
continue;
QString content = v.toObject()["content"].toString();
for (int i = 0; i < messages.count(); ++i) {
QString content = messages.at(i).toObject()["content"].toString();
if (!content.endsWith("\n") && i < messages.count() - 1)
content += "\n";
chats.append(content);
}
actualPrompt.prepend(chats.join("\n"));
}

// adds prompt/response items to GUI
emit requestServerNewPromptResponsePair(actualPrompt); // blocks

// load the new model if necessary
setShouldBeLoaded(true);

if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::BadRequest);
}

// NB: this resets the context, regardless of whether this model is already loaded
if (!loadModel(modelInfo)) {
} else if (!loadModel(modelInfo)) {
std::cerr << "ERROR: couldn't load model " << modelInfo.name().toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
}

const QString promptTemplate = modelInfo.promptTemplate();
const float top_k = modelInfo.topK();
const int n_batch = modelInfo.promptBatchSize();
const float repeat_penalty = modelInfo.repeatPenalty();
const int repeat_last_n = modelInfo.repeatPenaltyTokens();
resetContext();

QByteArray responseData;
QTextStream stream(&responseData, QIODevice::WriteOnly);

QString randomId = "chatcmpl-" + QUuid::createUuid().toString(QUuid::WithoutBraces).replace("-", "");

int promptTokens = 0;
int responseTokens = 0;
QList<QPair<QString, QList<ResultInfo>>> responses;
for (int i = 0; i < n; ++i) {
if (!promptInternal(
m_collections,
actualPrompt,
promptTemplate,
max_tokens /*n_predict*/,
top_k,
top_p,
min_p,
temperature,
n_batch,
repeat_penalty,
repeat_last_n)) {
if (!promptInternal(m_collections,
actualPrompt,
modelInfo.promptTemplate(),
max_tokens /*n_predict*/,
modelInfo.topK(),
top_p,
min_p,
temperature,
modelInfo.promptBatchSize(),
modelInfo.repeatPenalty(),
modelInfo.repeatPenaltyTokens())) {

std::cerr << "ERROR: couldn't prompt model " << modelInfo.name().toStdString() << std::endl;
return QHttpServerResponse(QHttpServerResponder::StatusCode::InternalServerError);
}
QString echoedPrompt = actualPrompt;
if (!echoedPrompt.endsWith("\n"))
echoedPrompt += "\n";
responses.append(qMakePair((echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response(), m_databaseResults));
if (!promptTokens)
promptTokens += m_promptTokens;
responseTokens += m_promptResponseTokens - m_promptTokens;
if (i != n - 1)
resetResponse();
}

QJsonObject responseObject;
responseObject.insert("id", "foobarbaz");
responseObject.insert("object", "text_completion");
responseObject.insert("created", QDateTime::currentSecsSinceEpoch());
responseObject.insert("model", modelInfo.name());
QString result = (echo ? u"%1\n"_s.arg(actualPrompt) : QString()) + response();

QJsonArray choices;
for (const QString &token : result.split(' ')) {
QJsonObject delta;
delta.insert("content", token + " ");

if (isChat) {
int index = 0;
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("index", index++);
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
QJsonObject message;
message.insert("role", "assistant");
message.insert("content", result);
choice.insert("message", message);
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
}
choices.append(choice);
}
} else {
int index = 0;
for (const auto &r : responses) {
QString result = r.first;
QList<ResultInfo> infos = r.second;
QJsonObject choice;
choice.insert("text", result);
choice.insert("index", index++);
choice.insert("logprobs", QJsonValue::Null); // We don't support
choice.insert("finish_reason", responseTokens == max_tokens ? "length" : "stop");
if (MySettings::globalInstance()->localDocsShowReferences()) {
QJsonArray references;
for (const auto &ref : infos)
references.append(resultToJson(ref));
choice.insert("references", references);
}
choices.append(choice);
choice.insert("index", i);
choice.insert("delta", delta);

QJsonObject responseChunk;
responseChunk.insert("id", randomId);
responseChunk.insert("object", "chat.completion.chunk");
responseChunk.insert("created", QDateTime::currentSecsSinceEpoch());
responseChunk.insert("model", modelInfo.name());
responseChunk.insert("choices", QJsonArray{choice});

stream << "data: " << QJsonDocument(responseChunk).toJson(QJsonDocument::Compact) << "\n\n";
stream.flush();
}

if (i != n - 1)
resetResponse();
}

responseObject.insert("choices", choices);
// Final empty delta to signify the end of the stream
QJsonObject delta;
delta.insert("content", QJsonValue::Null);

QJsonObject usage;
usage.insert("prompt_tokens", int(promptTokens));
usage.insert("completion_tokens", int(responseTokens));
usage.insert("total_tokens", int(promptTokens + responseTokens));
responseObject.insert("usage", usage);
QJsonObject choice;
choice.insert("index", 0);
choice.insert("delta", delta);
choice.insert("finish_reason", "stop");

#if defined(DEBUG)
QJsonDocument newDoc(responseObject);
printf("/v1/completions %s\n", qPrintable(newDoc.toJson(QJsonDocument::Indented)));
fflush(stdout);
#endif
QJsonObject finalChunk;
finalChunk.insert("id", randomId);
finalChunk.insert("object", "chat.completion.chunk");
finalChunk.insert("created", QDateTime::currentSecsSinceEpoch());
finalChunk.insert("model", modelInfo.name());
finalChunk.insert("choices", QJsonArray{choice});

stream << "data: " << QJsonDocument(finalChunk).toJson(QJsonDocument::Compact) << "\n\n";
stream << "data: [DONE]\n\n";
stream.flush();

// Log the entire response data
qDebug() << "Full SSE Response:\n" << responseData;

// Create the response
QHttpServerResponse response(responseData, QHttpServerResponder::StatusCode::Ok);
response.setHeader("Content-Type", "text/event-stream");
response.setHeader("Cache-Control", "no-cache");
response.setHeader("Connection", "keep-alive");

return QHttpServerResponse(responseObject);
return response;
}

0 comments on commit 05f125d

Please sign in to comment.