From 55875eb2a0399a6420076e875768f9bb30e860c8 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 7 Nov 2024 10:37:04 -0500 Subject: [PATCH] WIP Code interpreter tool call. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 4 + gpt4all-chat/qml/ChatItemView.qml | 18 ++-- gpt4all-chat/qml/ChatView.qml | 3 +- gpt4all-chat/src/chat.cpp | 63 +++++++++++++- gpt4all-chat/src/chatllm.cpp | 27 +++++- gpt4all-chat/src/chatmodel.h | 25 +++++- gpt4all-chat/src/codeinterpreter.cpp | 100 +++++++++++++++++++++ gpt4all-chat/src/codeinterpreter.h | 55 ++++++++++++ gpt4all-chat/src/jinja_helpers.cpp | 3 + gpt4all-chat/src/main.cpp | 3 + gpt4all-chat/src/tool.cpp | 32 +++++++ gpt4all-chat/src/tool.h | 125 +++++++++++++++++++++++++++ gpt4all-chat/src/toolcallparser.cpp | 86 ++++++++++++++++++ gpt4all-chat/src/toolcallparser.h | 34 ++++++++ gpt4all-chat/src/toolmodel.cpp | 41 +++++++++ gpt4all-chat/src/toolmodel.h | 121 ++++++++++++++++++++++++++ 16 files changed, 726 insertions(+), 14 deletions(-) create mode 100644 gpt4all-chat/src/codeinterpreter.cpp create mode 100644 gpt4all-chat/src/codeinterpreter.h create mode 100644 gpt4all-chat/src/tool.cpp create mode 100644 gpt4all-chat/src/tool.h create mode 100644 gpt4all-chat/src/toolcallparser.cpp create mode 100644 gpt4all-chat/src/toolcallparser.h create mode 100644 gpt4all-chat/src/toolmodel.cpp create mode 100644 gpt4all-chat/src/toolmodel.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 2199cb238020..65494875a4af 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -187,6 +187,7 @@ qt_add_executable(chat src/chatllm.cpp src/chatllm.h src/chatmodel.h src/chatviewtextprocessor.cpp src/chatviewtextprocessor.h + src/codeinterpreter.cpp src/codeinterpreter.h src/database.cpp src/database.h src/download.cpp src/download.h src/embllm.cpp src/embllm.h @@ -199,6 +200,9 @@ qt_add_executable(chat src/mysettings.cpp src/mysettings.h src/network.cpp src/network.h src/server.cpp src/server.h + src/tool.cpp src/tool.h + src/toolcallparser.cpp src/toolcallparser.h + src/toolmodel.cpp src/toolmodel.h src/xlsxtomd.cpp src/xlsxtomd.h ${CHAT_EXE_RESOURCES} ${MACOS_SOURCES} diff --git a/gpt4all-chat/qml/ChatItemView.qml b/gpt4all-chat/qml/ChatItemView.qml index 3beb3311ddb6..5858d474e749 100644 --- a/gpt4all-chat/qml/ChatItemView.qml +++ b/gpt4all-chat/qml/ChatItemView.qml @@ -18,6 +18,8 @@ GridLayout { Layout.alignment: Qt.AlignVCenter | Qt.AlignRight Layout.preferredWidth: 32 Layout.preferredHeight: 32 + Layout.topMargin: model.index > 0 ? 25 : 0 + Image { id: logo sourceSize: Qt.size(32, 32) @@ -50,6 +52,8 @@ GridLayout { Layout.column: 1 Layout.fillWidth: true Layout.preferredHeight: 38 + Layout.topMargin: model.index > 0 ? 25 : 0 + RowLayout { spacing: 5 anchors.left: parent.left @@ -240,7 +244,7 @@ GridLayout { Component.onCompleted: { resetChatViewTextProcessor(); chatModel.valueChanged.connect(function(i, value) { - if (index === i) + if (model.index === i) textProcessor.setValue(value); } ); @@ -274,9 +278,9 @@ GridLayout { if (thumbsDownState && !thumbsUpState && !responseHasChanged) return - chatModel.updateNewResponse(index, response) - chatModel.updateThumbsUpState(index, false) - chatModel.updateThumbsDownState(index, true) + chatModel.updateNewResponse(model.index, response) + chatModel.updateThumbsUpState(model.index, false) + chatModel.updateThumbsDownState(model.index, true) Network.sendConversation(currentChat.id, getConversationJson()); } } @@ -305,9 +309,9 @@ GridLayout { if (thumbsUpState && !thumbsDownState) return - chatModel.updateNewResponse(index, "") - chatModel.updateThumbsUpState(index, true) - chatModel.updateThumbsDownState(index, false) + chatModel.updateNewResponse(model.index, "") + chatModel.updateThumbsUpState(model.index, true) + chatModel.updateThumbsDownState(model.index, false) Network.sendConversation(currentChat.id, getConversationJson()); } } diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 1212cd2ffcce..4379619e5921 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -790,7 +790,6 @@ Rectangle { Layout.leftMargin: 50 Layout.rightMargin: 50 Layout.alignment: Qt.AlignHCenter - spacing: 25 model: chatModel cacheBuffer: Math.max(0, listView.contentHeight) @@ -804,6 +803,8 @@ Rectangle { delegate: ChatItemView { width: listView.contentItem.width - 15 + visible: name !== "ToolResponse: " + height: visible ? implicitHeight : 0 } function scrollToEnd() { diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index feb3c61d3d4b..99a216baffe5 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -3,10 +3,16 @@ #include "chatlistmodel.h" #include "network.h" #include "server.h" +#include "tool.h" +#include "toolcallparser.h" +#include "toolmodel.h" #include #include #include +#include +#include +#include #include #include #include @@ -16,6 +22,8 @@ #include +using namespace ToolEnums; + Chat::Chat(QObject *parent) : QObject(parent) , m_id(Network::globalInstance()->generateUniqueId()) @@ -222,6 +230,42 @@ void Chat::generatingQuestions() emit responseStateChanged(); } +QString executeToolCall(const QString &toolCall, QString &errorString) +{ + QJsonParseError err; + const QJsonDocument toolCallDoc = QJsonDocument::fromJson(toolCall.toUtf8(), &err); + + if (toolCallDoc.isNull() || err.error != QJsonParseError::NoError || !toolCallDoc.isObject()) { + errorString = QString("ERROR: The tool call had null or invalid json %1").arg(toolCall); + return QString(); + } + + QJsonObject rootObject = toolCallDoc.object(); + if (!rootObject.contains("name") || !rootObject.contains("parameters")) { + errorString = QString("ERROR: The tool call did not have required name and argument objects %1").arg(toolCall); + return QString(); + } + + const QString tool = toolCallDoc["name"].toString(); + const QJsonObject args = toolCallDoc["parameters"].toObject(); + + Tool *toolInstance = ToolModel::globalInstance()->get(tool); + if (!toolInstance) { + errorString = QString("ERROR: Could not find the tool for %1").arg(toolCall); + return QString(); + } + + // FIXME: Honor the confirmation mode feature + + const QString response = toolInstance->run(args, 10000 /*msecs to timeout*/); + if (toolInstance->error() != Error::NoError) { + errorString = QString("ERROR: Tool call produced error: %1").arg(toolInstance->errorString()); + return QString(); + } + + return response; +} + void Chat::responseStopped(qint64 promptResponseMs) { m_tokenSpeed = QString(); @@ -231,8 +275,25 @@ void Chat::responseStopped(qint64 promptResponseMs) m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); emit responseStateChanged(); - if (m_generatedName.isEmpty()) + + const int index = m_chatModel->count() - 1; + ChatItem item = m_chatModel->get(index); + ToolCallParser parser; + parser.update(item.value); + if (item.type() == ChatItem::Type::Response && parser.state() == ToolCallParser::Complete) { + const QString toolCall = parser.toolCall(); + QString errorString; + const QString toolResponse = executeToolCall(toolCall, errorString); + qDebug() << toolCall << toolResponse; + resetResponseState(); + m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); + m_chatModel->appendToolResponse(toolResponse); + m_chatModel->appendResponse(); + emit promptRequested(m_collections); // triggers a new response + return; + } else if (m_generatedName.isEmpty()) { emit generateNameRequested(); + } Network::globalInstance()->trackChatEvent("response_complete", { {"first", m_firstResponse}, diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index 0a5aebbc6e78..0d2201595c57 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -7,6 +7,9 @@ #include "localdocs.h" #include "mysettings.h" #include "network.h" +#include "tool.h" +#include "toolmodel.h" +#include "toolcallparser.h" #include @@ -48,6 +51,7 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace ToolEnums; namespace ranges = std::ranges; //#define DEBUG @@ -739,9 +743,18 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) const for (auto &item : items) messages.emplace_back(makeMap(item)); + jinja2::ValuesList toolList; + const int toolCount = ToolModel::globalInstance()->count(); + for (int i = 0; i < toolCount; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + if (t->usageMode() == UsageMode::Enabled) + toolList.push_back(t->jinjaValue()); + } + jinja2::ValuesMap params { { "messages", std::move(messages) }, { "add_generation_prompt", true }, + { "toolList", toolList }, }; if (auto token = model->bosToken()) params.emplace("bos_token", std::move(*token)); @@ -844,14 +857,18 @@ auto ChatLLM::promptInternal( return !m_stopGenerating; }; - auto handleResponse = [this, &result](LLModel::Token token, std::string_view piece) -> bool { + ToolCallParser toolCallParser; + auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool { Q_UNUSED(token) result.responseTokens++; m_timer->inc(); + + toolCallParser.update(QString::fromStdString(piece.data())); result.response.append(piece.data(), piece.size()); auto respStr = QString::fromUtf8(result.response); emit responseChanged(removeLeadingWhitespace(respStr)); - return !m_stopGenerating; + const bool foundToolCall = toolCallParser.state() == ToolCallParser::Complete; + return !foundToolCall && !m_stopGenerating; }; QElapsedTimer totalTime; @@ -871,13 +888,15 @@ auto ChatLLM::promptInternal( m_timer->stop(); qint64 elapsed = totalTime.elapsed(); + const bool foundToolCall = toolCallParser.state() == ToolCallParser::Complete; + // trim trailing whitespace auto respStr = QString::fromUtf8(result.response); - if (!respStr.isEmpty() && std::as_const(respStr).back().isSpace()) + if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall)) emit responseChanged(respStr.trimmed()); bool doQuestions = false; - if (!m_isServer && chatItems) { + if (!m_isServer && chatItems && !foundToolCall) { switch (mySettings->suggestionMode()) { case SuggestionMode::On: doQuestions = true; break; case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break; diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 45520d455b40..0cd741d4c7da 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -84,7 +84,7 @@ struct ChatItem Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) public: - enum class Type { System, Prompt, Response }; + enum class Type { System, Prompt, Response, ToolResponse }; // tags for constructing ChatItems struct prompt_tag_t { explicit prompt_tag_t() = default; }; @@ -93,6 +93,8 @@ struct ChatItem static inline constexpr response_tag_t response_tag = response_tag_t(); struct system_tag_t { explicit system_tag_t() = default; }; static inline constexpr system_tag_t system_tag = system_tag_t(); + struct tool_response_tag_t { explicit tool_response_tag_t() = default; }; + static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t(); // FIXME(jared): This should not be necessary. QML should see null or undefined if it // tries to access something invalid. @@ -108,6 +110,9 @@ struct ChatItem ChatItem(response_tag_t, bool currentResponse = true) : name(u"Response: "_s), currentResponse(currentResponse) {} + ChatItem(tool_response_tag_t) + : name(u"ToolResponse: "_s) {} + Type type() const { if (name == u"System: "_s) @@ -116,6 +121,8 @@ struct ChatItem return Type::Prompt; if (name == u"Response: "_s) return Type::Response; + if (name == u"ToolResponse: "_s) + return Type::ToolResponse; throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name)); } @@ -313,6 +320,22 @@ class ChatModel : public QAbstractListModel emit hasErrorChanged(false); } + void appendToolResponse(const QString &value) + { + m_mutex.lock(); + const int count = m_chatItems.count(); + m_mutex.unlock(); + ChatItem item(ChatItem::tool_response_tag); + item.value = value; + beginInsertRows(QModelIndex(), count, count); + { + QMutexLocker locker(&m_mutex); + m_chatItems.append(item); + } + endInsertRows(); + emit countChanged(); + } + Q_INVOKABLE void clear() { { diff --git a/gpt4all-chat/src/codeinterpreter.cpp b/gpt4all-chat/src/codeinterpreter.cpp new file mode 100644 index 000000000000..b5492db78384 --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.cpp @@ -0,0 +1,100 @@ +#include "codeinterpreter.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace Qt::Literals::StringLiterals; + +QString CodeInterpreter::run(const QJsonObject ¶meters, qint64 timeout) +{ + // Reset the error state + m_error = ToolEnums::Error::NoError; + m_errorString = QString(); + + QString code = parameters["code"].toString(); + + // Replace escaped characters in json to actual characters in javascript + code.replace("\\n", "\n"); // newline + code.replace("\\t", "\t"); // tab + code.replace("\\r", "\r"); // carriage return + code.replace("\\\"", "\""); // double quote + code.replace("\\\'", "\'"); // single quote + code.replace("\\\\", "\\"); // back slash + code.replace("\\b", "\b"); // backspace + code.replace("\\f", "\f"); // form feed + code.replace("\\v", "\v"); // vertical tab + code.replace("\\/", "/"); // forward slash + + QJSEngine engine; + + JavaScriptConsoleCapture consoleCapture; + QJSValue consoleObject = engine.newQObject(&consoleCapture); + engine.globalObject().setProperty("console", consoleObject); + + QJSValue result = engine.evaluate(code); + QString resultString = result.toString(); + + // NOTE: We purposely do not set the m_error or m_errorString which for the code interpreter since + // we *want* the model to see the response is an error so it can hopefully correct itself. The + // error member variables are intended for tools that have error conditions that cannot be corrected. + // For instance, a tool depending upon the network might set these error variables if the network + // is not available. + if (result.isError()) { + resultString = QString("Uncaught exception at line") + + result.property("lineNumber").toString() + + ":" + result.toString(); + } + + QJsonObject jsonObject; + jsonObject.insert("code", code); + jsonObject.insert("result", resultString); + jsonObject.insert("output", consoleCapture.output); + QJsonDocument doc(jsonObject); + Q_ASSERT(!doc.isNull() && doc.isObject()); + return doc.toJson(QJsonDocument::Compact); +} + +QJsonObject CodeInterpreter::paramSchema() const +{ + static const QString paramSchema = R"({ + "code": { + "type": "string", + "description": "The javascript code to run", + "required": true + } + })"; + + static const QJsonDocument params = QJsonDocument::fromJson(paramSchema.toUtf8()); + Q_ASSERT(!params.isNull() && params.isObject()); + return params.object(); +} + +QJsonObject CodeInterpreter::exampleParams() const +{ + static const QString example = R"( + function isPrime(num) { + if (num <= 1) return false; + if (num === 2) return true; + if (num % 2 === 0) return false; + for (let i = 3; i <= Math.sqrt(num); i += 2) { + if (num % i === 0) return false; + } + return true; + } + const number = 7; + console.log(`${number} is prime: ${isPrime(number)}`); + )"; + + QJsonObject jsonObject; + jsonObject.insert("code", example); + return jsonObject; +} diff --git a/gpt4all-chat/src/codeinterpreter.h b/gpt4all-chat/src/codeinterpreter.h new file mode 100644 index 000000000000..fd971db0480b --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.h @@ -0,0 +1,55 @@ +#ifndef CODEINTERPRETER_H +#define CODEINTERPRETER_H + +#include "tool.h" + +#include +#include + +class JavaScriptConsoleCapture : public QObject +{ + Q_OBJECT +public: + QString output; + Q_INVOKABLE void log(const QString &message) { output.append(message); } + // TODO: Consider adding the following + // console.assert() + // console.debug() + // console.exception() + // console.info() + // console.log() (equivalent to console.debug()) + // console.error() + // console.time() + // console.timeEnd() + // console.trace() + // console.count() + // console.warn() +}; + +class CodeInterpreter : public Tool +{ + Q_OBJECT +public: + explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {} + virtual ~CodeInterpreter() {} + + QString run(const QJsonObject ¶meters, qint64 timeout = 2000) override; + ToolEnums::Error error() const override { return m_error; } + QString errorString() const override { return m_errorString; } + + QString name() const override { return tr("Code Interpreter"); } + QString description() const override { return tr("Javascript code interpreter"); } + QString function() const override { return "javascript_interpret"; } + ToolEnums::PrivacyScope privacyScope() const override { return ToolEnums::PrivacyScope::Local; } + QJsonObject exampleParams() const override; + QJsonObject paramSchema() const override; + bool isBuiltin() const override { return true; } + ToolEnums::UsageMode usageMode() const override { return ToolEnums::UsageMode::Enabled; } + bool excerpts() const override { return false; } + +private: + ToolEnums::Error m_error; + QString m_errorString; +}; + +#endif // CODEINTERPRETER_H diff --git a/gpt4all-chat/src/jinja_helpers.cpp b/gpt4all-chat/src/jinja_helpers.cpp index 8067434f8fa6..361664883ffb 100644 --- a/gpt4all-chat/src/jinja_helpers.cpp +++ b/gpt4all-chat/src/jinja_helpers.cpp @@ -55,6 +55,7 @@ auto JinjaMessage::keys() const -> const std::unordered_set & using enum ChatItem::Type; case System: case Response: + case ToolResponse: return baseKeys; case Prompt: return userKeys; @@ -75,6 +76,7 @@ bool operator==(const JinjaMessage &a, const JinjaMessage &b) using enum ChatItem::Type; case System: case Response: + case ToolResponse: return true; case Prompt: return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments; @@ -89,6 +91,7 @@ const JinjaFieldMap JinjaMessage::s_fields = { case System: return "system"sv; case Prompt: return "user"sv; case Response: return "assistant"sv; + case ToolResponse: return "tool"sv; } Q_UNREACHABLE(); } }, diff --git a/gpt4all-chat/src/main.cpp b/gpt4all-chat/src/main.cpp index d594515671ca..51421d373e3c 100644 --- a/gpt4all-chat/src/main.cpp +++ b/gpt4all-chat/src/main.cpp @@ -7,6 +7,7 @@ #include "modellist.h" #include "mysettings.h" #include "network.h" +#include "toolmodel.h" #include #include @@ -112,6 +113,8 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); + qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance()); + qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); { diff --git a/gpt4all-chat/src/tool.cpp b/gpt4all-chat/src/tool.cpp new file mode 100644 index 000000000000..08097bfdcac1 --- /dev/null +++ b/gpt4all-chat/src/tool.cpp @@ -0,0 +1,32 @@ +#include "tool.h" + +#include + +QJsonObject filterModelGeneratedProperties(const QJsonObject &inputObject) +{ + QJsonObject filteredObject; + for (const QString &key : inputObject.keys()) { + QJsonObject propertyObject = inputObject.value(key).toObject(); + if (!propertyObject.contains("modelGenerated") || propertyObject["modelGenerated"].toBool()) + filteredObject.insert(key, propertyObject); + } + return filteredObject; +} + +jinja2::Value Tool::jinjaValue() const +{ + QJsonDocument doc(filterModelGeneratedProperties(paramSchema())); + QString p(doc.toJson(QJsonDocument::Compact)); + + QJsonDocument exampleDoc(exampleParams()); + QString e(exampleDoc.toJson(QJsonDocument::Compact)); + + jinja2::ValuesMap params { + { "name", name().toStdString() }, + { "description", description().toStdString() }, + { "function", function().toStdString() }, + { "paramSchema", p.toStdString() }, + { "exampleParams", e.toStdString() } + }; + return params; +} diff --git a/gpt4all-chat/src/tool.h b/gpt4all-chat/src/tool.h new file mode 100644 index 000000000000..58f01abd48a3 --- /dev/null +++ b/gpt4all-chat/src/tool.h @@ -0,0 +1,125 @@ +#ifndef TOOL_H +#define TOOL_H + +#include +#include + +#include + +using namespace Qt::Literals::StringLiterals; + +namespace ToolEnums { + Q_NAMESPACE + enum class Error + { + NoError = 0, + TimeoutError = 2, + UnknownError = 499, + }; + Q_ENUM_NS(Error) + + enum class UsageMode + { + Disabled = 0, // Completely disabled + Enabled = 1, // Enabled and the model decides whether to run + ForceUsage = 2, // Attempt to force usage of the tool rather than let the LLM decide. + // NOTE: Not always possible. + }; + Q_ENUM_NS(UsageMode) + + enum class ConfirmationMode + { + NoConfirmation = 0, // No confirmation required + AskBeforeRunning = 1, // User is queried on every execution + AskBeforeRunningRecursive = 2, // User is queried if the tool is invoked in a recursive tool call + }; + Q_ENUM_NS(ConfirmationMode) + + // Ordered in increasing levels of privacy + enum class PrivacyScope + { + None = 0, // Tool call data does not have any privacy scope + LocalOrg = 1, // Tool call data does not leave the local organization + Local = 2 // Tool call data does not leave the machine + }; + Q_ENUM_NS(PrivacyScope) +} + +class Tool : public QObject +{ + Q_OBJECT + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QString description READ description CONSTANT) + Q_PROPERTY(QString function READ function CONSTANT) + Q_PROPERTY(ToolEnums::PrivacyScope privacyScope READ privacyScope CONSTANT) + Q_PROPERTY(QJsonObject paramSchema READ paramSchema CONSTANT) + Q_PROPERTY(QJsonObject exampleParams READ exampleParams CONSTANT) + Q_PROPERTY(QUrl url READ url CONSTANT) + Q_PROPERTY(bool isBuiltin READ isBuiltin CONSTANT) + Q_PROPERTY(ToolEnums::UsageMode usageMode READ usageMode NOTIFY usageModeChanged) + Q_PROPERTY(ToolEnums::ConfirmationMode confirmationMode READ confirmationMode NOTIFY confirmationModeChanged) + Q_PROPERTY(bool excerpts READ excerpts CONSTANT) + +public: + Tool() : QObject(nullptr) {} + virtual ~Tool() {} + + virtual QString run(const QJsonObject ¶meters, qint64 timeout = 2000) = 0; + virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } + virtual QString errorString() const { return QString(); } + + // [Required] Human readable name of the tool. + virtual QString name() const = 0; + + // [Required] Human readable description of the tool. + virtual QString description() const = 0; + + // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + virtual QString function() const = 0; + + // [Required] The privacy scope. + virtual ToolEnums::PrivacyScope privacyScope() const = 0; + + // [Optional] Json schema describing the tool's parameters. An empty object specifies no parameters. + // https://json-schema.org/understanding-json-schema/ + // https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tools + // https://github.com/ollama/ollama/blob/main/docs/api.md#chat-request-with-tools + // FIXME: This should be validated against json schema + virtual QJsonObject paramSchema() const { return QJsonObject(); } + + // [Optional] An example of the parameters for this tool call. NOTE: This should only include parameters + // that the model is responsible for generating. + virtual QJsonObject exampleParams() const { return QJsonObject(); } + + // [Optional] The local file or remote resource use to invoke the tool. + virtual QUrl url() const { return QUrl(); } + + // [Optional] Whether the tool is built-in. + virtual bool isBuiltin() const { return false; } + + // [Optional] The usage mode. + virtual ToolEnums::UsageMode usageMode() const { return ToolEnums::UsageMode::Disabled; } + + // [Optional] The confirmation mode. + virtual ToolEnums::ConfirmationMode confirmationMode() const { return ToolEnums::ConfirmationMode::NoConfirmation; } + + // [Optional] Whether json result produces source excerpts. + virtual bool excerpts() const { return false; } + + bool operator==(const Tool &other) const + { + return function() == other.function(); + } + bool operator!=(const Tool &other) const + { + return !(*this == other); + } + + jinja2::Value jinjaValue() const; + +Q_SIGNALS: + void usageModeChanged(); + void confirmationModeChanged(); +}; + +#endif // TOOL_H diff --git a/gpt4all-chat/src/toolcallparser.cpp b/gpt4all-chat/src/toolcallparser.cpp new file mode 100644 index 000000000000..2e77bc2f49ba --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.cpp @@ -0,0 +1,86 @@ +#include "toolcallparser.h" + +#include + +ToolCallParser::ToolCallParser() +{ + reset(); +} + +void ToolCallParser::reset() +{ + m_expected = QChar('<'); + m_expectedIndex = 0; + m_state = None; + m_buffer.clear(); + m_toolCall.clear(); + m_endTagBuffer.clear(); + m_startIndex = -1; +} + +// This method is called with an arbitrary string and a current state. This method should take the +// current state into account and then parse through the update character by character to arrive at +// the new state. +void ToolCallParser::update(const QString &update) +{ + Q_ASSERT(m_state != Complete); + if (m_state == Complete) { + qWarning() << "ERROR: ToolCallParser::update already found a complete toolcall!"; + return; + } + + static const QString toolCallStart = ""; + static const QString toolCallEnd = ""; + + for (size_t i = 0; i < update.size(); ++i) { + const QChar c = update[i]; + const bool foundMatch = m_expected.isNull() || c == m_expected; + if (!foundMatch) { + reset(); + continue; + } + + m_buffer.append(c); + switch (m_state) { + case None: + { + m_expectedIndex = 1; + m_expected = u't'; + m_state = InStart; + m_startIndex = i; + break; + } + case InStart: + { + if (m_expectedIndex == toolCallStart.size() - 1) { + m_expectedIndex = 0; + m_expected = QChar(); + m_state = Partial; + } else { + ++m_expectedIndex; + m_expected = toolCallStart.at(m_expectedIndex); + } + break; + } + case Partial: + { + m_toolCall.append(c); + m_endTagBuffer.append(c); + if (m_endTagBuffer.size() > toolCallEnd.size()) + m_endTagBuffer.remove(0, 1); + if (m_endTagBuffer == toolCallEnd) { + m_toolCall.chop(toolCallEnd.size()); + m_state = Complete; + m_endTagBuffer.clear(); + m_buffer.clear(); + } + } + case Complete: + { + // Already complete, do nothing further + break; + } + } + } +} + diff --git a/gpt4all-chat/src/toolcallparser.h b/gpt4all-chat/src/toolcallparser.h new file mode 100644 index 000000000000..b15e1e28f7dd --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.h @@ -0,0 +1,34 @@ +#ifndef TOOLCALLPARSER_H +#define TOOLCALLPARSER_H + +#include + +class ToolCallParser +{ +public: + ToolCallParser(); + void reset(); + void update(const QString &update); + QString toolCall() const { return m_toolCall; } + int startIndex() const { return m_startIndex; } + + enum State { + None, + InStart, + Partial, + Complete, + }; + + State state() const { return m_state; } + +private: + QChar m_expected; + int m_expectedIndex; + State m_state; + QString m_buffer; + QString m_toolCall; + QString m_endTagBuffer; + int m_startIndex; +}; + +#endif // TOOLCALLPARSER_H diff --git a/gpt4all-chat/src/toolmodel.cpp b/gpt4all-chat/src/toolmodel.cpp new file mode 100644 index 000000000000..1e0ed11f4cd0 --- /dev/null +++ b/gpt4all-chat/src/toolmodel.cpp @@ -0,0 +1,41 @@ +#include "toolmodel.h" + +#include +#include +#include + +#include "codeinterpreter.h" + +class MyToolModel: public ToolModel { }; +Q_GLOBAL_STATIC(MyToolModel, toolModelInstance) +ToolModel *ToolModel::globalInstance() +{ + return toolModelInstance(); +} + +ToolModel::ToolModel() + : QAbstractListModel(nullptr) { + + QCoreApplication::instance()->installEventFilter(this); + + Tool* codeInterpreter = new CodeInterpreter; + m_tools.append(codeInterpreter); + m_toolMap.insert(codeInterpreter->function(), codeInterpreter); + connect(codeInterpreter, &Tool::usageModeChanged, this, &ToolModel::privacyScopeChanged); +} + +bool ToolModel::eventFilter(QObject *obj, QEvent *ev) +{ + if (obj == QCoreApplication::instance() && ev->type() == QEvent::LanguageChange) + emit dataChanged(index(0, 0), index(m_tools.size() - 1, 0)); + return false; +} + +ToolEnums::PrivacyScope ToolModel::privacyScope() const +{ + ToolEnums::PrivacyScope scope = ToolEnums::PrivacyScope::Local; // highest scope + for (const Tool *t : m_tools) + if (t->usageMode() != ToolEnums::UsageMode::Disabled) + scope = std::min(scope, t->privacyScope()); + return scope; +} diff --git a/gpt4all-chat/src/toolmodel.h b/gpt4all-chat/src/toolmodel.h new file mode 100644 index 000000000000..bb61b25d9962 --- /dev/null +++ b/gpt4all-chat/src/toolmodel.h @@ -0,0 +1,121 @@ +#ifndef TOOLMODEL_H +#define TOOLMODEL_H + +#include "tool.h" + +#include + +class ToolModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + Q_PROPERTY(ToolEnums::PrivacyScope privacyScope READ privacyScope NOTIFY privacyScopeChanged) + +public: + static ToolModel *globalInstance(); + + enum Roles { + NameRole = Qt::UserRole + 1, + DescriptionRole, + FunctionRole, + PrivacyScopeRole, + ParametersRole, + UrlRole, + ApiKeyRole, + KeyRequiredRole, + IsBuiltinRole, + UsageModeRole, + ConfirmationModeRole, + ExcerptsRole, + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_tools.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_tools.size()) + return QVariant(); + + const Tool *item = m_tools.at(index.row()); + switch (role) { + case NameRole: + return item->name(); + case DescriptionRole: + return item->description(); + case FunctionRole: + return item->function(); + case PrivacyScopeRole: + return QVariant::fromValue(item->privacyScope()); + case ParametersRole: + return item->paramSchema(); + case UrlRole: + return item->url(); + case IsBuiltinRole: + return item->isBuiltin(); + case UsageModeRole: + return QVariant::fromValue(item->usageMode()); + case ConfirmationModeRole: + return QVariant::fromValue(item->confirmationMode()); + case ExcerptsRole: + return item->excerpts(); + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[NameRole] = "name"; + roles[DescriptionRole] = "description"; + roles[FunctionRole] = "function"; + roles[PrivacyScopeRole] = "privacyScope"; + roles[ParametersRole] = "parameters"; + roles[UrlRole] = "url"; + roles[ApiKeyRole] = "apiKey"; + roles[KeyRequiredRole] = "keyRequired"; + roles[IsBuiltinRole] = "isBuiltin"; + roles[UsageModeRole] = "usageMode"; + roles[ConfirmationModeRole] = "confirmationMode"; + roles[ExcerptsRole] = "excerpts"; + return roles; + } + + Q_INVOKABLE Tool* get(int index) const + { + if (index < 0 || index >= m_tools.size()) return nullptr; + return m_tools.at(index); + } + + Q_INVOKABLE Tool *get(const QString &id) const + { + if (!m_toolMap.contains(id)) return nullptr; + return m_toolMap.value(id); + } + + int count() const { return m_tools.size(); } + + // Returns the least private scope of all enabled tools + ToolEnums::PrivacyScope privacyScope() const; + +Q_SIGNALS: + void countChanged(); + void privacyScopeChanged(); + void valueChanged(int index, const QString &value); + +protected: + bool eventFilter(QObject *obj, QEvent *ev) override; + +private: + explicit ToolModel(); + ~ToolModel() {} + friend class MyToolModel; + QList m_tools; + QHash m_toolMap; +}; + +#endif // TOOLMODEL_H