From 1918a5df1d4bc001e5ad9535d04113c7a715d319 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Wed, 18 Dec 2024 15:53:24 -0500 Subject: [PATCH] Serialize and deserialize subitems. Signed-off-by: Adam Treat --- gpt4all-chat/src/chatlistmodel.cpp | 3 +- gpt4all-chat/src/chatmodel.cpp | 128 +++++++++++++++++++++++++++-- gpt4all-chat/src/chatmodel.h | 19 ++++- gpt4all-chat/src/tool.cpp | 31 +++++++ gpt4all-chat/src/tool.h | 4 + 5 files changed, 173 insertions(+), 12 deletions(-) diff --git a/gpt4all-chat/src/chatlistmodel.cpp b/gpt4all-chat/src/chatlistmodel.cpp index aa6d3d9540b2..fd9d450925a2 100644 --- a/gpt4all-chat/src/chatlistmodel.cpp +++ b/gpt4all-chat/src/chatlistmodel.cpp @@ -20,8 +20,7 @@ #include static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC; -static constexpr qint32 CHAT_FORMAT_VERSION = 11; -// FIXME: need to bump format for new chatmodel chatitem tree +static constexpr qint32 CHAT_FORMAT_VERSION = 12; class MyChatListModel: public ChatListModel { }; Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance) diff --git a/gpt4all-chat/src/chatmodel.cpp b/gpt4all-chat/src/chatmodel.cpp index 7b955ba07e72..54af48d21f38 100644 --- a/gpt4all-chat/src/chatmodel.cpp +++ b/gpt4all-chat/src/chatmodel.cpp @@ -14,12 +14,48 @@ QList ChatItem::consolidateSources(const QList &sources) return consolidatedSources; } +void ChatItem::serializeResponse(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeToolCall(QDataStream &stream, int version) +{ + stream << value; + toolCallInfo.serialize(stream, version); +} + +void ChatItem::serializeToolResponse(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeText(QDataStream &stream, int version) +{ + stream << value; +} + +void ChatItem::serializeSubItems(QDataStream &stream, int version) +{ + stream << name; + switch (auto typ = type()) { + using enum ChatItem::Type; + case Response: { serializeResponse(stream, version); break; } + case ToolCall: { serializeToolCall(stream, version); break; } + case ToolResponse: { serializeToolResponse(stream, version); break; } + case Text: { serializeText(stream, version); break; } + case System: + case Prompt: + throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); + } + + stream << qsizetype(subItems.size()); + for (ChatItem *item :subItems) + item->serializeSubItems(stream, version); +} + void ChatItem::serialize(QDataStream &stream, int version) { - // FIXME: This 'id' should be eliminated the next time we bump serialization version. - // (Jared) This was apparently never used. - int id = 0; - stream << id; stream << name; stream << value; stream << newResponse; @@ -88,13 +124,78 @@ void ChatItem::serialize(QDataStream &stream, int version) stream << a.content; } } + + if (version >= 12) { + stream << qsizetype(subItems.size()); + for (ChatItem *item :subItems) + item->serializeSubItems(stream, version); + } +} + +bool ChatItem::deserializeToolCall(QDataStream &stream, int version) +{ + stream >> value; + return toolCallInfo.deserialize(stream, version);; +} + +bool ChatItem::deserializeToolResponse(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeText(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeResponse(QDataStream &stream, int version) +{ + stream >> value; + return true; +} + +bool ChatItem::deserializeSubItems(QDataStream &stream, int version) +{ + stream >> name; + try { + type(); // check name + } catch (const std::exception &e) { + qWarning() << "ChatModel ERROR:" << e.what(); + return false; + } + switch (auto typ = type()) { + using enum ChatItem::Type; + case Response: { deserializeResponse(stream, version); break; } + case ToolCall: { deserializeToolCall(stream, version); break; } + case ToolResponse: { deserializeToolResponse(stream, version); break; } + case Text: { deserializeText(stream, version); break; } + case System: + case Prompt: + throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ))); + } + + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ChatItem *c = new ChatItem(this); + if (!c->deserializeSubItems(stream, version)) { + delete c; + return false; + } + subItems.push_back(c); + } + + return true; } bool ChatItem::deserialize(QDataStream &stream, int version) { - // FIXME: see comment in serialization about id - int id; - stream >> id; + if (version < 12) { + int id; + stream >> id; + } stream >> name; try { type(); // check name @@ -227,5 +328,18 @@ bool ChatItem::deserialize(QDataStream &stream, int version) } promptAttachments = attachments; } + + if (version >= 12) { + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ChatItem *c = new ChatItem(this); + if (!c->deserializeSubItems(stream, version)) { + delete c; + return false; + } + subItems.push_back(c); + } + } return true; } diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 7bef69d8754f..3615801d7fa5 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -376,7 +376,19 @@ class ChatItem : public QObject static QList consolidateSources(const QList &sources); + void serializeResponse(QDataStream &stream, int version); + void serializeToolCall(QDataStream &stream, int version); + void serializeToolResponse(QDataStream &stream, int version); + void serializeText(QDataStream &stream, int version); + void serializeSubItems(QDataStream &stream, int version); // recursive void serialize(QDataStream &stream, int version); + + + bool deserializeResponse(QDataStream &stream, int version); + bool deserializeToolCall(QDataStream &stream, int version); + bool deserializeToolResponse(QDataStream &stream, int version); + bool deserializeText(QDataStream &stream, int version); + bool deserializeSubItems(QDataStream &stream, int version); // recursive bool deserialize(QDataStream &stream, int version); Q_SIGNALS: @@ -868,7 +880,6 @@ class ChatModel : public QAbstractListModel Q_ASSERT(!split.second.isEmpty()); ChatItem *toolCallItem = new ChatItem(this, ChatItem::tool_call_tag, split.second); toolCallItem->isCurrentResponse = true; - // toolCallItem.toolCallInfo = toolCallInfo; newResponse->subItems.push_back(toolCallItem); // Add new response and reset our value @@ -997,7 +1008,6 @@ class ChatModel : public QAbstractListModel bool deserialize(QDataStream &stream, int version) { - // FIXME: need to deserialize new chatitem tree clear(); // reset to known state int size; @@ -1006,7 +1016,10 @@ class ChatModel : public QAbstractListModel QList chatItems; for (int i = 0; i < size; ++i) { ChatItem *c = new ChatItem(this); - c->deserialize(stream, version); + if (!c->deserialize(stream, version)) { + delete c; + return false; + } if (version < 11 && c->type() == ChatItem::Type::Response) { // move sources from the response to their last prompt if (lastPromptIndex >= 0) { diff --git a/gpt4all-chat/src/tool.cpp b/gpt4all-chat/src/tool.cpp index 0e8689ec44ba..74975d2830c8 100644 --- a/gpt4all-chat/src/tool.cpp +++ b/gpt4all-chat/src/tool.cpp @@ -41,3 +41,34 @@ jinja2::Value Tool::jinjaValue() const }; return params; } + +void ToolCallInfo::serialize(QDataStream &stream, int version) +{ + stream << name; + stream << params.size(); + for (auto param : params) { + stream << param.name; + stream << param.type; + stream << param.value; + } + stream << result; + stream << error; + stream << errorString; +} + +bool ToolCallInfo::deserialize(QDataStream &stream, int version) +{ + stream >> name; + qsizetype count; + stream >> count; + for (int i = 0; i < count; ++i) { + ToolParam p; + stream >> p.name; + stream >> p.type; + stream >> p.value; + } + stream >> result; + stream >> error; + stream >> errorString; + return true; +} diff --git a/gpt4all-chat/src/tool.h b/gpt4all-chat/src/tool.h index 62e42cbbd762..08c058eb5e66 100644 --- a/gpt4all-chat/src/tool.h +++ b/gpt4all-chat/src/tool.h @@ -60,6 +60,10 @@ struct ToolCallInfo QString result; ToolEnums::Error error = ToolEnums::Error::NoError; QString errorString; + + void serialize(QDataStream &stream, int version); + bool deserialize(QDataStream &stream, int version); + bool operator==(const ToolCallInfo& other) const { return name == other.name && result == other.result && params == other.params