Skip to content

Commit 76b7d0d

Browse files
committed
Communication protocol for Dive: Messages
- Messages (Request/Response) were added for the socket connection. - Helper functions for messages were added.
1 parent 8916923 commit 76b7d0d

File tree

4 files changed

+543
-2
lines changed

4 files changed

+543
-2
lines changed

network/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ cmake_minimum_required(VERSION 3.10 FATAL_ERROR)
1818

1919
project(network)
2020

21-
set(NETWORK_SRCS socket_connection.cc)
21+
set(NETWORK_SRCS socket_connection.cc messages.cc)
2222

23-
set(NETWORK_HDRS socket_connection.h)
23+
set(NETWORK_HDRS socket_connection.h serializable.h messages.h)
2424

2525
add_library(network SHARED ${NETWORK_SRCS} ${NETWORK_HDRS})
2626

network/messages.cc

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
/*
2+
Copyright 2025 Google Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
#include "messages.h"
18+
19+
#ifdef WIN32
20+
# include <winsock2.h>
21+
#else
22+
# include <netinet/in.h>
23+
#endif
24+
25+
namespace Network
26+
{
27+
28+
void WriteUint32ToBuffer(uint32_t value, Buffer& dest)
29+
{
30+
uint32_t net_val = htonl(value);
31+
const uint8_t* p_val = reinterpret_cast<const uint8_t*>(&net_val);
32+
dest.insert(dest.end(), p_val, p_val + sizeof(uint32_t));
33+
}
34+
35+
void WriteStringToBuffer(const std::string& str, Buffer& dest)
36+
{
37+
WriteUint32ToBuffer(static_cast<uint32_t>(str.length()), dest);
38+
dest.insert(dest.end(), str.begin(), str.end());
39+
}
40+
41+
absl::StatusOr<uint32_t> ReadUint32FromBuffer(const Buffer& src, size_t& offset)
42+
{
43+
if (src.size() < offset + sizeof(uint32_t))
44+
{
45+
return absl::InvalidArgumentError("Buffer too small to read a uint32_t.");
46+
}
47+
uint32_t net_val;
48+
std::memcpy(&net_val, src.data() + offset, sizeof(uint32_t));
49+
offset += sizeof(uint32_t);
50+
return ntohl(net_val);
51+
}
52+
53+
absl::StatusOr<std::string> ReadStringFromBuffer(const Buffer& src, size_t& offset)
54+
{
55+
absl::StatusOr<uint32_t> len_or = ReadUint32FromBuffer(src, offset);
56+
if (!len_or.ok())
57+
{
58+
return len_or.status();
59+
}
60+
61+
uint32_t len = len_or.value();
62+
if (src.size() < offset + len)
63+
{
64+
return absl::InvalidArgumentError("Buffer too small for declared string length.");
65+
}
66+
std::string result(reinterpret_cast<const char*>(src.data() + offset), len);
67+
offset += len;
68+
return result;
69+
}
70+
71+
absl::Status HandShakeMessage::Serialize(Buffer& dest) const
72+
{
73+
dest.clear();
74+
WriteUint32ToBuffer(major_version, dest);
75+
WriteUint32ToBuffer(minor_version, dest);
76+
return absl::OkStatus();
77+
}
78+
79+
absl::Status HandShakeMessage::Deserialize(const Buffer& src)
80+
{
81+
size_t offset = 0;
82+
absl::StatusOr<uint32_t> major_or = ReadUint32FromBuffer(src, offset);
83+
if (!major_or.ok())
84+
{
85+
return major_or.status();
86+
}
87+
major_version = major_or.value();
88+
89+
absl::StatusOr<uint32_t> minor_or = ReadUint32FromBuffer(src, offset);
90+
if (!minor_or.ok())
91+
{
92+
return minor_or.status();
93+
}
94+
minor_version = minor_or.value();
95+
96+
if (offset != src.size())
97+
{
98+
return absl::InvalidArgumentError("Handshake message has unexpected trailing data.");
99+
}
100+
return absl::OkStatus();
101+
}
102+
103+
absl::Status StringMessage::Serialize(Buffer& dest) const
104+
{
105+
dest.clear();
106+
WriteStringToBuffer(str, dest);
107+
return absl::OkStatus();
108+
}
109+
110+
absl::Status StringMessage::Deserialize(const Buffer& src)
111+
{
112+
size_t offset = 0;
113+
absl::StatusOr<std::string> str_or = ReadStringFromBuffer(src, offset);
114+
if (!str_or.ok())
115+
{
116+
return str_or.status();
117+
}
118+
119+
str = std::move(str_or.value());
120+
if (offset != src.size())
121+
{
122+
return absl::InvalidArgumentError("String message has unexpected trailing data.");
123+
}
124+
return absl::OkStatus();
125+
}
126+
127+
absl::Status DownloadFileResponse::Serialize(Buffer& dest) const
128+
{
129+
dest.push_back(static_cast<uint8_t>(found));
130+
WriteStringToBuffer(error_reason, dest);
131+
WriteStringToBuffer(file_path, dest);
132+
WriteStringToBuffer(file_size_str, dest);
133+
134+
return absl::OkStatus();
135+
}
136+
137+
absl::Status DownloadFileResponse::Deserialize(const Buffer& src)
138+
{
139+
size_t offset = 0;
140+
141+
// Deserialize the 'found' boolean.
142+
if (src.size() < offset + sizeof(uint8_t))
143+
{
144+
return absl::InvalidArgumentError("Buffer too small for 'found' field.");
145+
}
146+
found = (src[offset] != 0);
147+
offset += sizeof(uint8_t);
148+
149+
// Deserialize the strings using the StatusOr-returning helper
150+
absl::StatusOr<std::string> error_reason_or = ReadStringFromBuffer(src, offset);
151+
if (!error_reason_or.ok())
152+
{
153+
return error_reason_or.status(); // Forward the error
154+
}
155+
error_reason = std::move(error_reason_or.value());
156+
157+
absl::StatusOr<std::string> file_path_or = ReadStringFromBuffer(src, offset);
158+
if (!file_path_or.ok())
159+
{
160+
return file_path_or.status();
161+
}
162+
file_path = std::move(file_path_or.value());
163+
164+
absl::StatusOr<std::string> file_size_or = ReadStringFromBuffer(src, offset);
165+
if (!file_size_or.ok())
166+
{
167+
return file_size_or.status();
168+
}
169+
file_size_str = std::move(file_size_or.value());
170+
171+
// Final check for trailing data.
172+
if (offset != src.size())
173+
{
174+
return absl::InvalidArgumentError("Message has unexpected trailing data.");
175+
}
176+
177+
return absl::OkStatus();
178+
}
179+
180+
absl::Status ReceiveFull(SocketConnection* conn, uint8_t* buffer, size_t size)
181+
{
182+
if (!conn)
183+
{
184+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
185+
}
186+
size_t total_received = 0;
187+
while (total_received < size)
188+
{
189+
absl::StatusOr<size_t> received_or = conn->Recv(buffer + total_received,
190+
size - total_received);
191+
if (!received_or.ok())
192+
{
193+
return received_or.status();
194+
}
195+
total_received += received_or.value();
196+
}
197+
return absl::OkStatus();
198+
}
199+
200+
absl::Status SendFull(SocketConnection* conn, const uint8_t* buffer, size_t size)
201+
{
202+
if (!conn)
203+
{
204+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
205+
}
206+
return conn->Send(buffer, size);
207+
}
208+
209+
absl::StatusOr<std::unique_ptr<ISerializable>> ReceiveMessage(SocketConnection* conn)
210+
{
211+
if (!conn)
212+
{
213+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
214+
}
215+
216+
const size_t header_size = sizeof(uint32_t) * 2;
217+
uint8_t header_buffer[header_size];
218+
219+
// Receive the message header.
220+
absl::Status status = ReceiveFull(conn, header_buffer, header_size);
221+
if (!status.ok())
222+
{
223+
return status;
224+
}
225+
226+
// Parse header.
227+
uint32_t net_type, net_length;
228+
std::memcpy(&net_type, header_buffer, sizeof(uint32_t));
229+
std::memcpy(&net_length, header_buffer + sizeof(uint32_t), sizeof(uint32_t));
230+
uint32_t type = ntohl(net_type);
231+
uint32_t payload_length = ntohl(net_length);
232+
233+
if (payload_length > kMaxPayloadSize)
234+
{
235+
conn->Close();
236+
return absl::InvalidArgumentError(
237+
absl::StrCat("Payload size ", payload_length, " exceeds limit."));
238+
}
239+
240+
// Receive the message payload.
241+
Buffer payload_buffer(payload_length);
242+
status = ReceiveFull(conn, payload_buffer.data(), payload_length);
243+
if (!status.ok())
244+
{
245+
return status;
246+
}
247+
248+
// Create and deserialize the message object.
249+
std::unique_ptr<ISerializable> message;
250+
switch (static_cast<MessageType>(type))
251+
{
252+
case MessageType::HANDSHAKE_REQUEST:
253+
message = std::make_unique<HandShakeRequest>();
254+
break;
255+
case MessageType::HANDSHAKE_RESPONSE:
256+
message = std::make_unique<HandShakeResponse>();
257+
break;
258+
case MessageType::PING_MESSAGE:
259+
message = std::make_unique<PingMessage>();
260+
break;
261+
case MessageType::PONG_MESSAGE:
262+
message = std::make_unique<PongMessage>();
263+
break;
264+
case MessageType::CAPTURE_REQUEST:
265+
message = std::make_unique<CaptureRequest>();
266+
break;
267+
case MessageType::CAPTURE_RESPONSE:
268+
message = std::make_unique<CaptureResponse>();
269+
break;
270+
case MessageType::DOWNLOAD_FILE_REQUEST:
271+
message = std::make_unique<DownloadFileRequest>();
272+
break;
273+
case MessageType::DOWNLOAD_FILE_RESPONSE:
274+
message = std::make_unique<DownloadFileResponse>();
275+
break;
276+
default:
277+
conn->Close();
278+
return absl::InvalidArgumentError(absl::StrCat("Unknown message type: ", type));
279+
}
280+
281+
status = message->Deserialize(payload_buffer);
282+
if (!status.ok())
283+
{
284+
conn->Close();
285+
return status;
286+
}
287+
288+
return message;
289+
}
290+
291+
absl::Status SendMessage(SocketConnection* conn, const ISerializable& message)
292+
{
293+
if (!conn)
294+
{
295+
return absl::InvalidArgumentError("Provided SocketConnection is null.");
296+
}
297+
298+
// Serialize the message payload.
299+
Buffer payload_buffer;
300+
absl::Status status = message.Serialize(payload_buffer);
301+
if (!status.ok())
302+
{
303+
return status;
304+
}
305+
if (payload_buffer.size() > kMaxPayloadSize)
306+
{
307+
return absl::InvalidArgumentError("Serialized payload size exceeds limit.");
308+
}
309+
310+
// Construct and send the header.
311+
uint32_t net_type = htonl(message.GetMessageType());
312+
uint32_t net_payload_length = htonl(static_cast<uint32_t>(payload_buffer.size()));
313+
const size_t header_size = sizeof(net_type) + sizeof(net_payload_length);
314+
uint8_t header_buffer[header_size];
315+
std::memcpy(header_buffer, &net_type, sizeof(uint32_t));
316+
std::memcpy(header_buffer + sizeof(uint32_t), &net_payload_length, sizeof(uint32_t));
317+
318+
status = SendFull(conn, header_buffer, header_size);
319+
if (!status.ok())
320+
{
321+
return status;
322+
}
323+
324+
// Send the payload.
325+
status = SendFull(conn, payload_buffer.data(), payload_buffer.size());
326+
if (!status.ok())
327+
{
328+
return status;
329+
}
330+
331+
return absl::OkStatus();
332+
}
333+
334+
} // namespace Network

0 commit comments

Comments
 (0)