Skip to content

Commit 99d38ca

Browse files
committed
[DO NOT MERGE] Add a Endpoint.destroy which closes active connections
Previously there was no way to locate the connections associated with an endpoint to shut them down. This patch adds a map of TCP `id` to `unit Lwt.u` and a function `Endpoint.destroy` which triggers the disconnection of all the active connections. Related to moby#260 Signed-off-by: David Scott <[email protected]>
1 parent 151ec4d commit 99d38ca

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

src/hostnet/hostnet_dns.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ struct
356356
| Ok buffer ->
357357
Udp.write ~src_port:53 ~dst:src ~dst_port:src_port udp buffer
358358

359-
let handle_tcp ~t =
359+
let handle_tcp ~t ~close =
360360
(* FIXME: need to record the upstream request *)
361361
let listeners _ =
362362
Log.debug (fun f -> f "DNS TCP handshake complete");
@@ -384,7 +384,10 @@ struct
384384
Lwt.async queries;
385385
loop ()
386386
in
387-
loop ()
387+
Lwt.pick [
388+
loop ();
389+
close
390+
]
388391
in
389392
Some f
390393
in

src/hostnet/hostnet_dns.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ sig
4040
t:t -> udp:Udp.t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> src_port:int ->
4141
Cstruct.t -> (unit, Udp.error) result Lwt.t
4242

43-
val handle_tcp: t:t -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
43+
val handle_tcp: t:t -> close:(unit Lwt.t) -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
4444

4545
val destroy: t -> unit Lwt.t
4646
end

src/hostnet/slirp.ml

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ struct
257257
clock: Clock.t;
258258
mutable pending: Tcp.Id.Set.t;
259259
mutable last_active_time: float;
260+
(* Tasks that will be signalled if the endpoint is destroyed *)
261+
mutable on_destroy: unit Lwt.u Tcp.Id.Map.t;
260262
}
261263
(** A generic TCP/IP endpoint *)
262264

@@ -279,12 +281,17 @@ struct
279281

280282
let pending = Tcp.Id.Set.empty in
281283
let last_active_time = Unix.gettimeofday () in
284+
let on_destroy = Tcp.Id.Map.empty in
282285
let tcp_stack =
283286
{ recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
284-
last_active_time; clock }
287+
last_active_time; clock; on_destroy }
285288
in
286289
Lwt.return tcp_stack
287290

291+
let destroy t =
292+
Tcp.Id.Map.iter (fun _ u -> Lwt.wakeup_later u ()) t.on_destroy;
293+
t.on_destroy <- Tcp.Id.Map.empty
294+
288295
let intercept_tcp_syn t ~id ~syn on_syn_callback (buf: Cstruct.t) =
289296
if syn then begin
290297
if Tcp.Id.Set.mem id t.pending then begin
@@ -295,9 +302,14 @@ struct
295302
Lwt.return_unit
296303
end else begin
297304
t.pending <- Tcp.Id.Set.add id t.pending;
305+
(* Add a task to the "on_destroy" list which will be signalled if
306+
the Endpoint is disconnected from the switch and we should close
307+
connections. *)
308+
let close, close_request = Lwt.task () in
309+
t.on_destroy <- Tcp.Id.Map.add id close_request t.on_destroy;
298310
Lwt.finalize
299311
(fun () ->
300-
on_syn_callback ()
312+
on_syn_callback close
301313
>>= fun listeners ->
302314
let src = Stack_tcp_wire.dst id in
303315
let dst = Stack_tcp_wire.src id in
@@ -319,7 +331,7 @@ struct
319331
Mirage_flow_lwt.Proxy(Clock)(Stack_tcp)(Host.Sockets.Stream.Tcp)
320332

321333
let input_tcp t ~id ~syn (ip, port) (buf: Cstruct.t) =
322-
intercept_tcp_syn t ~id ~syn (fun () ->
334+
intercept_tcp_syn t ~id ~syn (fun close ->
323335
Host.Sockets.Stream.Tcp.connect (ip, port)
324336
>>= function
325337
| Error (`Msg m) ->
@@ -341,9 +353,21 @@ struct
341353
Lwt.return_unit
342354
| Some socket ->
343355
Lwt.finalize (fun () ->
344-
Proxy.proxy t.clock flow socket
356+
Lwt.pick [
357+
Lwt.map
358+
(function Error e -> Error (`Proxy e) | Ok x -> Ok x)
359+
(Proxy.proxy t.clock flow socket);
360+
Lwt.map
361+
(fun () -> Error `Close)
362+
close
363+
]
345364
>>= function
346-
| Error e ->
365+
| Error (`Close) ->
366+
Log.info (fun f ->
367+
f "%s proxy closed due to switch port disconnection"
368+
(Tcp.Flow.to_string tcp));
369+
Lwt.return_unit
370+
| Error (`Proxy e) ->
347371
Log.debug (fun f ->
348372
f "%s proxy failed with %a"
349373
(Tcp.Flow.to_string tcp) Proxy.pp_error e);
@@ -354,6 +378,7 @@ struct
354378
Log.debug (fun f ->
355379
f "closing flow %s" (string_of_id tcp.Tcp.Flow.id));
356380
tcp.Tcp.Flow.socket <- None;
381+
t.on_destroy <- Tcp.Id.Map.remove id t.on_destroy;
357382
Tcp.Flow.remove tcp.Tcp.Flow.id;
358383
Host.Sockets.Stream.Tcp.close socket
359384
)
@@ -479,9 +504,9 @@ struct
479504
let id =
480505
Stack_tcp_wire.v ~src_port:53 ~dst:src ~src:dst ~dst_port:src_port
481506
in
482-
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
507+
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
483508
!dns >>= fun t ->
484-
Dns_forwarder.handle_tcp ~t
509+
Dns_forwarder.handle_tcp ~t ~close
485510
) raw
486511
>|= ok
487512

@@ -801,10 +826,11 @@ struct
801826
let now = Unix.gettimeofday () in
802827
let old_ips = IPMap.fold (fun ip endpoint acc ->
803828
let age = now -. endpoint.Endpoint.last_active_time in
804-
if age > 300.0 then ip :: acc else acc
829+
if age > 300.0 then (ip, endpoint) :: acc else acc
805830
) t.endpoints [] in
806-
List.iter (fun ip ->
831+
List.iter (fun (ip, endpoint) ->
807832
Switch.remove t.switch ip;
833+
Endpoint.destroy endpoint;
808834
t.endpoints <- IPMap.remove ip t.endpoints
809835
) old_ips;
810836
Lwt.return_unit

0 commit comments

Comments
 (0)