Skip to content

Commit

Permalink
Serialize and deserialize subitems.
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Treat <[email protected]>
  • Loading branch information
manyoso committed Dec 18, 2024
1 parent 6e6cdbf commit 1918a5d
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 12 deletions.
3 changes: 1 addition & 2 deletions gpt4all-chat/src/chatlistmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
#include <memory>

static constexpr quint32 CHAT_FORMAT_MAGIC = 0xF5D553CC;
static constexpr qint32 CHAT_FORMAT_VERSION = 11;
// FIXME: need to bump format for new chatmodel chatitem tree
static constexpr qint32 CHAT_FORMAT_VERSION = 12;

class MyChatListModel: public ChatListModel { };
Q_GLOBAL_STATIC(MyChatListModel, chatListModelInstance)
Expand Down
128 changes: 121 additions & 7 deletions gpt4all-chat/src/chatmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,48 @@ QList<ResultInfo> ChatItem::consolidateSources(const QList<ResultInfo> &sources)
return consolidatedSources;
}

void ChatItem::serializeResponse(QDataStream &stream, int version)
{
stream << value;
}

void ChatItem::serializeToolCall(QDataStream &stream, int version)
{
stream << value;
toolCallInfo.serialize(stream, version);
}

void ChatItem::serializeToolResponse(QDataStream &stream, int version)
{
stream << value;
}

void ChatItem::serializeText(QDataStream &stream, int version)
{
stream << value;
}

void ChatItem::serializeSubItems(QDataStream &stream, int version)
{
stream << name;
switch (auto typ = type()) {
using enum ChatItem::Type;
case Response: { serializeResponse(stream, version); break; }
case ToolCall: { serializeToolCall(stream, version); break; }
case ToolResponse: { serializeToolResponse(stream, version); break; }
case Text: { serializeText(stream, version); break; }
case System:
case Prompt:
throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ)));
}

stream << qsizetype(subItems.size());
for (ChatItem *item :subItems)
item->serializeSubItems(stream, version);
}

void ChatItem::serialize(QDataStream &stream, int version)
{
// FIXME: This 'id' should be eliminated the next time we bump serialization version.
// (Jared) This was apparently never used.
int id = 0;
stream << id;
stream << name;
stream << value;
stream << newResponse;
Expand Down Expand Up @@ -88,13 +124,78 @@ void ChatItem::serialize(QDataStream &stream, int version)
stream << a.content;
}
}

if (version >= 12) {
stream << qsizetype(subItems.size());
for (ChatItem *item :subItems)
item->serializeSubItems(stream, version);
}
}

bool ChatItem::deserializeToolCall(QDataStream &stream, int version)
{
stream >> value;
return toolCallInfo.deserialize(stream, version);;
}

bool ChatItem::deserializeToolResponse(QDataStream &stream, int version)
{
stream >> value;
return true;
}

bool ChatItem::deserializeText(QDataStream &stream, int version)
{
stream >> value;
return true;
}

bool ChatItem::deserializeResponse(QDataStream &stream, int version)
{
stream >> value;
return true;
}

bool ChatItem::deserializeSubItems(QDataStream &stream, int version)
{
stream >> name;
try {
type(); // check name
} catch (const std::exception &e) {
qWarning() << "ChatModel ERROR:" << e.what();
return false;
}
switch (auto typ = type()) {
using enum ChatItem::Type;
case Response: { deserializeResponse(stream, version); break; }
case ToolCall: { deserializeToolCall(stream, version); break; }
case ToolResponse: { deserializeToolResponse(stream, version); break; }
case Text: { deserializeText(stream, version); break; }
case System:
case Prompt:
throw std::invalid_argument(fmt::format("cannot serialize subitem type {}", int(typ)));
}

qsizetype count;
stream >> count;
for (int i = 0; i < count; ++i) {
ChatItem *c = new ChatItem(this);
if (!c->deserializeSubItems(stream, version)) {
delete c;
return false;
}
subItems.push_back(c);
}

return true;
}

bool ChatItem::deserialize(QDataStream &stream, int version)
{
// FIXME: see comment in serialization about id
int id;
stream >> id;
if (version < 12) {
int id;
stream >> id;
}
stream >> name;
try {
type(); // check name
Expand Down Expand Up @@ -227,5 +328,18 @@ bool ChatItem::deserialize(QDataStream &stream, int version)
}
promptAttachments = attachments;
}

if (version >= 12) {
qsizetype count;
stream >> count;
for (int i = 0; i < count; ++i) {
ChatItem *c = new ChatItem(this);
if (!c->deserializeSubItems(stream, version)) {
delete c;
return false;
}
subItems.push_back(c);
}
}
return true;
}
19 changes: 16 additions & 3 deletions gpt4all-chat/src/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,19 @@ class ChatItem : public QObject

static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources);

void serializeResponse(QDataStream &stream, int version);
void serializeToolCall(QDataStream &stream, int version);
void serializeToolResponse(QDataStream &stream, int version);
void serializeText(QDataStream &stream, int version);
void serializeSubItems(QDataStream &stream, int version); // recursive
void serialize(QDataStream &stream, int version);


bool deserializeResponse(QDataStream &stream, int version);
bool deserializeToolCall(QDataStream &stream, int version);
bool deserializeToolResponse(QDataStream &stream, int version);
bool deserializeText(QDataStream &stream, int version);
bool deserializeSubItems(QDataStream &stream, int version); // recursive
bool deserialize(QDataStream &stream, int version);

Q_SIGNALS:
Expand Down Expand Up @@ -868,7 +880,6 @@ class ChatModel : public QAbstractListModel
Q_ASSERT(!split.second.isEmpty());
ChatItem *toolCallItem = new ChatItem(this, ChatItem::tool_call_tag, split.second);
toolCallItem->isCurrentResponse = true;
// toolCallItem.toolCallInfo = toolCallInfo;
newResponse->subItems.push_back(toolCallItem);

// Add new response and reset our value
Expand Down Expand Up @@ -997,7 +1008,6 @@ class ChatModel : public QAbstractListModel

bool deserialize(QDataStream &stream, int version)
{
// FIXME: need to deserialize new chatitem tree
clear(); // reset to known state

int size;
Expand All @@ -1006,7 +1016,10 @@ class ChatModel : public QAbstractListModel
QList<ChatItem*> chatItems;
for (int i = 0; i < size; ++i) {
ChatItem *c = new ChatItem(this);
c->deserialize(stream, version);
if (!c->deserialize(stream, version)) {
delete c;
return false;
}
if (version < 11 && c->type() == ChatItem::Type::Response) {
// move sources from the response to their last prompt
if (lastPromptIndex >= 0) {
Expand Down
31 changes: 31 additions & 0 deletions gpt4all-chat/src/tool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,34 @@ jinja2::Value Tool::jinjaValue() const
};
return params;
}

void ToolCallInfo::serialize(QDataStream &stream, int version)
{
stream << name;
stream << params.size();
for (auto param : params) {
stream << param.name;
stream << param.type;
stream << param.value;
}
stream << result;
stream << error;
stream << errorString;
}

bool ToolCallInfo::deserialize(QDataStream &stream, int version)
{
stream >> name;
qsizetype count;
stream >> count;
for (int i = 0; i < count; ++i) {
ToolParam p;
stream >> p.name;
stream >> p.type;
stream >> p.value;
}
stream >> result;
stream >> error;
stream >> errorString;
return true;
}
4 changes: 4 additions & 0 deletions gpt4all-chat/src/tool.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ struct ToolCallInfo
QString result;
ToolEnums::Error error = ToolEnums::Error::NoError;
QString errorString;

void serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);

bool operator==(const ToolCallInfo& other) const
{
return name == other.name && result == other.result && params == other.params
Expand Down

0 comments on commit 1918a5d

Please sign in to comment.