Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Lwt.Syntax and avoid some >>= fun () patterns #197

Merged
merged 10 commits into from
Oct 16, 2024
2 changes: 1 addition & 1 deletion build-with.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ $builder build -t qubes-mirage-firewall .
echo Building Firewall...
$builder run --rm -i -v `pwd`:/tmp/orb-build:Z qubes-mirage-firewall
echo "SHA2 of build: $(sha256sum ./dist/qubes-firewall.xen)"
echo "SHA2 last known: 4b1f743bf4540bc8a9366cf8f23a78316e4f2d477af77962e50618753c4adf10"
echo "SHA2 last known: 2392386d9056b17a648f26b0c5d1c72b93f8a197964c670b2b45e71707727317"
echo "(hashes should match for released versions)"
16 changes: 8 additions & 8 deletions client_eth.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ let src = Logs.Src.create "client_eth" ~doc:"Ethernet networks for NetVM clients
module Log = (val Logs.src_log src : Logs.LOG)

type t = {
mutable iface_of_ip : client_link IpMap.t;
mutable iface_of_ip : client_link Ipaddr.V4.Map.t;
changed : unit Lwt_condition.t; (* Fires when [iface_of_ip] changes. *)
my_ip : Ipaddr.V4.t; (* The IP that clients are given as their default gateway. *)
}
Expand All @@ -21,33 +21,33 @@ type host =
let create config =
let changed = Lwt_condition.create () in
let my_ip = config.Dao.our_ip in
Lwt.return { iface_of_ip = IpMap.empty; my_ip; changed }
Lwt.return { iface_of_ip = Ipaddr.V4.Map.empty; my_ip; changed }

let client_gw t = t.my_ip

let add_client t iface =
let ip = iface#other_ip in
let rec aux () =
match IpMap.find ip t.iface_of_ip with
match Ipaddr.V4.Map.find_opt ip t.iface_of_ip with
| Some old ->
(* Wait for old client to disappear before adding one with the same IP address.
Otherwise, its [remove_client] call will remove the new client instead. *)
Log.info (fun f -> f ~header:iface#log_header "Waiting for old client %s to go away before accepting new one" old#log_header);
Lwt_condition.wait t.changed >>= aux
| None ->
t.iface_of_ip <- t.iface_of_ip |> IpMap.add ip iface;
t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.add ip iface;
Lwt_condition.broadcast t.changed ();
Lwt.return_unit
in
aux ()

let remove_client t iface =
let ip = iface#other_ip in
assert (IpMap.mem ip t.iface_of_ip);
t.iface_of_ip <- t.iface_of_ip |> IpMap.remove ip;
assert (Ipaddr.V4.Map.mem ip t.iface_of_ip);
t.iface_of_ip <- t.iface_of_ip |> Ipaddr.V4.Map.remove ip;
Lwt_condition.broadcast t.changed ()

let lookup t ip = IpMap.find ip t.iface_of_ip
let lookup t ip = Ipaddr.V4.Map.find_opt ip t.iface_of_ip

let classify t ip =
match ip with
Expand Down Expand Up @@ -79,7 +79,7 @@ module ARP = struct
(* We're now treating client networks as point-to-point links,
so we no longer respond on behalf of other clients. *)
(*
else match IpMap.find ip t.net.iface_of_ip with
else match Ipaddr.V4.Map.find_opt ip t.net.iface_of_ip with
| Some client_iface -> Some client_iface#other_mac
| None -> None
*)
Expand Down
67 changes: 32 additions & 35 deletions dao.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,43 +65,40 @@ let read_rules rules client_ip =
number = 0;})]

let vifs client domid =
let open Lwt.Syntax in
match int_of_string_opt domid with
| None -> Log.err (fun f -> f "Invalid domid %S" domid); Lwt.return []
| Some domid ->
let path = Printf.sprintf "backend/vif/%d" domid in
Xen_os.Xs.immediate client (fun handle ->
directory ~handle path >>=
Lwt_list.filter_map_p (fun device_id ->
match int_of_string_opt device_id with
| None -> Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid); Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
Lwt.try_bind
(fun () -> Xen_os.Xs.read handle (Printf.sprintf "%s/%d/ip" path device_id))
(fun client_ip ->
let client_ip' = match String.split_on_char ' ' client_ip with
| [] -> Log.err (fun m -> m "unexpected empty list"); ""
| [ ip ] -> ip
| ip::rest ->
Log.warn (fun m -> m "ignoring IPs %s from %a, we support one IP per client"
(String.concat " " rest) ClientVif.pp vif);
ip
in
match Ipaddr.V4.of_string client_ip' with
| Ok ip -> Lwt.return (Some (vif, ip))
| Error `Msg msg ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return None
)
(function
| Xs_protocol.Enoent _ -> Lwt.return None
| ex ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string ex));
Lwt.return None
)
))
let path = Fmt.str "backend/vif/%d" domid in
let vifs_of_domain handle =
let* devices = directory ~handle path in
let ip_of_vif device_id = match int_of_string_opt device_id with
| None ->
Log.err (fun f -> f "Invalid device ID %S for domid %d" device_id domid);
Lwt.return_none
| Some device_id ->
let vif = { ClientVif.domid; device_id } in
let get_client_ip () =
let* str = Xen_os.Xs.read handle (Fmt.str "%s/%d/ip" path device_id) in
let client_ip = List.hd (String.split_on_char ' ' str) in
dinosaure marked this conversation as resolved.
Show resolved Hide resolved
(* NOTE(dinosaure): it's safe to use [List.hd] here,
[String.split_on_char] can not return an empty list. *)
Lwt.return_some (vif, Ipaddr.V4.of_string_exn client_ip)
in
Lwt.catch get_client_ip @@ function
| Xs_protocol.Enoent _ -> Lwt.return_none
| Ipaddr.Parse_error (msg, client_ip) ->
Log.err (fun f -> f "Error parsing IP address of %a from %s: %s"
ClientVif.pp vif client_ip msg);
Lwt.return_none
| exn ->
Log.err (fun f -> f "Error getting IP address of %a: %s"
ClientVif.pp vif (Printexc.to_string exn));
Lwt.return_none
in
Lwt_list.filter_map_p ip_of_vif devices
in
Xen_os.Xs.immediate client vifs_of_domain

let watch_clients fn =
Xen_os.Xs.make () >>= fun xs ->
Expand All @@ -116,7 +113,7 @@ let watch_clients fn =
end >>= fun items ->
Xen_os.Xs.make () >>= fun xs ->
Lwt_list.map_p (vifs xs) items >>= fun items ->
fn (List.concat items |> VifMap.of_list);
fn (List.concat items |> VifMap.of_list) >>= fun () ->
(* Wait for further updates *)
Lwt.fail Xs_protocol.Eagain
)
Expand Down
2 changes: 1 addition & 1 deletion dao.mli
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module VifMap : sig
val find : key -> 'a t -> 'a option
end

val watch_clients : (Ipaddr.V4.t VifMap.t -> unit) -> 'a Lwt.t
val watch_clients : (Ipaddr.V4.t VifMap.t -> unit Lwt.t) -> 'a Lwt.t
(** [watch_clients fn] calls [fn clients] with the list of backend clients
in XenStore, and again each time XenStore updates. *)

Expand Down
77 changes: 42 additions & 35 deletions dispatcher.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ struct
module I = Static_ipv4.Make (R) (Clock) (UplinkEth) (Arp)
module U = Udp.Make (I) (R)

let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty

class client_iface eth ~domid ~gateway_ip ~client_ip client_mac : client_link
=
let log_header = Fmt.str "dom%d:%a" domid Ipaddr.V4.pp client_ip in
Expand Down Expand Up @@ -344,11 +342,12 @@ struct

(** Connect to a new client's interface and listen for incoming frames and firewall rule changes. *)
let add_vif get_ts { Dao.ClientVif.domid; device_id } dns_client dns_servers
~client_ip ~router ~cleanup_tasks qubesDB =
Netback.make ~domid ~device_id >>= fun backend ->
~client_ip ~router ~cleanup_tasks qubesDB () =
let open Lwt.Syntax in
let* backend = Netback.make ~domid ~device_id in
Log.info (fun f ->
f "Client %d (IP: %s) ready" domid (Ipaddr.V4.to_string client_ip));
ClientEth.connect backend >>= fun eth ->
let* eth = ClientEth.connect backend in
let client_mac = Netback.frontend_mac backend in
let client_eth = router.clients in
let gateway_ip = Client_eth.client_gw client_eth in
Expand Down Expand Up @@ -404,46 +403,54 @@ struct
(function Lwt.Canceled -> Lwt.return_unit | e -> Lwt.fail e)
in
Cleanup.on_cleanup cleanup_tasks (fun () -> Lwt.cancel listener);
Lwt.pick [ qubesdb_updater; listener ]
(* NOTE(dinosaure): [qubes_updater] and [listener] can be forgotten, our [cleanup_task]
will cancel them if the client is disconnected. *)
Lwt.async (fun () -> Lwt.pick [ qubesdb_updater; listener ]);
Lwt.return_unit

(** A new client VM has been found in XenStore. Find its interface and connect to it. *)
let add_client get_ts dns_client dns_servers ~router vif client_ip qubesDB =
let open Lwt.Syntax in
let cleanup_tasks = Cleanup.create () in
Log.info (fun f ->
f "add client vif %a with IP %a" Dao.ClientVif.pp vif Ipaddr.V4.pp
client_ip);
Lwt.async (fun () ->
Lwt.catch
(fun () ->
add_vif get_ts vif dns_client dns_servers ~client_ip ~router
~cleanup_tasks qubesDB)
(fun ex ->
Log.warn (fun f ->
f "Error with client %a: %s" Dao.ClientVif.pp vif
(Printexc.to_string ex));
Lwt.return_unit));
cleanup_tasks
let* () =
Lwt.catch (add_vif get_ts vif dns_client dns_servers ~client_ip ~router
~cleanup_tasks qubesDB)
@@ fun exn ->
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally struggle with @@ and find parenthesis much easier to read (since I'm always unsure how strong the @@ binding strength is.

Log.warn (fun f ->
f "Error with client %a: %s" Dao.ClientVif.pp vif
(Printexc.to_string exn));
Lwt.return_unit
in
Lwt.return cleanup_tasks

(** Watch XenStore for notifications of new clients. *)
let wait_clients get_ts dns_client dns_servers qubesDB router =
Dao.watch_clients (fun new_set ->
(* Check for removed clients *)
!clients
|> Dao.VifMap.iter (fun key cleanup ->
if not (Dao.VifMap.mem key new_set) then (
clients := !clients |> Dao.VifMap.remove key;
Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key);
Cleanup.cleanup cleanup));
(* Check for added clients *)
new_set
|> Dao.VifMap.iter (fun key ip_addr ->
if not (Dao.VifMap.mem key !clients) then (
let cleanup =
add_client get_ts dns_client dns_servers ~router key ip_addr
qubesDB
in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := !clients |> Dao.VifMap.add key cleanup)))
let open Lwt.Syntax in
let clients : Cleanup.t Dao.VifMap.t ref = ref Dao.VifMap.empty in
Dao.watch_clients @@ fun new_set ->
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above about @@.

(* Check for removed clients *)
let clean_up_clients key cleanup =
if not (Dao.VifMap.mem key new_set) then begin
clients := !clients |> Dao.VifMap.remove key;
Log.info (fun f -> f "client %a has gone" Dao.ClientVif.pp key);
Cleanup.cleanup cleanup
end
in
Dao.VifMap.iter clean_up_clients !clients;
(* Check for added clients *)
let rec go seq = match Seq.uncons seq with
| None -> Lwt.return_unit
| Some ((key, ipaddr), seq) when not (Dao.VifMap.mem key !clients) ->
let* cleanup = add_client get_ts dns_client dns_servers ~router key ipaddr qubesDB in
Log.debug (fun f -> f "client %a arrived" Dao.ClientVif.pp key);
clients := Dao.VifMap.add key cleanup !clients;
go seq
| Some (_, seq) -> go seq
in
go (Dao.VifMap.to_seq new_set)

let send_dns_client_query t ~src_port ~dst ~dst_port buf =
match t.uplink with
Expand Down
8 changes: 0 additions & 8 deletions fw_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@

(** General utility functions. *)

module IpMap = struct
include Map.Make(Ipaddr.V4)
let find x map =
try Some (find x map)
with Not_found -> None
| _ -> Logs.err( fun f -> f "uncaught exception in find...%!"); None
end

(** An Ethernet interface. *)
class type interface = object
method my_mac : Macaddr.t
Expand Down
11 changes: 4 additions & 7 deletions unikernel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,12 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :

(* Main unikernel entry point (called from auto-generated main.ml). *)
let start _random _clock _time =
let open Lwt.Syntax in
let start_time = Clock.elapsed_ns () in
(* Start qrexec agent and QubesDB agent in parallel *)
let qrexec = RExec.connect ~domid:0 () in
let qubesDB = DB.connect ~domid:0 () in

(* Wait for clients to connect *)
qrexec >>= fun qrexec ->
let* qrexec = RExec.connect ~domid:0 () in
let agent_listener = RExec.listen qrexec Command.handler in
qubesDB >>= fun qubesDB ->
let* qubesDB = DB.connect ~domid:0 () in
let startup_time =
let (-) = Int64.sub in
let time_in_ns = Clock.elapsed_ns () - start_time in
Expand Down Expand Up @@ -93,7 +90,7 @@ module Main (R : Mirage_crypto_rng_mirage.S)(Clock : Mirage_clock.MCLOCK)(Time :
Dao.print_network_config config ;

(* Set up client-side networking *)
Client_eth.create config >>= fun clients ->
let* clients = Client_eth.create config in

(* Set up routing between networks and hosts *)
let router = Dispatcher.create
Expand Down
Loading