Skip to content

Commit

Permalink
WIP Code interpreter tool call.
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Treat <[email protected]>
  • Loading branch information
manyoso committed Nov 7, 2024
1 parent 6895a48 commit 55875eb
Show file tree
Hide file tree
Showing 16 changed files with 726 additions and 14 deletions.
4 changes: 4 additions & 0 deletions gpt4all-chat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ qt_add_executable(chat
src/chatllm.cpp src/chatllm.h
src/chatmodel.h
src/chatviewtextprocessor.cpp src/chatviewtextprocessor.h
src/codeinterpreter.cpp src/codeinterpreter.h
src/database.cpp src/database.h
src/download.cpp src/download.h
src/embllm.cpp src/embllm.h
Expand All @@ -199,6 +200,9 @@ qt_add_executable(chat
src/mysettings.cpp src/mysettings.h
src/network.cpp src/network.h
src/server.cpp src/server.h
src/tool.cpp src/tool.h
src/toolcallparser.cpp src/toolcallparser.h
src/toolmodel.cpp src/toolmodel.h
src/xlsxtomd.cpp src/xlsxtomd.h
${CHAT_EXE_RESOURCES}
${MACOS_SOURCES}
Expand Down
18 changes: 11 additions & 7 deletions gpt4all-chat/qml/ChatItemView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GridLayout {
Layout.alignment: Qt.AlignVCenter | Qt.AlignRight
Layout.preferredWidth: 32
Layout.preferredHeight: 32
Layout.topMargin: model.index > 0 ? 25 : 0

Image {
id: logo
sourceSize: Qt.size(32, 32)
Expand Down Expand Up @@ -50,6 +52,8 @@ GridLayout {
Layout.column: 1
Layout.fillWidth: true
Layout.preferredHeight: 38
Layout.topMargin: model.index > 0 ? 25 : 0

RowLayout {
spacing: 5
anchors.left: parent.left
Expand Down Expand Up @@ -240,7 +244,7 @@ GridLayout {
Component.onCompleted: {
resetChatViewTextProcessor();
chatModel.valueChanged.connect(function(i, value) {
if (index === i)
if (model.index === i)
textProcessor.setValue(value);
}
);
Expand Down Expand Up @@ -274,9 +278,9 @@ GridLayout {
if (thumbsDownState && !thumbsUpState && !responseHasChanged)
return

chatModel.updateNewResponse(index, response)
chatModel.updateThumbsUpState(index, false)
chatModel.updateThumbsDownState(index, true)
chatModel.updateNewResponse(model.index, response)
chatModel.updateThumbsUpState(model.index, false)
chatModel.updateThumbsDownState(model.index, true)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
Expand Down Expand Up @@ -305,9 +309,9 @@ GridLayout {
if (thumbsUpState && !thumbsDownState)
return

chatModel.updateNewResponse(index, "")
chatModel.updateThumbsUpState(index, true)
chatModel.updateThumbsDownState(index, false)
chatModel.updateNewResponse(model.index, "")
chatModel.updateThumbsUpState(model.index, true)
chatModel.updateThumbsDownState(model.index, false)
Network.sendConversation(currentChat.id, getConversationJson());
}
}
Expand Down
3 changes: 2 additions & 1 deletion gpt4all-chat/qml/ChatView.qml
Original file line number Diff line number Diff line change
Expand Up @@ -790,7 +790,6 @@ Rectangle {
Layout.leftMargin: 50
Layout.rightMargin: 50
Layout.alignment: Qt.AlignHCenter
spacing: 25
model: chatModel
cacheBuffer: Math.max(0, listView.contentHeight)

Expand All @@ -804,6 +803,8 @@ Rectangle {

delegate: ChatItemView {
width: listView.contentItem.width - 15
visible: name !== "ToolResponse: "
height: visible ? implicitHeight : 0
}

function scrollToEnd() {
Expand Down
63 changes: 62 additions & 1 deletion gpt4all-chat/src/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
#include "chatlistmodel.h"
#include "network.h"
#include "server.h"
#include "tool.h"
#include "toolcallparser.h"
#include "toolmodel.h"

#include <QBuffer>
#include <QDataStream>
#include <QDebug>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QLatin1String>
#include <QMap>
#include <QString>
Expand All @@ -16,6 +22,8 @@

#include <utility>

using namespace ToolEnums;

Chat::Chat(QObject *parent)
: QObject(parent)
, m_id(Network::globalInstance()->generateUniqueId())
Expand Down Expand Up @@ -222,6 +230,42 @@ void Chat::generatingQuestions()
emit responseStateChanged();
}

QString executeToolCall(const QString &toolCall, QString &errorString)
{
QJsonParseError err;
const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err);

if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) {
errorString = QString("ERROR: The tool call had null or invalid json %1").arg(toolCall);
return QString();
}

QJsonObject rootObject = toolCallDoc.object();
if (!rootObject.contains("name") || !rootObject.contains("parameters")) {
errorString = QString("ERROR: The tool call did not have required name and argument objects %1").arg(toolCall);
return QString();
}

const QString tool = toolCallDoc["name"].toString();
const QJsonObject args = toolCallDoc["parameters"].toObject();

Tool *toolInstance = ToolModel::globalInstance()->get(tool);
if (!toolInstance) {
errorString = QString("ERROR: Could not find the tool for %1").arg(toolCall);
return QString();
}

// FIXME: Honor the confirmation mode feature

const QString response = toolInstance->run(args, 10000 /*msecs to timeout*/);
if (toolInstance->error() != Error::NoError) {
errorString = QString("ERROR: Tool call produced error: %1").arg(toolInstance->errorString());
return QString();
}

return response;
}

void Chat::responseStopped(qint64 promptResponseMs)
{
m_tokenSpeed = QString();
Expand All @@ -231,8 +275,25 @@ void Chat::responseStopped(qint64 promptResponseMs)
m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged();
emit responseStateChanged();
if (m_generatedName.isEmpty())

const int index = m_chatModel->count() - 1;
ChatItem item = m_chatModel->get(index);
ToolCallParser parser;
parser.update(item.value);
if (item.type() == ChatItem::Type::Response && parser.state() == ToolCallParser::Complete) {
const QString toolCall = parser.toolCall();
QString errorString;
const QString toolResponse = executeToolCall(toolCall, errorString);
qDebug() << toolCall << toolResponse;
resetResponseState();
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendToolResponse(toolResponse);
m_chatModel->appendResponse();
emit promptRequested(m_collections); // triggers a new response
return;
} else if (m_generatedName.isEmpty()) {
emit generateNameRequested();
}

Network::globalInstance()->trackChatEvent("response_complete", {
{"first", m_firstResponse},
Expand Down
27 changes: 23 additions & 4 deletions gpt4all-chat/src/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include "localdocs.h"
#include "mysettings.h"
#include "network.h"
#include "tool.h"
#include "toolmodel.h"
#include "toolcallparser.h"

#include <fmt/format.h>

Expand Down Expand Up @@ -48,6 +51,7 @@
#include <vector>

using namespace Qt::Literals::StringLiterals;
using namespace ToolEnums;
namespace ranges = std::ranges;

//#define DEBUG
Expand Down Expand Up @@ -739,9 +743,18 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const ChatItem> items) const
for (auto &item : items)
messages.emplace_back(makeMap(item));

jinja2::ValuesList toolList;
const int toolCount = ToolModel::globalInstance()->count();
for (int i = 0; i < toolCount; ++i) {
Tool *t = ToolModel::globalInstance()->get(i);
if (t->usageMode() == UsageMode::Enabled)
toolList.push_back(t->jinjaValue());
}

jinja2::ValuesMap params {
{ "messages", std::move(messages) },
{ "add_generation_prompt", true },
{ "toolList", toolList },
};
if (auto token = model->bosToken())
params.emplace("bos_token", std::move(*token));
Expand Down Expand Up @@ -844,14 +857,18 @@ auto ChatLLM::promptInternal(
return !m_stopGenerating;
};

auto handleResponse = [this, &result](LLModel::Token token, std::string_view piece) -> bool {
ToolCallParser toolCallParser;
auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool {
Q_UNUSED(token)
result.responseTokens++;
m_timer->inc();

toolCallParser.update(QString::fromStdString(piece.data()));
result.response.append(piece.data(), piece.size());
auto respStr = QString::fromUtf8(result.response);
emit responseChanged(removeLeadingWhitespace(respStr));
return !m_stopGenerating;
const bool foundToolCall = toolCallParser.state() == ToolCallParser::Complete;
return !foundToolCall && !m_stopGenerating;
};

QElapsedTimer totalTime;
Expand All @@ -871,13 +888,15 @@ auto ChatLLM::promptInternal(
m_timer->stop();
qint64 elapsed = totalTime.elapsed();

const bool foundToolCall = toolCallParser.state() == ToolCallParser::Complete;

// trim trailing whitespace
auto respStr = QString::fromUtf8(result.response);
if (!respStr.isEmpty() && std::as_const(respStr).back().isSpace())
if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall))
emit responseChanged(respStr.trimmed());

bool doQuestions = false;
if (!m_isServer && chatItems) {
if (!m_isServer && chatItems && !foundToolCall) {
switch (mySettings->suggestionMode()) {
case SuggestionMode::On: doQuestions = true; break;
case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break;
Expand Down
25 changes: 24 additions & 1 deletion gpt4all-chat/src/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct ChatItem
Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState)

public:
enum class Type { System, Prompt, Response };
enum class Type { System, Prompt, Response, ToolResponse };

// tags for constructing ChatItems
struct prompt_tag_t { explicit prompt_tag_t() = default; };
Expand All @@ -93,6 +93,8 @@ struct ChatItem
static inline constexpr response_tag_t response_tag = response_tag_t();
struct system_tag_t { explicit system_tag_t() = default; };
static inline constexpr system_tag_t system_tag = system_tag_t();
struct tool_response_tag_t { explicit tool_response_tag_t() = default; };
static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t();

// FIXME(jared): This should not be necessary. QML should see null or undefined if it
// tries to access something invalid.
Expand All @@ -108,6 +110,9 @@ struct ChatItem
ChatItem(response_tag_t, bool currentResponse = true)
: name(u"Response: "_s), currentResponse(currentResponse) {}

ChatItem(tool_response_tag_t)
: name(u"ToolResponse: "_s) {}

Type type() const
{
if (name == u"System: "_s)
Expand All @@ -116,6 +121,8 @@ struct ChatItem
return Type::Prompt;
if (name == u"Response: "_s)
return Type::Response;
if (name == u"ToolResponse: "_s)
return Type::ToolResponse;
throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name));
}

Expand Down Expand Up @@ -313,6 +320,22 @@ class ChatModel : public QAbstractListModel
emit hasErrorChanged(false);
}

void appendToolResponse(const QString &value)
{
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
ChatItem item(ChatItem::tool_response_tag);
item.value = value;
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
}
endInsertRows();
emit countChanged();
}

Q_INVOKABLE void clear()
{
{
Expand Down
Loading

0 comments on commit 55875eb

Please sign in to comment.