Skip to content

Commit 0ef9bc7

Browse files
committed
Communication protocol for Dive: Messages
- Messages (Request/Response) were added for the socket connection. - Helper functions for messages were added.
1 parent 8916923 commit 0ef9bc7

File tree

4 files changed

+580
-2
lines changed

4 files changed

+580
-2
lines changed

network/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
1818

1919
project(network)
2020

21-
set(NETWORK_SRCS socket_connection.cc)
21+
set(NETWORK_SRCS socket_connection.cc messages.cc)
2222

23-
set(NETWORK_HDRS socket_connection.h)
23+
set(NETWORK_HDRS socket_connection.h serializable.h messages.h)
2424

2525
add_library(network SHARED ${NETWORK_SRCS} ${NETWORK_HDRS})
2626

network/messages.cc

Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
/*
2+
Copyright 2025 Google Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#include "messages.h"
18+
19+
#ifdef WIN32
20+
# include <winsock2.h>
21+
#else
22+
# include <netinet/in.h>
23+
#endif
24+
25+
constexpr uint32_t kMaxPayloadSize = 16 * 1024 * 1024;
26+
27+
namespace Network
28+
{
29+
30+
void WriteUint32ToBuffer(uint32_t value, Buffer& dest)
31+
{
32+
uint32_t net_val = htonl(value);
33+
const uint8_t* p_val = reinterpret_cast<const uint8_t*>(&net_val);
34+
dest.insert(dest.end(), p_val, p_val + sizeof(uint32_t));
35+
}
36+
37+
void WriteStringToBuffer(const std::string& str, Buffer& dest)
38+
{
39+
WriteUint32ToBuffer(static_cast<uint32_t>(str.length()), dest);
40+
dest.insert(dest.end(), str.begin(), str.end());
41+
}
42+
43+
absl::StatusOr<uint32_t> ReadUint32FromBuffer(const Buffer& src, size_t& offset)
44+
{
45+
if (src.size() < offset + sizeof(uint32_t))
46+
{
47+
return absl::InvalidArgumentError("Buffer too small to read a uint32_t.");
48+
}
49+
uint32_t net_val;
50+
std::memcpy(&net_val, src.data() + offset, sizeof(uint32_t));
51+
offset += sizeof(uint32_t);
52+
return ntohl(net_val);
53+
}
54+
55+
absl::StatusOr<std::string> ReadStringFromBuffer(const Buffer& src, size_t& offset)
56+
{
57+
absl::StatusOr<uint32_t> len_or = ReadUint32FromBuffer(src, offset);
58+
if (!len_or.ok())
59+
{
60+
return len_or.status();
61+
}
62+
63+
uint32_t len = len_or.value();
64+
if (src.size() < offset + len)
65+
{
66+
return absl::InvalidArgumentError("Buffer too small for declared string length.");
67+
}
68+
std::string result(reinterpret_cast<const char*>(src.data() + offset), len);
69+
offset += len;
70+
return result;
71+
}
72+
73+
absl::Status HandShakeMessage::Serialize(Buffer& dest) const
74+
{
75+
dest.clear();
76+
WriteUint32ToBuffer(m_major_version, dest);
77+
WriteUint32ToBuffer(m_minor_version, dest);
78+
return absl::OkStatus();
79+
}
80+
81+
absl::Status HandShakeMessage::Deserialize(const Buffer& src)
82+
{
83+
size_t offset = 0;
84+
absl::StatusOr<uint32_t> major_or = ReadUint32FromBuffer(src, offset);
85+
if (!major_or.ok())
86+
{
87+
return major_or.status();
88+
}
89+
m_major_version = major_or.value();
90+
91+
absl::StatusOr<uint32_t> minor_or = ReadUint32FromBuffer(src, offset);
92+
if (!minor_or.ok())
93+
{
94+
return minor_or.status();
95+
}
96+
m_minor_version = minor_or.value();
97+
98+
if (offset != src.size())
99+
{
100+
return absl::InvalidArgumentError("Handshake message has unexpected trailing data.");
101+
}
102+
return absl::OkStatus();
103+
}
104+
105+
absl::Status StringMessage::Serialize(Buffer& dest) const
106+
{
107+
dest.clear();
108+
WriteStringToBuffer(m_str, dest);
109+
return absl::OkStatus();
110+
}
111+
112+
absl::Status StringMessage::Deserialize(const Buffer& src)
113+
{
114+
size_t offset = 0;
115+
absl::StatusOr<std::string> str_or = ReadStringFromBuffer(src, offset);
116+
if (!str_or.ok())
117+
{
118+
return str_or.status();
119+
}
120+
121+
m_str = std::move(str_or.value());
122+
if (offset != src.size())
123+
{
124+
return absl::InvalidArgumentError("String message has unexpected trailing data.");
125+
}
126+
return absl::OkStatus();
127+
}
128+
129+
absl::Status DownloadFileResponse::Serialize(Buffer& dest) const
130+
{
131+
dest.push_back(static_cast<uint8_t>(m_found));
132+
WriteStringToBuffer(m_error_reason, dest);
133+
WriteStringToBuffer(m_file_path, dest);
134+
WriteStringToBuffer(m_file_size_str, dest);
135+
136+
return absl::OkStatus();
137+
}
138+
139+
absl::Status DownloadFileResponse::Deserialize(const Buffer& src)
140+
{
141+
size_t offset = 0;
142+
143+
// Deserialize the 'found' boolean.
144+
if (src.size() < offset + sizeof(uint8_t))
145+
{
146+
return absl::InvalidArgumentError("Buffer too small for 'found' field.");
147+
}
148+
m_found = (src[offset] != 0);
149+
offset += sizeof(uint8_t);
150+
151+
// Deserialize the strings using the StatusOr-returning helper
152+
absl::StatusOr<std::string> error_reason_or = ReadStringFromBuffer(src, offset);
153+
if (!error_reason_or.ok())
154+
{
155+
return error_reason_or.status(); // Forward the error
156+
}
157+
m_error_reason = std::move(error_reason_or.value());
158+
159+
absl::StatusOr<std::string> file_path_or = ReadStringFromBuffer(src, offset);
160+
if (!file_path_or.ok())
161+
{
162+
return file_path_or.status();
163+
}
164+
m_file_path = std::move(file_path_or.value());
165+
166+
absl::StatusOr<std::string> file_size_or = ReadStringFromBuffer(src, offset);
167+
if (!file_size_or.ok())
168+
{
169+
return file_size_or.status();
170+
}
171+
m_file_size_str = std::move(file_size_or.value());
172+
173+
// Final check for trailing data.
174+
if (offset != src.size())
175+
{
176+
return absl::InvalidArgumentError("Message has unexpected trailing data.");
177+
}
178+
179+
return absl::OkStatus();
180+
}
181+
182+
absl::Status ReceiveBuffer(SocketConnection* conn, uint8_t* buffer, size_t size)
183+
{
184+
if (!conn)
185+
{
186+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
187+
}
188+
size_t total_received = 0;
189+
while (total_received < size)
190+
{
191+
absl::StatusOr<size_t> received_or = conn->Recv(buffer + total_received,
192+
size - total_received);
193+
if (!received_or.ok())
194+
{
195+
return received_or.status();
196+
}
197+
total_received += received_or.value();
198+
}
199+
return absl::OkStatus();
200+
}
201+
202+
absl::Status SendBuffer(SocketConnection* conn, const uint8_t* buffer, size_t size)
203+
{
204+
if (!conn)
205+
{
206+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
207+
}
208+
return conn->Send(buffer, size);
209+
}
210+
211+
absl::StatusOr<std::unique_ptr<ISerializable>> ReceiveMessage(SocketConnection* conn)
212+
{
213+
if (!conn)
214+
{
215+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
216+
}
217+
218+
const size_t header_size = sizeof(uint32_t) * 2;
219+
uint8_t header_buffer[header_size];
220+
221+
// Receive the message header.
222+
absl::Status status = ReceiveBuffer(conn, header_buffer, header_size);
223+
if (!status.ok())
224+
{
225+
return status;
226+
}
227+
228+
// Parse header.
229+
uint32_t net_type, net_length;
230+
std::memcpy(&net_type, header_buffer, sizeof(uint32_t));
231+
std::memcpy(&net_length, header_buffer + sizeof(uint32_t), sizeof(uint32_t));
232+
uint32_t type = ntohl(net_type);
233+
uint32_t payload_length = ntohl(net_length);
234+
235+
if (payload_length > kMaxPayloadSize)
236+
{
237+
conn->Close();
238+
return absl::InvalidArgumentError(
239+
absl::StrCat("Payload size ", payload_length, " exceeds limit."));
240+
}
241+
242+
// Receive the message payload.
243+
Buffer payload_buffer(payload_length);
244+
status = ReceiveBuffer(conn, payload_buffer.data(), payload_length);
245+
if (!status.ok())
246+
{
247+
return status;
248+
}
249+
250+
// Create and deserialize the message object.
251+
std::unique_ptr<ISerializable> message;
252+
switch (static_cast<MessageType>(type))
253+
{
254+
case MessageType::HANDSHAKE_REQUEST:
255+
message = std::make_unique<HandShakeRequest>();
256+
break;
257+
case MessageType::HANDSHAKE_RESPONSE:
258+
message = std::make_unique<HandShakeResponse>();
259+
break;
260+
case MessageType::PING_MESSAGE:
261+
message = std::make_unique<PingMessage>();
262+
break;
263+
case MessageType::PONG_MESSAGE:
264+
message = std::make_unique<PongMessage>();
265+
break;
266+
case MessageType::CAPTURE_REQUEST:
267+
message = std::make_unique<CaptureRequest>();
268+
break;
269+
case MessageType::CAPTURE_RESPONSE:
270+
message = std::make_unique<CaptureResponse>();
271+
break;
272+
case MessageType::DOWNLOAD_FILE_REQUEST:
273+
message = std::make_unique<DownloadFileRequest>();
274+
break;
275+
case MessageType::DOWNLOAD_FILE_RESPONSE:
276+
message = std::make_unique<DownloadFileResponse>();
277+
break;
278+
default:
279+
conn->Close();
280+
return absl::InvalidArgumentError(absl::StrCat("Unknown message type: ", type));
281+
}
282+
283+
status = message->Deserialize(payload_buffer);
284+
if (!status.ok())
285+
{
286+
conn->Close();
287+
return status;
288+
}
289+
290+
return message;
291+
}
292+
293+
absl::Status SendMessage(SocketConnection* conn, const ISerializable& message)
294+
{
295+
if (!conn)
296+
{
297+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
298+
}
299+
300+
// Serialize the message payload.
301+
Buffer payload_buffer;
302+
absl::Status status = message.Serialize(payload_buffer);
303+
if (!status.ok())
304+
{
305+
return status;
306+
}
307+
if (payload_buffer.size() > kMaxPayloadSize)
308+
{
309+
return absl::InvalidArgumentError("Serialized payload size exceeds limit.");
310+
}
311+
312+
// Construct and send the header.
313+
uint32_t net_type = htonl(message.GetMessageType());
314+
uint32_t net_payload_length = htonl(static_cast<uint32_t>(payload_buffer.size()));
315+
const size_t header_size = sizeof(net_type) + sizeof(net_payload_length);
316+
uint8_t header_buffer[header_size];
317+
std::memcpy(header_buffer, &net_type, sizeof(uint32_t));
318+
std::memcpy(header_buffer + sizeof(uint32_t), &net_payload_length, sizeof(uint32_t));
319+
320+
status = SendBuffer(conn, header_buffer, header_size);
321+
if (!status.ok())
322+
{
323+
return status;
324+
}
325+
326+
// Send the payload.
327+
status = SendBuffer(conn, payload_buffer.data(), payload_buffer.size());
328+
if (!status.ok())
329+
{
330+
return status;
331+
}
332+
333+
return absl::OkStatus();
334+
}
335+
336+
} // namespace Network

0 commit comments

Comments
 (0)