Skip to content

Commit

Permalink
Add a Endpoint.destroy which closes active connections
Browse files Browse the repository at this point in the history
Previously there was no way to locate the connections associated with
an endpoint to shut them down. This patch adds a map of TCP `id` to
`unit Lwt.u` and a function `Endpoint.destroy` which triggers the
disconnection of all the active connections.

Related to moby#260

Signed-off-by: David Scott <[email protected]>
  • Loading branch information
djs55 committed Aug 17, 2017
1 parent 47a530f commit 36b23db
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/hostnet/hostnet_dns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ struct
| Ok buffer ->
Udp.write ~src_port:53 ~dst:src ~dst_port:src_port udp buffer

let handle_tcp ~t =
let handle_tcp ~t ~close =
(* FIXME: need to record the upstream request *)
let listeners _ =
Log.debug (fun f -> f "DNS TCP handshake complete");
Expand Down Expand Up @@ -384,7 +384,10 @@ struct
Lwt.async queries;
loop ()
in
loop ()
Lwt.pick [
loop ();
close
]
in
Some f
in
Expand Down
2 changes: 1 addition & 1 deletion src/hostnet/hostnet_dns.mli
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ sig
t:t -> udp:Udp.t -> src:Ipaddr.V4.t -> dst:Ipaddr.V4.t -> src_port:int ->
Cstruct.t -> (unit, Udp.error) result Lwt.t

val handle_tcp: t:t -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t
val handle_tcp: t:t -> close:(unit Lwt.t) -> (int -> (Tcp.flow -> unit Lwt.t) option) Lwt.t

val destroy: t -> unit Lwt.t
end
44 changes: 35 additions & 9 deletions src/hostnet/slirp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ struct
clock: Clock.t;
mutable pending: Tcp.Id.Set.t;
mutable last_active_time: float;
(* Tasks that will be signalled if the endpoint is destroyed *)
mutable on_destroy: unit Lwt.u Tcp.Id.Map.t;
}
(** A generic TCP/IP endpoint *)

Expand All @@ -284,12 +286,17 @@ struct

let pending = Tcp.Id.Set.empty in
let last_active_time = Unix.gettimeofday () in
let on_destroy = Tcp.Id.Map.empty in
let tcp_stack =
{ recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
last_active_time; clock }
last_active_time; clock; on_destroy }
in
Lwt.return tcp_stack

let destroy t =
Tcp.Id.Map.iter (fun _ u -> Lwt.wakeup_later u ()) t.on_destroy;
t.on_destroy <- Tcp.Id.Map.empty

let intercept_tcp_syn t ~id ~syn on_syn_callback (buf: Cstruct.t) =
if syn then begin
if Tcp.Id.Set.mem id t.pending then begin
Expand All @@ -300,9 +307,14 @@ struct
Lwt.return_unit
end else begin
t.pending <- Tcp.Id.Set.add id t.pending;
(* Add a task to the "on_destroy" list which will be signalled if
the Endpoint is disconnected from the switch and we should close
connections. *)
let close, close_request = Lwt.task () in
t.on_destroy <- Tcp.Id.Map.add id close_request t.on_destroy;
Lwt.finalize
(fun () ->
on_syn_callback ()
on_syn_callback close
>>= fun listeners ->
let src = Stack_tcp_wire.dst id in
let dst = Stack_tcp_wire.src id in
Expand All @@ -324,7 +336,7 @@ struct
Mirage_flow_lwt.Proxy(Clock)(Stack_tcp)(Host.Sockets.Stream.Tcp)

let input_tcp t ~id ~syn (ip, port) (buf: Cstruct.t) =
intercept_tcp_syn t ~id ~syn (fun () ->
intercept_tcp_syn t ~id ~syn (fun close ->
Host.Sockets.Stream.Tcp.connect (ip, port)
>>= function
| Error (`Msg m) ->
Expand All @@ -346,9 +358,21 @@ struct
Lwt.return_unit
| Some socket ->
Lwt.finalize (fun () ->
Proxy.proxy t.clock flow socket
Lwt.pick [
Lwt.map
(function Error e -> Error (`Proxy e) | Ok x -> Ok x)
(Proxy.proxy t.clock flow socket);
Lwt.map
(fun () -> Error `Close)
close
]
>>= function
| Error e ->
| Error (`Close) ->
Log.info (fun f ->
f "%s proxy closed due to switch port disconnection"
(Tcp.Flow.to_string tcp));
Lwt.return_unit
| Error (`Proxy e) ->
Log.debug (fun f ->
f "%s proxy failed with %a"
(Tcp.Flow.to_string tcp) Proxy.pp_error e);
Expand All @@ -359,6 +383,7 @@ struct
Log.debug (fun f ->
f "closing flow %s" (string_of_id tcp.Tcp.Flow.id));
tcp.Tcp.Flow.socket <- None;
t.on_destroy <- Tcp.Id.Map.remove id t.on_destroy;
Tcp.Flow.remove tcp.Tcp.Flow.id;
Host.Sockets.Stream.Tcp.close socket
)
Expand Down Expand Up @@ -484,9 +509,9 @@ struct
let id =
Stack_tcp_wire.v ~src_port:53 ~dst:src ~src:dst ~dst_port:src_port
in
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
Endpoint.intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
!dns >>= fun t ->
Dns_forwarder.handle_tcp ~t
Dns_forwarder.handle_tcp ~t ~close
) raw
>|= ok

Expand Down Expand Up @@ -808,10 +833,11 @@ struct
let now = Unix.gettimeofday () in
let old_ips = IPMap.fold (fun ip endpoint acc ->
let age = now -. endpoint.Endpoint.last_active_time in
if age > (float_of_int port_max_idle_time) then ip :: acc else acc
if age > (float_of_int port_max_idle_time) then (ip, endpoint) :: acc else acc
) t.endpoints [] in
List.iter (fun ip ->
List.iter (fun (ip, endpoint) ->
Switch.remove t.switch ip;
Endpoint.destroy endpoint;
t.endpoints <- IPMap.remove ip t.endpoints
) old_ips;
Lwt.return_unit
Expand Down

0 comments on commit 36b23db

Please sign in to comment.