From 8a61c6b69414f43a6cd62ac03e83005678356aa3 Mon Sep 17 00:00:00 2001 From: Shmulik Ladkani Date: Wed, 4 Jun 2025 12:13:49 +0300 Subject: [PATCH] Allow restrictions based on custom Authorization Header matching logic In commit e1205b72b8 ("Allow restrictions based on Authorization header"), we added restriction match type that tests the Auth Header of the websocket upgrade request matches a regex. We'd like to augment that approach for cases where the wstunnel server needs to perform some computation on the presented Auth Header value, or perform some custom matching logic. Examples would be computing a hash, performing jwt validation, comparing to a dynamic set of values, etc. There could be numerous different custom match implementations. So instead of suggesting additional tailored MatchConfig types, lets allow the admin to pass a lua script that implements his custom match logic. The script is processed and a global function named 'auth_validate' is invoked, given the Auth Header value as a parameter. The function must return true/false; true iff access is granted. example lua script: -- local match = string.match function auth_validate(auth) if match(auth, "^Basic aGk6dGhlcmU=$") then return true end return false end -- Any failure loading the script or invoking the lua function is considered an authentication failure. --- Cargo.lock | 98 ++++++++++++++++++++++++++++- restrictions.yaml | 10 ++- wstunnel/Cargo.toml | 2 + wstunnel/src/restrictions/types.rs | 1 + wstunnel/src/tunnel/server/utils.rs | 38 +++++++++++ 5 files changed, 145 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5dad8a28..b6fb14d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -284,10 +284,10 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 1.1.0", "shlex", "syn", - "which", + "which 4.4.2", ] [[package]] @@ -361,6 +361,16 @@ dependencies = [ "serde_with", ] +[[package]] +name = "bstr" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "bumpalo" version = "3.17.0" @@ -774,6 +784,12 @@ dependencies = [ "syn", ] +[[package]] +name = "env_home" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f84e12ccf0a7ddc17a6c41c93326024c42920d7ee630d04950e6926645c0fe" + [[package]] name = "equivalent" version = "1.0.2" @@ -1720,6 +1736,25 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "lua-src" +version = "547.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1edaf29e3517b49b8b746701e5648ccb5785cde1c119062cbabbc5d5cd115e42" +dependencies = [ + "cc", +] + +[[package]] +name = "luajit-src" +version = "210.5.12+a4f56a4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3a8e7962a5368d5f264d045a5a255e90f9aa3fc1941ae15a8d2940d42cac671" +dependencies = [ + "cc", + "which 7.0.3", +] + [[package]] name = "matchers" version = "0.1.0" @@ -1771,6 +1806,34 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mlua" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1f5f8fbebc7db5f671671134b9321c4b9aa9adeafccfd9a8c020ae45c6a35d0" +dependencies = [ + "bstr", + "either", + "mlua-sys", + "num-traits", + "parking_lot", + "rustc-hash 2.1.1", + "rustversion", +] + +[[package]] +name = "mlua-sys" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380c1f7e2099cafcf40e51d3a9f20a346977587aa4d012eae1f043149a728a93" +dependencies = [ + "cc", + "cfg-if", + "lua-src", + "luajit-src", + "pkg-config", +] + [[package]] name = "moka" version = "0.12.10" @@ -2029,6 +2092,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "portable-atomic" version = "1.11.0" @@ -2313,6 +2382,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -3411,6 +3486,18 @@ dependencies = [ "rustix 0.38.44", ] +[[package]] +name = "which" +version = "7.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" +dependencies = [ + "either", + "env_home", + "rustix 1.0.7", + "winsafe", +] + [[package]] name = "widestring" version = "1.2.0" @@ -3862,6 +3949,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "winsafe" +version = "0.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" + [[package]] name = "wit-bindgen-rt" version = "0.39.0" @@ -3904,6 +3997,7 @@ dependencies = [ "ipnet", "jsonwebtoken", "log", + "mlua", "nix", "notify", "parking_lot", diff --git a/restrictions.yaml b/restrictions.yaml index 7074944a..af273525 100644 --- a/restrictions.yaml +++ b/restrictions.yaml @@ -10,9 +10,15 @@ restrictions: # The regex does a match, so if you want to match exactly you need to bound the pattern with ^ $ # I.e: "tesotron" is going to match "XXXtesotronXXX", but "^tesotron$" is going to match only "tesotron" - !PathPrefix "^.*$" - # This match applies only if it succeeds to match the Authentication Header with the given regex. - # If present, Authentication Header must exists and must match the regex. + # This match applies only if it succeeds to match the Authorization Header with the given regex. + # If present, Authorization Header must exists and must match the regex. # - !Authorization "^[Bb]earer +actual_bearer_token_to_match$" + # This match applies if the Authorization Header is allowed according to a custom lua script. + # The lua script must contain a global function named 'auth_validate' that gets the Auth Header value + # as a string parameter, and returns true iff the custom logic decides the value is authorized. + # Any failure to load/execute the lua script is logged, and considered authorization failure. + # If the match is present but no Auth Header exists, it's considered authorization failure. + # - !AuthorizationScript "/etc/wstunnel/authrorize.lua" # The only other possible match type for now is !Any, that match everything/any request # - !Any diff --git a/wstunnel/Cargo.toml b/wstunnel/Cargo.toml index 5948d28b..0ab47daa 100644 --- a/wstunnel/Cargo.toml +++ b/wstunnel/Cargo.toml @@ -56,6 +56,8 @@ rcgen = { version = "0.13.2", default-features = false, features = [] } hickory-resolver = { version = "0.25.2", default-features = false, features = ["system-config", "tokio", "rustls-platform-verifier"] } aws-lc-rs = { version = "*", optional = true } +mlua = { version = "0.10.5", features = ["lua54", "vendored"] } + [target.'cfg(not(target_family = "unix"))'.dependencies] crossterm = { version = "0.29.0" } tokio-util = { version = "0.7.15", features = ["io"] } diff --git a/wstunnel/src/restrictions/types.rs b/wstunnel/src/restrictions/types.rs index a55ce227..4232e2b2 100644 --- a/wstunnel/src/restrictions/types.rs +++ b/wstunnel/src/restrictions/types.rs @@ -25,6 +25,7 @@ pub enum MatchConfig { PathPrefix(Regex), #[serde(with = "serde_regex")] Authorization(Regex), + AuthorizationScript(String), } #[derive(Debug, Clone, Deserialize)] diff --git a/wstunnel/src/tunnel/server/utils.rs b/wstunnel/src/tunnel/server/utils.rs index c4dc7d43..50acdb02 100644 --- a/wstunnel/src/tunnel/server/utils.rs +++ b/wstunnel/src/tunnel/server/utils.rs @@ -12,6 +12,7 @@ use hyper::body::{Body, Incoming}; use hyper::header::{AUTHORIZATION, COOKIE, HeaderValue, SEC_WEBSOCKET_PROTOCOL}; use hyper::{Request, Response, StatusCode, http}; use jsonwebtoken::TokenData; +use mlua::prelude::{Lua, LuaFunction}; use std::net::IpAddr; use tracing::{error, info, warn}; use url::Host; @@ -116,10 +117,47 @@ impl RestrictionConfig { MatchConfig::Any => true, MatchConfig::PathPrefix(path) => path.is_match(path_prefix), MatchConfig::Authorization(auth) => authorization_header_val.is_some_and(|val| auth.is_match(val)), + MatchConfig::AuthorizationScript(script_name) => { + authorization_header_val.is_some_and(|val| auth_header_matcher(script_name, val)) + } }) } } +fn auth_header_matcher(script_name: &String, auth_val: &str) -> bool { + let validate_fn_name = "auth_validate"; + let lua = Lua::new(); + + let script = match std::fs::read_to_string(script_name) { + Ok(s) => s, + Err(e) => { + error!("Failed to read {}: {}", script_name, e); + return false; + } + }; + + if let Err(e) = lua.load(&script).exec() { + error!("Failed to load lua script {}: {}", script_name, e); + return false; + } + + let validate_fn: LuaFunction = match lua.globals().get(validate_fn_name) { + Ok(func) => func, + Err(e) => { + error!("Failed to find '{}' lua function: {}", validate_fn_name, e); + return false; + } + }; + + match validate_fn.call(auth_val) { + Ok(result) => result, + Err(e) => { + error!("Failed calling '{}' lua function: {}", validate_fn_name, e); + false + } + } +} + impl AllowReverseTunnelConfig { #[inline] fn is_allowed(&self, remote: &RemoteAddr) -> bool {