@@ -257,6 +257,8 @@ struct
257
257
clock : Clock .t ;
258
258
mutable pending : Tcp.Id.Set .t ;
259
259
mutable last_active_time : float ;
260
+ (* Tasks that will be signalled if the endpoint is destroyed *)
261
+ mutable on_destroy : unit Lwt .u Tcp.Id.Map .t ;
260
262
}
261
263
(* * A generic TCP/IP endpoint *)
262
264
@@ -279,12 +281,17 @@ struct
279
281
280
282
let pending = Tcp.Id.Set. empty in
281
283
let last_active_time = Unix. gettimeofday () in
284
+ let on_destroy = Tcp.Id.Map. empty in
282
285
let tcp_stack =
283
286
{ recorder; netif; ethif; arp; ipv4; icmpv4; udp4; tcp4; pending;
284
- last_active_time; clock }
287
+ last_active_time; clock; on_destroy }
285
288
in
286
289
Lwt. return tcp_stack
287
290
291
+ let destroy t =
292
+ Tcp.Id.Map. iter (fun _ u -> Lwt. wakeup_later u () ) t.on_destroy;
293
+ t.on_destroy < - Tcp.Id.Map. empty
294
+
288
295
let intercept_tcp_syn t ~id ~syn on_syn_callback (buf : Cstruct.t ) =
289
296
if syn then begin
290
297
if Tcp.Id.Set. mem id t.pending then begin
@@ -295,9 +302,14 @@ struct
295
302
Lwt. return_unit
296
303
end else begin
297
304
t.pending < - Tcp.Id.Set. add id t.pending;
305
+ (* Add a task to the "on_destroy" list which will be signalled if
306
+ the Endpoint is disconnected from the switch and we should close
307
+ connections. *)
308
+ let close, close_request = Lwt. task () in
309
+ t.on_destroy < - Tcp.Id.Map. add id close_request t.on_destroy;
298
310
Lwt. finalize
299
311
(fun () ->
300
- on_syn_callback ()
312
+ on_syn_callback close
301
313
>> = fun listeners ->
302
314
let src = Stack_tcp_wire. dst id in
303
315
let dst = Stack_tcp_wire. src id in
@@ -319,7 +331,7 @@ struct
319
331
Mirage_flow_lwt. Proxy (Clock )(Stack_tcp )(Host.Sockets.Stream. Tcp )
320
332
321
333
let input_tcp t ~id ~syn (ip , port ) (buf : Cstruct.t ) =
322
- intercept_tcp_syn t ~id ~syn (fun () ->
334
+ intercept_tcp_syn t ~id ~syn (fun close ->
323
335
Host.Sockets.Stream.Tcp. connect (ip, port)
324
336
>> = function
325
337
| Error (`Msg m ) ->
@@ -341,9 +353,21 @@ struct
341
353
Lwt. return_unit
342
354
| Some socket ->
343
355
Lwt. finalize (fun () ->
344
- Proxy. proxy t.clock flow socket
356
+ Lwt. pick [
357
+ Lwt. map
358
+ (function Error e -> Error (`Proxy e) | Ok x -> Ok x)
359
+ (Proxy. proxy t.clock flow socket);
360
+ Lwt. map
361
+ (fun () -> Error `Close )
362
+ close
363
+ ]
345
364
>> = function
346
- | Error e ->
365
+ | Error (`Close) ->
366
+ Log. info (fun f ->
367
+ f " %s proxy closed due to switch port disconnection"
368
+ (Tcp.Flow. to_string tcp));
369
+ Lwt. return_unit
370
+ | Error (`Proxy e ) ->
347
371
Log. debug (fun f ->
348
372
f " %s proxy failed with %a"
349
373
(Tcp.Flow. to_string tcp) Proxy. pp_error e);
@@ -354,6 +378,7 @@ struct
354
378
Log. debug (fun f ->
355
379
f " closing flow %s" (string_of_id tcp.Tcp.Flow. id));
356
380
tcp.Tcp.Flow. socket < - None ;
381
+ t.on_destroy < - Tcp.Id.Map. remove id t.on_destroy;
357
382
Tcp.Flow. remove tcp.Tcp.Flow. id;
358
383
Host.Sockets.Stream.Tcp. close socket
359
384
)
@@ -479,9 +504,9 @@ struct
479
504
let id =
480
505
Stack_tcp_wire. v ~src_port: 53 ~dst: src ~src: dst ~dst_port: src_port
481
506
in
482
- Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun () ->
507
+ Endpoint. intercept_tcp_syn t.endpoint ~id ~syn (fun close ->
483
508
! dns >> = fun t ->
484
- Dns_forwarder. handle_tcp ~t
509
+ Dns_forwarder. handle_tcp ~t ~close
485
510
) raw
486
511
> |= ok
487
512
@@ -801,10 +826,11 @@ struct
801
826
let now = Unix. gettimeofday () in
802
827
let old_ips = IPMap. fold (fun ip endpoint acc ->
803
828
let age = now -. endpoint.Endpoint. last_active_time in
804
- if age > 300.0 then ip :: acc else acc
829
+ if age > 300.0 then (ip, endpoint) :: acc else acc
805
830
) t.endpoints [] in
806
- List. iter (fun ip ->
831
+ List. iter (fun ( ip , endpoint ) ->
807
832
Switch. remove t.switch ip;
833
+ Endpoint. destroy endpoint;
808
834
t.endpoints < - IPMap. remove ip t.endpoints
809
835
) old_ips;
810
836
Lwt. return_unit
0 commit comments