Skip to content

Commit

Permalink
Add support for udp multicast
Browse files Browse the repository at this point in the history
Signed-off-by: Paul Guyot <[email protected]>
  • Loading branch information
pguyot committed Jan 23, 2025
1 parent 279c44a commit a613c12
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `externalterm_to_term_with_roots` to efficiently preserve roots when allocating memory for external terms.
- Added `erl_epmd` client implementation to epmd using `socket` module
- Added support for socket asynchronous API for `recv`, `recvfrom` and `accept`.
- Added support for UDP multicast with socket API.

### Changed

Expand Down
8 changes: 6 additions & 2 deletions libs/estdlib/src/socket.erl
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,12 @@
}.
-type in_addr() :: {0..255, 0..255, 0..255, 0..255}.
-type port_number() :: 0..65535.
-type ip_mreq() :: #{multiaddr := in_addr(), interface := in_addr()}.

-type socket_option() ::
{socket, reuseaddr | linger | type}
| {otp, recvbuf}.
| {otp, recvbuf}
| {ip, add_membership}.

-export_type([
socket/0,
Expand All @@ -80,7 +82,8 @@
sockaddr_in/0,
in_addr/0,
port_number/0,
socket_option/0
socket_option/0,
ip_mreq/0
]).

-define(DEFAULT_BACKLOG, 4).
Expand Down Expand Up @@ -647,6 +650,7 @@ getopt(_Socket, _SocketOption) ->
%% <tr><td>`{socket, reuseaddr}'</td><td>`boolean()'</td></tr>
%% <tr><td>`{socket, linger}'</td><td>`#{onoff => boolean(), linger => non_neg_integer()}'</td></tr>
%% <tr><td>`{otp, recvbuf}'</td><td>`non_neg_integer()'</td></tr>
%% <tr><td>`{ip, add_membership}'</td><td>`ip_mreq()'</td></tr>
%% </table>
%%
%% Example:
Expand Down
71 changes: 60 additions & 11 deletions src/libAtomVM/otp_socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ static const char *const port_atom = ATOM_STR("\x4", "port");
static const char *const rcvbuf_atom = ATOM_STR("\x6", "rcvbuf");
static const char *const reuseaddr_atom = ATOM_STR("\x9", "reuseaddr");
static const char *const type_atom = ATOM_STR("\x4", "type");
static const char *const add_membership_atom = ATOM_STR("\xE", "add_membership");

#define CLOSED_FD 0

Expand Down Expand Up @@ -221,12 +222,14 @@ enum otp_socket_setopt_level
{
OtpSocketInvalidSetoptLevel = 0,
OtpSocketSetoptLevelSocket,
OtpSocketSetoptLevelOTP
OtpSocketSetoptLevelOTP,
OtpSocketSetoptLevelIP
};

static const AtomStringIntPair otp_socket_setopt_level_table[] = {
{ ATOM_STR("\x6", "socket"), OtpSocketSetoptLevelSocket },
{ ATOM_STR("\x3", "otp"), OtpSocketSetoptLevelOTP },
{ ATOM_STR("\x2", "ip"), OtpSocketSetoptLevelIP },
SELECT_INT_DEFAULT(OtpSocketInvalidSetoptLevel)
};

Expand Down Expand Up @@ -604,7 +607,7 @@ static term nif_socket_open(Context *ctx, int argc, term argv[])
}

term socket_term = term_alloc_tuple(2, &ctx->heap);
uint64_t ref_ticks = globalcontext_get_ref_ticks(ctx->global);
uint64_t ref_ticks = globalcontext_get_ref_ticks(global);
rsrc_obj->socket_ref_ticks = ref_ticks;
term ref = term_from_ref_ticks(ref_ticks, &ctx->heap);
term_put_tuple_element(socket_term, 0, obj);
Expand Down Expand Up @@ -1261,8 +1264,8 @@ static term nif_socket_setopt(Context *ctx, int argc, term argv[])
return OK_ATOM;
#endif
} else if (globalcontext_is_term_equal_to_atom_string(global, opt, linger_atom)) {
term onoff = interop_kv_get_value(value, onoff_atom, ctx->global);
term linger = interop_kv_get_value(value, linger_atom, ctx->global);
term onoff = interop_kv_get_value(value, onoff_atom, global);
term linger = interop_kv_get_value(value, linger_atom, global);
VALIDATE_VALUE(linger, term_is_integer);

#if OTP_SOCKET_BSD
Expand Down Expand Up @@ -1323,6 +1326,52 @@ static term nif_socket_setopt(Context *ctx, int argc, term argv[])
}
}

#if OTP_SOCKET_BSD
case OtpSocketSetoptLevelIP: {
term opt = term_get_tuple_element(level_tuple, 1);
if (globalcontext_is_term_equal_to_atom_string(global, opt, add_membership_atom)) {
// socket:setopt(Socket, {ip, add_membership_atom}, Req :: ip_mreq())

if (UNLIKELY(!term_is_map(value))) {
TRACE("socket:setopt: ip add_membership_atom value must be a map");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

term multiaddr = interop_kv_get_value(value, ATOM_STR("\x9", "multiaddr"), global);
if (UNLIKELY(!term_is_tuple(multiaddr) || term_get_tuple_arity(multiaddr) != 4)) {
TRACE("socket:setopt: ip add_membership_atom multiaddr value must be an IP addr");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

term interface = interop_kv_get_value(value, ATOM_STR("\x9", "interface"), global);
if (UNLIKELY(!term_is_tuple(interface) || term_get_tuple_arity(interface) != 4)) {
TRACE("socket:setopt: ip add_membership_atom interface value must be an IP addr");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_value_atom), ctx);
}

struct ip_mreq option_value;
option_value.imr_multiaddr.s_addr = htonl(inet_addr4_to_uint32(multiaddr));
option_value.imr_interface.s_addr = htonl(inet_addr4_to_uint32(interface));

int res = setsockopt(rsrc_obj->fd, IPPROTO_IP, IP_ADD_MEMBERSHIP, &option_value, sizeof(option_value));

SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
if (UNLIKELY(res != 0)) {
return make_errno_tuple(ctx);
} else {
return OK_ATOM;
}
} else {
TRACE("socket:setopt: Unsupported ip option");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(globalcontext_make_atom(global, invalid_option_atom), ctx);
}
}
#endif

default: {
TRACE("socket:setopt: Unsupported level");
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
Expand Down Expand Up @@ -1538,9 +1587,9 @@ static term nif_socket_bind(Context *ctx, int argc, term argv[])
ip_addr_set_loopback(false, &ip_addr);
#endif
} else if (term_is_map(sockaddr)) {
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), ctx->global);
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), global);
port_u16 = term_to_int(port);
term addr = interop_kv_get_value(sockaddr, addr_atom, ctx->global);
term addr = interop_kv_get_value(sockaddr, addr_atom, global);
if (globalcontext_is_term_equal_to_atom_string(global, addr, any_atom)) {
#if OTP_SOCKET_BSD
serveraddr.sin_addr.s_addr = htonl(INADDR_ANY);
Expand Down Expand Up @@ -1764,7 +1813,7 @@ static term nif_socket_accept(Context *ctx, int argc, term argv[])
}

term socket_term = term_alloc_tuple(2, &ctx->heap);
uint64_t ref_ticks = globalcontext_get_ref_ticks(ctx->global);
uint64_t ref_ticks = globalcontext_get_ref_ticks(global);
conn_rsrc_obj->socket_ref_ticks = ref_ticks;
term ref = term_from_ref_ticks(ref_ticks, &ctx->heap);
term_put_tuple_element(socket_term, 0, new_resource);
Expand Down Expand Up @@ -1808,7 +1857,7 @@ static term nif_socket_accept(Context *ctx, int argc, term argv[])
// return EAGAIN
LWIP_END();
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
return make_error_tuple(posix_errno_to_term(EAGAIN, ctx->global), ctx);
return make_error_tuple(posix_errno_to_term(EAGAIN, global), ctx);
}
LWIP_END();
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
Expand Down Expand Up @@ -2430,7 +2479,7 @@ static term nif_socket_send_internal(Context *ctx, int argc, term argv[], bool i
RAISE_ERROR(OUT_OF_MEMORY_ATOM);
}

term rest = term_maybe_create_sub_binary(data, sent_data, rest_len, &ctx->heap, ctx->global);
term rest = term_maybe_create_sub_binary(data, sent_data, rest_len, &ctx->heap, global);
return port_create_tuple2(ctx, OK_ATOM, rest);

} else if (sent_data == 0) {
Expand Down Expand Up @@ -2535,8 +2584,8 @@ static term nif_socket_connect(Context *ctx, int argc, term argv[])

SMP_RWLOCK_RDLOCK(rsrc_obj->socket_lock);
term sockaddr = argv[1];
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), ctx->global);
term addr = interop_kv_get_value(sockaddr, addr_atom, ctx->global);
term port = interop_kv_get_value_default(sockaddr, port_atom, term_from_int(0), global);
term addr = interop_kv_get_value(sockaddr, addr_atom, global);
if (term_is_invalid_term(addr)) {
SMP_RWLOCK_UNLOCK(rsrc_obj->socket_lock);
RAISE_ERROR(BADARG_ATOM);
Expand Down
34 changes: 34 additions & 0 deletions tests/libs/estdlib/test_udp_socket.erl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ test() ->
ok = test_timeout(),
ok = test_nowait(),
ok = test_setopt_getopt(),
ok = test_multicast(),
ok.

-define(PACKET_SIZE, 7).
Expand Down Expand Up @@ -290,3 +291,36 @@ test_setopt_getopt() ->
{error, closed} = socket:getopt(Socket, {socket, type}),
{error, closed} = socket:setopt(Socket, {socket, reuseaddr}, true),
ok.

test_multicast() ->
{ok, SocketRecv} = socket:open(inet, dgram, udp),
SocketRecvAddr = #{
family => inet, addr => {0, 0, 0, 0}, port => 8042
},
ok = socket:setopt(SocketRecv, {socket, reuseaddr}, true),
ok = socket:bind(SocketRecv, SocketRecvAddr),
ok = socket:setopt(SocketRecv, {ip, add_membership}, #{
multiaddr => {224, 0, 0, 42}, interface => {0, 0, 0, 0}
}),

{ok, SocketSender} = socket:open(inet, dgram, udp),
ok = socket:sendto(SocketSender, <<"42">>, #{
family => inet, addr => {224, 0, 0, 42}, port => 8042
}),
{ok, SocketSenderAddr} = socket:sockname(SocketSender),
SocketSenderAddrPort = maps:get(port, SocketSenderAddr),

{ok, {SocketSenderAddrFrom, <<"42">>}} = socket:recvfrom(SocketRecv, 2, 500),
{error, timeout} = socket:recvfrom(SocketRecv, 2, 0),
SocketSenderAddrPort = maps:get(port, SocketSenderAddrFrom),

ok = socket:sendto(SocketRecv, <<"43">>, #{
family => inet, addr => {224, 0, 0, 42}, port => 8042
}),
{ok, {SocketRecvAddrFrom, <<"43">>}} = socket:recvfrom(SocketRecv, 2, 500),
{error, timeout} = socket:recvfrom(SocketRecv, 2, 0),
8042 = maps:get(port, SocketRecvAddrFrom),

ok = socket:close(SocketRecv),
ok = socket:close(SocketSender),
ok.

0 comments on commit a613c12

Please sign in to comment.