diff --git a/LICENSE b/LICENSE index d968fa0..1d978f0 100644 --- a/LICENSE +++ b/LICENSE @@ -21,9 +21,10 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --- -Parts of this project, specifically the file cmd/internal/bindmount.go, +Parts of this project, specifically the file cmd/socket-proxy/bindmount.go and +the files in the internal/docker and internal/go-connections folders, contain source code licensed under the Apache License 2.0. See the comments -in that file for details. +in the applicable files for details. The rest of the project is licensed under the MIT License. Apache License diff --git a/README.md b/README.md index 2449118..5581891 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ As an additional benefit, socket-proxy can be used to examine the API calls of t The advantage over other solutions is the very slim container image (from-scratch-image) without any external dependencies (no OS, no packages, just the Go standard library). It is designed with security in mind, so there are secure defaults and an additional security layer (IP address-based access control) compared to most other solutions. -The allowlist is configured for each HTTP method separately using the Go regexp syntax, allowing fine-grained control over the allowed HTTP methods. +The allowlist is configured for each HTTP method separately using the Go regexp syntax, allowing fine-grained control over the allowed HTTP methods. In bridge network mode, each container that uses socket-proxy can be configured with its own allowlist. The source code is available on [GitHub: wollomatic/socket-proxy](https://github.com/wollomatic/socket-proxy) @@ -110,6 +110,27 @@ Bind mount restrictions are applied to relevant Docker API endpoints and work wi **Note**: This feature only restricts bind mounts. Other mount types (volumes, tmpfs, etc.) are not affected by this restriction. +#### Setting up per-container allowlists + +Allowlists for both requests and bind mount restrictions can be specified for particular containers. To do this: + +1. Set `-proxycontainername` or the environment variable `SP_PROXYCONTAINERNAME` to the name of the socket proxy container. +2. Make sure that each container that will use the socket proxy is in a Docker network that the socket proxy container is also in. +3. Use the same regex syntax for request allowlists and for bind mount restrictions that were discussed earlier, but for labels on each container that will use the socket proxy. Each label name will have the prefix of `socket-proxy.allow.`, with `socket-proxy.allow.bindmountfrom` for bind mount restrictions. For example: + +``` compose.yaml +services: + traefik: + # [...] see github.com/wollomatic/traefik-hardened for a full example + networks: + - traefik-servicenet # this is the common traefik network + - docker-proxynet # this should be only restricted to traefik and socket-proxy + labels: + - 'socket-proxy.allow.get=.*' # allow all GET requests to socket-proxy +``` + +When this is used, it is not necessary to specify the container in `-allowfrom` as the presence of the allowlist labels will grant corresponding access. + ### Container health check Health checks are disabled by default. As the socket-proxy container may not be exposed to a public network, a separate health check binary is included in the container image. To activate the health check, the `-allowhealthcheck` parameter or the environment variable `SP_ALLOWHEALTHCHECK=true` must be set. Then, a health check is possible for example with the following docker-compose snippet: @@ -212,6 +233,7 @@ socket-proxy can be configured via command line parameters or via environment va | `-watchdoginterval` | `SP_WATCHDOGINTERVAL` | `0` | Check for socket availability every x seconds (disable checks, if not set or value is 0) | | `-proxysocketendpoint` | `SP_PROXYSOCKETENDPOINT` | (not set) | Proxy to the given unix socket instead of a TCP port | | `-proxysocketendpointfilemode` | `SP_PROXYSOCKETENDPOINTFILEMODE` | `0600` | Explicitly set the file mode for the filtered unix socket endpoint (only useful with `-proxysocketendpoint`) | +| `-proxycontainername` | `SP_PROXYCONTAINERNAME ` | (not set) | Provides the name of the socket proxy container to enable per-container allowlists specified by Docker container labels (not available with `-proxysocketendpoint`) | ### Changelog @@ -240,7 +262,7 @@ socket-proxy can be configured via command line parameters or via environment va ## License This project is licensed under the MIT License – see the [LICENSE](LICENSE) file for details. -Parts of the file `cmd/internal/bindmount.go` are licensed under the Apache 2.0 License. +Parts of the file `cmd/socket-proxy/bindmount.go` and files under the `internal/docker` and `internal/go-connections` folders are licensed under the Apache 2.0 License. See the comments in this file and the LICENSE file for more information. ## Aknowledgements diff --git a/cmd/socket-proxy/bindmount.go b/cmd/socket-proxy/bindmount.go index 5b6ba16..38671c3 100644 --- a/cmd/socket-proxy/bindmount.go +++ b/cmd/socket-proxy/bindmount.go @@ -79,9 +79,9 @@ type ( ) // checkBindMountRestrictions checks if bind mounts in the request are allowed. -func checkBindMountRestrictions(r *http.Request) error { +func checkBindMountRestrictions(allowedBindMounts []string, r *http.Request) error { // Only check if bind mount restrictions are configured - if len(cfg.AllowBindMountFrom) == 0 { + if len(allowedBindMounts) == 0 { return nil } @@ -94,23 +94,23 @@ func checkBindMountRestrictions(r *http.Request) error { switch { case len(pathParts) >= 4 && pathParts[2] == "containers" && pathParts[3] == "create": // Container creation: /vX.xx/containers/create - return checkContainer(r) + return checkContainer(allowedBindMounts, r) case len(pathParts) >= 5 && pathParts[2] == "containers" && pathParts[4] == "update": // Container update: /vX.xx/containers/{id}/update - return checkContainer(r) + return checkContainer(allowedBindMounts, r) case len(pathParts) >= 4 && pathParts[2] == "services" && pathParts[3] == "create": // Service creation: /vX.xx/services/create - return checkService(r) + return checkService(allowedBindMounts, r) case len(pathParts) >= 5 && pathParts[2] == "services" && pathParts[4] == "update": // Service update: /vX.xx/services/{id}/update - return checkService(r) + return checkService(allowedBindMounts, r) default: return nil } } // checkContainer checks bind mounts in container creation requests. -func checkContainer(r *http.Request) error { +func checkContainer(allowedBindMounts []string, r *http.Request) error { body, err := readAndRestoreBody(r) if err != nil { return err @@ -122,11 +122,11 @@ func checkContainer(r *http.Request) error { return nil // Don't block if we can't parse. } - return checkHostConfigBindMounts(req.HostConfig) + return checkHostConfigBindMounts(allowedBindMounts, req.HostConfig) } // checkService checks bind mounts in service creation requests. -func checkService(r *http.Request) error { +func checkService(allowedBindMounts []string, r *http.Request) error { body, err := readAndRestoreBody(r) if err != nil { return err @@ -141,20 +141,23 @@ func checkService(r *http.Request) error { if req.TaskTemplate.ContainerSpec == nil { return nil // No container spec, nothing to check. } - return checkHostConfigBindMounts(&containerHostConfig{ - Mounts: req.TaskTemplate.ContainerSpec.Mounts, - }) + return checkHostConfigBindMounts( + allowedBindMounts, + &containerHostConfig{ + Mounts: req.TaskTemplate.ContainerSpec.Mounts, + }, + ) } // checkHostConfigBindMounts checks bind mounts in HostConfig. -func checkHostConfigBindMounts(hostConfig *containerHostConfig) error { +func checkHostConfigBindMounts(allowedBindMounts []string, hostConfig *containerHostConfig) error { if hostConfig == nil { return nil // No HostConfig, nothing to check } // Check legacy Binds field for _, bind := range hostConfig.Binds { - if err := validateBindMount(bind); err != nil { + if err := validateBindMount(allowedBindMounts, bind); err != nil { return err } } @@ -162,7 +165,7 @@ func checkHostConfigBindMounts(hostConfig *containerHostConfig) error { // Check modern Mounts field for _, mountItem := range hostConfig.Mounts { if mountItem.Type == mountTypeBind { - if err := validateBindMountSource(mountItem.Source); err != nil { + if err := validateBindMountSource(allowedBindMounts, mountItem.Source); err != nil { return err } } @@ -172,23 +175,23 @@ func checkHostConfigBindMounts(hostConfig *containerHostConfig) error { } // validateBindMount validates a bind mount string in the format "source:target:options". -func validateBindMount(bind string) error { +func validateBindMount(allowedBindMounts []string, bind string) error { parts := strings.Split(bind, ":") if len(parts) < 2 { return fmt.Errorf("invalid bind mount format: %s", bind) } - return validateBindMountSource(parts[0]) + return validateBindMountSource(allowedBindMounts, parts[0]) } // validateBindMountSource checks if the source directory is allowed. -func validateBindMountSource(source string) error { +func validateBindMountSource(allowedBindMounts []string, source string) error { // Skip if source is not an absolute path (i.e. bind mount). if !strings.HasPrefix(source, "/") { return nil } source = filepath.Clean(source) // Clean the path to resolve .. and . components. - for _, allowedDir := range cfg.AllowBindMountFrom { + for _, allowedDir := range allowedBindMounts { if allowedDir == "/" || source == allowedDir || strings.HasPrefix(source, allowedDir+"/") { return nil } diff --git a/cmd/socket-proxy/bindmount_test.go b/cmd/socket-proxy/bindmount_test.go index 6fe8d10..d71d74e 100644 --- a/cmd/socket-proxy/bindmount_test.go +++ b/cmd/socket-proxy/bindmount_test.go @@ -5,8 +5,6 @@ import ( "net/http" "runtime" "testing" - - "github.com/wollomatic/socket-proxy/internal/config" ) func skipIfNotUnix(t *testing.T) { @@ -21,9 +19,7 @@ func skipIfNotUnix(t *testing.T) { func TestValidateBindMountSource(t *testing.T) { skipIfNotUnix(t) - cfg = &config.Config{ - AllowBindMountFrom: []string{"/home", "/var/log"}, - } + allowedBindMounts := []string{"/home", "/var/log"} tests := []struct { name string @@ -44,7 +40,7 @@ func TestValidateBindMountSource(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateBindMountSource(tt.source) + err := validateBindMountSource(allowedBindMounts, tt.source) if tt.shouldPass && err != nil { t.Errorf("expected %s to pass, but got error: %v", tt.source, err) } @@ -83,10 +79,7 @@ func TestIsPathAllowed(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg = &config.Config{ - AllowBindMountFrom: []string{tt.allowedDir}, - } - err := validateBindMountSource(tt.path) + err := validateBindMountSource([]string{tt.allowedDir}, tt.path) if (err == nil) != tt.expected { t.Errorf("isPathAllowed(%s, %s) = %v, expected %v", tt.path, tt.allowedDir, err, tt.expected) } @@ -97,9 +90,7 @@ func TestIsPathAllowed(t *testing.T) { func TestValidateBindMount(t *testing.T) { skipIfNotUnix(t) - cfg = &config.Config{ - AllowBindMountFrom: []string{"/home", "/var/log"}, - } + allowedBindMounts := []string{"/home", "/var/log"} tests := []struct { name string @@ -115,7 +106,7 @@ func TestValidateBindMount(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validateBindMount(tt.bind) + err := validateBindMount(allowedBindMounts, tt.bind) if tt.shouldPass && err != nil { t.Errorf("expected %s to pass, but got error: %v", tt.bind, err) } @@ -129,9 +120,7 @@ func TestValidateBindMount(t *testing.T) { func TestCheckBindMountRestrictions(t *testing.T) { skipIfNotUnix(t) - cfg = &config.Config{ - AllowBindMountFrom: []string{"/home"}, - } + allowedBindMounts := []string{"/home"} tests := []struct { name string @@ -212,7 +201,7 @@ func TestCheckBindMountRestrictions(t *testing.T) { t.Fatalf("failed to create request: %v", err) } - err = checkBindMountRestrictions(req) + err = checkBindMountRestrictions(allowedBindMounts, req) if tt.shouldPass && err != nil { t.Errorf("expected request to pass, but got error: %v", err) } diff --git a/cmd/socket-proxy/handlehttprequest.go b/cmd/socket-proxy/handlehttprequest.go index 8d748b4..7bf1092 100644 --- a/cmd/socket-proxy/handlehttprequest.go +++ b/cmd/socket-proxy/handlehttprequest.go @@ -5,25 +5,21 @@ import ( "log/slog" "net" "net/http" + + "github.com/wollomatic/socket-proxy/internal/config" ) // handleHTTPRequest checks if the request is allowed and sends it to the proxy. // Otherwise, it returns a "405 Method Not Allowed" or a "403 Forbidden" error. // In case of an error, it returns a 500 Internal Server Error. func handleHTTPRequest(w http.ResponseWriter, r *http.Request) { - if cfg.ProxySocketEndpoint == "" { // do not perform this check if we proxy to a unix socket - allowedIP, err := isAllowedClient(r.RemoteAddr) - if err != nil { - slog.Warn("cannot get valid IP address for client allowlist check", "reason", err, "method", r.Method, "URL", r.URL, "client", r.RemoteAddr) - } - if !allowedIP { - communicateBlockedRequest(w, r, "forbidden IP", http.StatusForbidden) - return - } + allowList, ok := determineAllowList(r) + if !ok { + communicateBlockedRequest(w, r, "forbidden IP", http.StatusForbidden) + return } - // check if the request is allowed - allowed, exists := cfg.AllowedRequests[r.Method] + allowed, exists := allowList.AllowedRequests[r.Method] if !exists { // method not in map -> not allowed communicateBlockedRequest(w, r, "method not allowed", http.StatusMethodNotAllowed) return @@ -34,7 +30,7 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request) { } // check bind mount restrictions - if err := checkBindMountRestrictions(r); err != nil { + if err := checkBindMountRestrictions(allowList.AllowedBindMounts, r); err != nil { communicateBlockedRequest(w, r, "bind mount restriction: "+err.Error(), http.StatusForbidden) return } @@ -44,14 +40,40 @@ func handleHTTPRequest(w http.ResponseWriter, r *http.Request) { socketProxy.ServeHTTP(w, r) // proxy the request } +// return the relevant allowlist +func determineAllowList(r *http.Request) (config.AllowList, bool) { + if cfg.ProxySocketEndpoint == "" { // do not perform this check if we proxy to a unix socket + // Get the client IP address from the remote address string + clientIPStr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + slog.Warn("cannot get valid IP address from request", "reason", err, "method", r.Method, "URL", r.URL, "client", r.RemoteAddr) + return config.AllowList{}, false + } + + // If applicable, get the non-default allowlist corresponding to the client IP address + if cfg.ProxyContainerName != "" { + allowList, found := cfg.AllowLists.FindByIP(clientIPStr) + if found { + return allowList, true + } + } + + // Check if client is allowed for the default allowlist: + allowedIP, err := isAllowedClient(clientIPStr) + if err != nil { + slog.Warn("cannot get valid IP address for client allowlist check", "reason", err, "method", r.Method, "URL", r.URL, "client", r.RemoteAddr) + } + if !allowedIP { + return config.AllowList{}, false + } + } + + return cfg.AllowLists.Default, true +} + // isAllowedClient checks if the given remote address is allowed to connect to the proxy. // The IP address is extracted from a RemoteAddr string (the part before the colon). -func isAllowedClient(remoteAddr string) (bool, error) { - // Get the client IP address from the remote address string - clientIPStr, _, err := net.SplitHostPort(remoteAddr) - if err != nil { - return false, err - } +func isAllowedClient(clientIPStr string) (bool, error) { // Parse the IP address clientIP := net.ParseIP(clientIPStr) if clientIP == nil { diff --git a/cmd/socket-proxy/main.go b/cmd/socket-proxy/main.go index 77837cc..9fbf9fb 100644 --- a/cmd/socket-proxy/main.go +++ b/cmd/socket-proxy/main.go @@ -56,6 +56,11 @@ func main() { } slog.SetDefault(logger) + // setup non-default allowlists + if cfg.ProxySocketEndpoint == "" && cfg.ProxyContainerName != "" { + go cfg.UpdateAllowLists() + } + // print configuration slog.Info("starting socket-proxy", "version", version, "os", runtime.GOOS, "arch", runtime.GOARCH, "runtime", runtime.Version(), "URL", programURL) if cfg.ProxySocketEndpoint == "" { @@ -71,26 +76,17 @@ func main() { } else { slog.Info("watchdog disabled") } - if len(cfg.AllowBindMountFrom) > 0 { - slog.Info("Docker bind mount restrictions enabled", "allowbindmountfrom", cfg.AllowBindMountFrom) + if len(cfg.ProxyContainerName) > 0 { + slog.Info("Proxy container name provided", "proxycontainername", cfg.ProxyContainerName) } else { - // we only log this on DEBUG level because bind mount restrictions are a very special use case - slog.Debug("no Docker bind mount restrictions") + // we only log this on DEBUG level because providing the socket-proxy container name + // enables the use of labels to specify per-container allowlists + slog.Debug("no proxy container name provided") } + cfg.AllowLists.PrintNetworks() - // print request allowlist - if cfg.LogJSON { - for method, regex := range cfg.AllowedRequests { - slog.Info("configured allowed request", "method", method, "regex", regex) - } - } else { - // don't use slog here, as we want to print the regexes as they are - // see https://github.com/wollomatic/socket-proxy/issues/11 - fmt.Printf("Request allowlist:\n %-8s %s\n", "Method", "Regex") - for method, regex := range cfg.AllowedRequests { - fmt.Printf(" %-8s %s\n", method, regex) - } - } + // print default request allowlist + cfg.AllowLists.PrintDefault(cfg.LogJSON) // check if the socket is available err = checkSocketAvailability(cfg.SocketPath) diff --git a/internal/config/config.go b/internal/config/config.go index 1f4ae42..c6c93e7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,6 +1,7 @@ package config import ( + "context" "errors" "flag" "fmt" @@ -11,10 +12,20 @@ import ( "os" "path/filepath" "regexp" + "slices" "strconv" "strings" + "sync" + "time" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types/container" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/events" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" + "github.com/wollomatic/socket-proxy/internal/docker/client" ) +const allowedDockerLabelPrefix = "socket-proxy.allow." + var ( defaultAllowFrom = "127.0.0.1/32" // allowed IPs to connect to the proxy defaultAllowHealthcheck = false // allow health check requests (HEAD http://localhost:55555/health) @@ -29,10 +40,11 @@ var ( defaultProxySocketEndpoint = "" // empty string means no socket listener, but regular TCP listener defaultProxySocketEndpointFileMode = uint(0o600) // set the file mode of the unix socket endpoint defaultAllowBindMountFrom = "" // empty string means no bind mount restrictions + defaultProxyContainerName = "" // socket-proxy Docker container name (empty string disables container labels for allowlists) ) type Config struct { - AllowedRequests map[string]*regexp.Regexp + AllowLists *AllowListRegistry AllowFrom []string AllowHealthcheck bool LogJSON bool @@ -44,7 +56,20 @@ type Config struct { SocketPath string ProxySocketEndpoint string ProxySocketEndpointFileMode os.FileMode - AllowBindMountFrom []string + ProxyContainerName string +} + +type AllowListRegistry struct { + mutex sync.RWMutex // mutex to control read/write of byIP + networks []string // names of networks in which socket proxy access is allowed for non-default allowlists + Default AllowList // default allowlist + byIP map[string]AllowList // map container IP address to allowlist for that container +} + +type AllowList struct { + ID string // Container ID (empty for the default allowlist) + AllowedRequests map[string]*regexp.Regexp // map of request methods to request path regex patterns (no requests allowed if empty) + AllowedBindMounts []string // list of from portion of allowed bind mounts (all bind mounts allowed if empty) } // used for list of allowed requests @@ -134,6 +159,9 @@ func InitConfig() (*Config, error) { if val, ok := os.LookupEnv("SP_ALLOWBINDMOUNTFROM"); ok && val != "" { defaultAllowBindMountFrom = val } + if val, ok := os.LookupEnv("SP_PROXYCONTAINERNAME"); ok && val != "" { + defaultProxyContainerName = val + } for i := range mr { if val, ok := os.LookupEnv("SP_ALLOW_" + mr[i].method); ok && val != "" { @@ -160,39 +188,41 @@ func InitConfig() (*Config, error) { flag.StringVar(&cfg.ProxySocketEndpoint, "proxysocketendpoint", defaultProxySocketEndpoint, "unix socket endpoint (if set, used instead of the TCP listener)") flag.UintVar(&endpointFileMode, "proxysocketendpointfilemode", defaultProxySocketEndpointFileMode, "set the file mode of the unix socket endpoint") flag.StringVar(&allowBindMountFromString, "allowbindmountfrom", defaultAllowBindMountFrom, "allowed directories for bind mounts (comma-separated)") + flag.StringVar(&cfg.ProxyContainerName, "proxycontainername", defaultProxyContainerName, "socket-proxy Docker container name") for i := range mr { flag.StringVar(&mr[i].regexStringFromParam, "allow"+mr[i].method, "", "regex for "+mr[i].method+" requests (not set means method is not allowed)") } flag.Parse() + // init allowlist registry to configure default allowlist + cfg.AllowLists = &AllowListRegistry{} + // parse comma-separeted allowFromString into allowFrom slice cfg.AllowFrom = strings.Split(allowFromString, ",") - // parse allowBindMountFromString into AllowBindMountFrom slice and validate + // parse allowBindMountFromString into default allowlist AllowedBindMounts slice and validate if allowBindMountFromString != "" { - cfg.AllowBindMountFrom = strings.Split(allowBindMountFromString, ",") - for i, dir := range cfg.AllowBindMountFrom { - if !strings.HasPrefix(dir, "/") { - return nil, fmt.Errorf("bind mount directory must start with /: %q", dir) - } - cfg.AllowBindMountFrom[i] = filepath.Clean(dir) + allowedBindMounts, err := parseAllowedBindMounts(allowBindMountFromString) + if err != nil { + return nil, err } + cfg.AllowLists.Default.AllowedBindMounts = allowedBindMounts } // check listenIP and proxyPort if proxyPort < 1 || proxyPort > 65535 { - return nil, errors.New("port number has to be between 1 and 65535") + return nil, errors.New("port number has to be between 1 and 65535") } ip := net.ParseIP(listenIP) if ip == nil { - return nil, fmt.Errorf("invalid IP \"%s\" for listenip", listenIP) + return nil, fmt.Errorf("invalid IP \"%s\" for listenip", listenIP) } // Properly format address for both IPv4 and IPv6 if ip.To4() == nil { - cfg.ListenAddress = fmt.Sprintf("[%s]:%d", listenIP, proxyPort) + cfg.ListenAddress = fmt.Sprintf("[%s]:%d", listenIP, proxyPort) } else { - cfg.ListenAddress = fmt.Sprintf("%s:%d", listenIP, proxyPort) + cfg.ListenAddress = fmt.Sprintf("%s:%d", listenIP, proxyPort) } // parse defaultLogLevel and setup logging handler depending on defaultLogJSON @@ -214,22 +244,416 @@ func InitConfig() (*Config, error) { } cfg.ProxySocketEndpointFileMode = os.FileMode(uint32(endpointFileMode)) - // compile regexes for allowed requests - cfg.AllowedRequests = make(map[string]*regexp.Regexp) + // compile regexes for default allowed requests + cfg.AllowLists.Default.AllowedRequests = make(map[string]*regexp.Regexp) for _, rx := range mr { if rx.regexStringFromParam != "" { - r, err := regexp.Compile("^" + rx.regexStringFromParam + "$") + r, err := compileRegexp(rx.regexStringFromParam, rx.method, "command line parameter") if err != nil { - return nil, fmt.Errorf("invalid regex \"%s\" for method %s in command line parameter: %w", rx.regexStringFromParam, rx.method, err) + return nil, err } - cfg.AllowedRequests[rx.method] = r + cfg.AllowLists.Default.AllowedRequests[rx.method] = r } else if rx.regexStringFromEnv != "" { - r, err := regexp.Compile("^" + rx.regexStringFromEnv + "$") + r, err := compileRegexp(rx.regexStringFromEnv, rx.method, "env variable") if err != nil { - return nil, fmt.Errorf("invalid regex \"%s\" for method %s in env variable: %w", rx.regexStringFromEnv, rx.method, err) + return nil, err } - cfg.AllowedRequests[rx.method] = r + cfg.AllowLists.Default.AllowedRequests[rx.method] = r + } + } + + // populate list of socket proxy networks if applicable + if cfg.ProxySocketEndpoint == "" && cfg.ProxyContainerName != "" { + var err error + cfg.AllowLists.networks, err = listSocketProxyNetworks(cfg.SocketPath, cfg.ProxyContainerName) + if err != nil { + return nil, err } } + return &cfg, nil } + +// UpdateAllowLists populates the byIP allowlists then keeps them updated +func (cfg *Config) UpdateAllowLists() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + dockerClient, err := client.NewClientWithOpts( + client.WithHost("unix://"+cfg.SocketPath), + client.WithAPIVersionNegotiation(), + ) + if err != nil { + slog.Error("failed to create Docker client", "error", err) + return + } + defer dockerClient.Close() + + err = cfg.AllowLists.initByIP(ctx, dockerClient) + if err != nil { + slog.Error("failed to initialise non-default allowlists", "error", err) + return + } + slog.Debug("initialised non-default allowlists") + + filter := filters.NewArgs() + filter.Add("type", "container") + filter.Add("event", "start") + filter.Add("event", "restart") + filter.Add("event", "die") + eventsChan, errChan := dockerClient.Events(ctx, events.ListOptions{Filters: filter}) + slog.Debug("subscribed to Docker event stream to update allowlists") + + // print non-default request allowlists + cfg.AllowLists.PrintByIP(cfg.LogJSON) + + // handle Docker events to update allowlists + for { + select { + case event, ok := <-eventsChan: + if !ok { + slog.Info("Docker event stream closed") + return + } + slog.Debug("received Docker container event", "action", event.Action, "id", event.Actor.ID[:12]) + addedIPs, removedIPs, updateErr := cfg.AllowLists.updateFromEvent(ctx, dockerClient, event) + if updateErr != nil { + slog.Warn("failed to update allowlists from container event", "error", updateErr) + continue + } + for _, ip := range addedIPs { + cfg.AllowLists.mutex.RLock() + allowList, found := cfg.AllowLists.byIP[ip] + cfg.AllowLists.mutex.RUnlock() + if found { + allowList.Print(ip, cfg.LogJSON) + } + } + for _, ip := range removedIPs { + slog.Info("removed allowlist for container", "id", event.Actor.ID[:12], "ip", ip) + } + case err := <-errChan: + if err != nil { + slog.Error("received error from Docker event stream", "error", err) + return + } + } + } +} + +// PrintNetworks prints the allowed networks +func (allowLists *AllowListRegistry) PrintNetworks() { + if len(allowLists.networks) > 0 { + slog.Info("socket proxy networks detected", "socketproxynetworks", allowLists.networks) + } else { + // we only log this on DEBUG level because the socket proxy networks are used for per-container allowlists + slog.Debug("no socket proxy networks detected") + } +} + +// PrintDefault prints the default allowlist +func (allowLists *AllowListRegistry) PrintDefault(logJSON bool) { + allowLists.Default.Print("", logJSON) +} + +// PrintByIP prints the non-default allowlists +func (allowLists *AllowListRegistry) PrintByIP(logJSON bool) { + allowLists.mutex.RLock() + defer allowLists.mutex.RUnlock() + for ip, allowList := range allowLists.byIP { + allowList.Print(ip, logJSON) + } +} + +// FindByIP returns the allowlist corresponding to the given IP address if found +func (allowLists *AllowListRegistry) FindByIP(ip string) (AllowList, bool) { + allowLists.mutex.RLock() + defer allowLists.mutex.RUnlock() + allowList, found := allowLists.byIP[ip] + return allowList, found +} + +// initialise allowlist registry byIP allowlists +func (allowLists *AllowListRegistry) initByIP(ctx context.Context, dockerClient *client.Client) error { + filter := filters.NewArgs() + for _, network := range allowLists.networks { + filter.Add("network", network) + } + containers, err := dockerClient.ContainerList(ctx, container.ListOptions{Filters: filter}) + if err != nil { + return err + } + + allowLists.mutex.Lock() + defer allowLists.mutex.Unlock() + + allowLists.byIP = make(map[string]AllowList) + + for _, cntr := range containers { + allowedRequests, allowedBindMounts, err := extractLabelData(cntr) + if err != nil { + allowLists.byIP = nil + return err + } + + if len(allowedRequests) > 0 || len(allowedBindMounts) > 0 { + for networkID, cntrNetwork := range cntr.NetworkSettings.Networks { + if slices.Contains(allowLists.networks, networkID) { + allowList := AllowList{ + ID: cntr.ID, + AllowedRequests: allowedRequests, + AllowedBindMounts: allowedBindMounts, + } + + if len(cntrNetwork.IPAddress) > 0 { + allowLists.byIP[cntrNetwork.IPAddress] = allowList + } + if len(cntrNetwork.GlobalIPv6Address) > 0 { + allowLists.byIP[cntrNetwork.GlobalIPv6Address] = allowList + } + } + } + } + } + + return nil +} + +// update the allowlist registry based on the Docker event +func (allowLists *AllowListRegistry) updateFromEvent( + ctx context.Context, dockerClient *client.Client, event events.Message, +) ([]string, []string, error) { + containerID := event.Actor.ID + var ( + addedIPs []string + removedIPs []string + err error + ) + + switch event.Action { + case "start", "restart": + addedIPs, err = allowLists.add(ctx, dockerClient, containerID) + if err != nil { + return nil, nil, err + } + case "die": + removedIPs = allowLists.remove(containerID) + } + return addedIPs, removedIPs, nil +} + +// add the allowlist for the container with the given ID to the allowlist registry +// if it has at least one socket-proxy allow label and is in a same network as the socket-proxy +func (allowLists *AllowListRegistry) add( + ctx context.Context, dockerClient *client.Client, containerID string, +) ([]string, error) { + filter := filters.NewArgs() + filter.Add("id", containerID) + for _, network := range allowLists.networks { + filter.Add("network", network) + } + containers, err := dockerClient.ContainerList(ctx, container.ListOptions{Filters: filter}) + if err != nil { + return nil, err + } + if len(containers) == 0 { + slog.Debug("container is not in a network with socket-proxy or may have stopped", "id", containerID[:12]) + return nil, nil + } + cntr := containers[0] + + allowedRequests, allowedBindMounts, err := extractLabelData(cntr) + if err != nil { + return nil, err + } + + var ips []string + if len(allowedRequests) > 0 || len(allowedBindMounts) > 0 { + allowList := AllowList{ + ID: cntr.ID, + AllowedRequests: allowedRequests, + AllowedBindMounts: allowedBindMounts, + } + + allowLists.mutex.Lock() + defer allowLists.mutex.Unlock() + + for networkID, cntrNetwork := range cntr.NetworkSettings.Networks { + if slices.Contains(allowLists.networks, networkID) { + ipv4Address := cntrNetwork.IPAddress + if len(ipv4Address) > 0 { + allowLists.byIP[ipv4Address] = allowList + ips = append(ips, ipv4Address) + } + ipv6Address := cntrNetwork.GlobalIPv6Address + if len(ipv6Address) > 0 { + allowLists.byIP[ipv6Address] = allowList + ips = append(ips, ipv6Address) + } + } + } + } + + return ips, nil +} + +// remove allowlists having the given container ID from the allowlist registry +func (allowLists *AllowListRegistry) remove(containerID string) []string { + allowLists.mutex.Lock() + defer allowLists.mutex.Unlock() + + var removedIPs []string + for ip, allowList := range allowLists.byIP { + if allowList.ID == containerID { + delete(allowLists.byIP, ip) + removedIPs = append(removedIPs, ip) + } + } + return removedIPs +} + +// Print prints the allowlist, including the IP address of the associated container if it is not empty, +// and in JSON format if logJSON is true +func (allowList AllowList) Print(ip string, logJSON bool) { + // print allowed requests + if logJSON { + if ip == "" { + for method, regex := range allowList.AllowedRequests { + slog.Info("configured default request allowlist", "method", method, "regex", regex) + } + } else { + for method, regex := range allowList.AllowedRequests { + slog.Info("configured request allowlist", + "id", allowList.ID[:12], + "ip", ip, + "method", method, + "regex", regex, + ) + } + } + } else { + // don't use slog here, as we want to print the regexes as they are + // see https://github.com/wollomatic/socket-proxy/issues/11 + if ip == "" { + fmt.Printf("Default request allowlist:\n %-8s %s\n", "Method", "Regex") + } else { + fmt.Printf("Request allowlist for %s (%s):\n %-8s %s\n", allowList.ID[:12], ip, "Method", "Regex") + } + for method, regex := range allowList.AllowedRequests { + fmt.Printf(" %-8s %s\n", method, regex) + } + } + // print allowed bind mounts + if len(allowList.AllowedBindMounts) > 0 { + if ip == "" { + slog.Info("Default Docker bind mount restrictions enabled", + "allowbindmountfrom", allowList.AllowedBindMounts, + ) + } else { + slog.Info("Docker bind mount restrictions enabled", + "allowbindmountfrom", allowList.AllowedBindMounts, + "id", allowList.ID[:12], + "ip", ip, + ) + } + } else { + // we only log this on DEBUG level because bind mount restrictions are a very special use case + if ip == "" { + slog.Debug("no default Docker bind mount restrictions") + } else { + slog.Debug("no Docker bind mount restrictions", "id", allowList.ID[:12], "ip", ip) + } + } +} + +// compile allowed requests regex pattern +func compileRegexp(regex, method, configLocation string) (*regexp.Regexp, error) { + r, err := regexp.Compile("^" + regex + "$") + if err != nil { + return nil, fmt.Errorf("invalid regex \"%s\" for method %s in %s: %w", regex, method, configLocation, err) + } + return r, nil +} + +// parse bind mount from string into list of allowed bind mounts +func parseAllowedBindMounts(allowBindMountFromString string) ([]string, error) { + allowedBindMounts := strings.Split(allowBindMountFromString, ",") + for i, dir := range allowedBindMounts { + if !strings.HasPrefix(dir, "/") { + return nil, fmt.Errorf("bind mount directory must start with /: %q", dir) + } + allowedBindMounts[i] = filepath.Clean(dir) + } + return allowedBindMounts, nil +} + +// return list of docker networks that the socket-proxy container is in +func listSocketProxyNetworks(socketPath, proxyContainerName string) ([]string, error) { + cntr, err := getSocketProxyContainerSummary(socketPath, proxyContainerName) + if err != nil { + return nil, err + } + + networks := make([]string, 0, len(cntr.NetworkSettings.Networks)) + for networkID := range cntr.NetworkSettings.Networks { + networks = append(networks, networkID) + } + return networks, nil +} + +// return Docker container summary for the socket proxy container +func getSocketProxyContainerSummary(socketPath, proxyContainerName string) (container.Summary, error) { + const maxTries = 3 + + dockerClient, err := client.NewClientWithOpts( + client.WithHost("unix://"+socketPath), + client.WithAPIVersionNegotiation(), + ) + if err != nil { + return container.Summary{}, err + } + defer dockerClient.Close() + + ctx := context.Background() + filter := filters.NewArgs() + filter.Add("name", proxyContainerName) + var containers []container.Summary + for i := 1; i < maxTries; i++ { + containers, err = dockerClient.ContainerList(ctx, container.ListOptions{Filters: filter}) + if err != nil { + return container.Summary{}, err + } + if len(containers) > 0 { + return containers[0], nil + } + if i < maxTries { + time.Sleep(time.Duration(i) * time.Second) + } + } + return container.Summary{}, fmt.Errorf("socket-proxy container \"%s\" was not found", proxyContainerName) +} + +// extract Docker container allowlist label data from the container summary +func extractLabelData(cntr container.Summary) (map[string]*regexp.Regexp, []string, error) { + allowedRequests := make(map[string]*regexp.Regexp) + var allowedBindMounts []string + for labelName, labelValue := range cntr.Labels { + if strings.HasPrefix(labelName, allowedDockerLabelPrefix) && labelValue != "" { + allowSpec := strings.ToUpper(strings.TrimPrefix(labelName, allowedDockerLabelPrefix)) + if slices.ContainsFunc(mr, func(rx methodRegex) bool { return rx.method == allowSpec }) { + r, err := compileRegexp(labelValue, allowSpec, "docker container label") + if err != nil { + return nil, nil, err + } + allowedRequests[allowSpec] = r + } else if allowSpec == "BINDMOUNTFROM" { + var err error + allowedBindMounts, err = parseAllowedBindMounts(labelValue) + if err != nil { + return nil, nil, err + } + } + } + } + return allowedRequests, allowedBindMounts, nil +} diff --git a/internal/docker/api/common.go b/internal/docker/api/common.go new file mode 100644 index 0000000..d90b158 --- /dev/null +++ b/internal/docker/api/common.go @@ -0,0 +1,12 @@ +package api + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/common.go +*/ + +// Common constants for daemon and client. +const ( + // DefaultVersion of the current REST API. + DefaultVersion = "1.51" +) diff --git a/internal/docker/api/types/container/container.go b/internal/docker/api/types/container/container.go new file mode 100644 index 0000000..83cd3b9 --- /dev/null +++ b/internal/docker/api/types/container/container.go @@ -0,0 +1,15 @@ +package container + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/container/container.go +*/ + +// Summary contains response of Engine API: +// GET "/containers/json" +type Summary struct { + ID string `json:"Id"` + Names []string + Labels map[string]string + NetworkSettings *NetworkSettingsSummary +} diff --git a/internal/docker/api/types/container/network_settings.go b/internal/docker/api/types/container/network_settings.go new file mode 100644 index 0000000..78e5d4c --- /dev/null +++ b/internal/docker/api/types/container/network_settings.go @@ -0,0 +1,16 @@ +package container + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/container/network_settings.go +*/ + +import ( + "github.com/wollomatic/socket-proxy/internal/docker/api/types/network" +) + +// NetworkSettingsSummary provides a summary of container's networks +// in /containers/json +type NetworkSettingsSummary struct { + Networks map[string]*network.EndpointSettings +} diff --git a/internal/docker/api/types/container/options.go b/internal/docker/api/types/container/options.go new file mode 100644 index 0000000..4893e5f --- /dev/null +++ b/internal/docker/api/types/container/options.go @@ -0,0 +1,13 @@ +package container + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/container/options.go +*/ + +import "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" + +// ListOptions holds parameters to list containers with. +type ListOptions struct { + Filters filters.Args +} diff --git a/internal/docker/api/types/error_response.go b/internal/docker/api/types/error_response.go new file mode 100644 index 0000000..ab8aa10 --- /dev/null +++ b/internal/docker/api/types/error_response.go @@ -0,0 +1,15 @@ +package types + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/error_response.go +*/ + +// ErrorResponse Represents an error. +// swagger:model ErrorResponse +type ErrorResponse struct { + + // The error message. + // Required: true + Message string `json:"message"` +} diff --git a/internal/docker/api/types/events/events.go b/internal/docker/api/types/events/events.go new file mode 100644 index 0000000..d8c7a63 --- /dev/null +++ b/internal/docker/api/types/events/events.go @@ -0,0 +1,128 @@ +package events + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/events/events.go +*/ + +import "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" + +// Type is used for event-types. +type Type string + +// List of known event types. +const ( + BuilderEventType Type = "builder" // BuilderEventType is the event type that the builder generates. + ConfigEventType Type = "config" // ConfigEventType is the event type that configs generate. + ContainerEventType Type = "container" // ContainerEventType is the event type that containers generate. + DaemonEventType Type = "daemon" // DaemonEventType is the event type that daemon generate. + ImageEventType Type = "image" // ImageEventType is the event type that images generate. + NetworkEventType Type = "network" // NetworkEventType is the event type that networks generate. + NodeEventType Type = "node" // NodeEventType is the event type that nodes generate. + PluginEventType Type = "plugin" // PluginEventType is the event type that plugins generate. + SecretEventType Type = "secret" // SecretEventType is the event type that secrets generate. + ServiceEventType Type = "service" // ServiceEventType is the event type that services generate. + VolumeEventType Type = "volume" // VolumeEventType is the event type that volumes generate. +) + +// Action is used for event-actions. +type Action string + +const ( + ActionCreate Action = "create" + ActionStart Action = "start" + ActionRestart Action = "restart" + ActionStop Action = "stop" + ActionCheckpoint Action = "checkpoint" + ActionPause Action = "pause" + ActionUnPause Action = "unpause" + ActionAttach Action = "attach" + ActionDetach Action = "detach" + ActionResize Action = "resize" + ActionUpdate Action = "update" + ActionRename Action = "rename" + ActionKill Action = "kill" + ActionDie Action = "die" + ActionOOM Action = "oom" + ActionDestroy Action = "destroy" + ActionRemove Action = "remove" + ActionCommit Action = "commit" + ActionTop Action = "top" + ActionCopy Action = "copy" + ActionArchivePath Action = "archive-path" + ActionExtractToDir Action = "extract-to-dir" + ActionExport Action = "export" + ActionImport Action = "import" + ActionSave Action = "save" + ActionLoad Action = "load" + ActionTag Action = "tag" + ActionUnTag Action = "untag" + ActionPush Action = "push" + ActionPull Action = "pull" + ActionPrune Action = "prune" + ActionDelete Action = "delete" + ActionEnable Action = "enable" + ActionDisable Action = "disable" + ActionConnect Action = "connect" + ActionDisconnect Action = "disconnect" + ActionReload Action = "reload" + ActionMount Action = "mount" + ActionUnmount Action = "unmount" + + // ActionExecCreate is the prefix used for exec_create events. These + // event-actions are commonly followed by a colon and space (": "), + // and the command that's defined for the exec, for example: + // + // exec_create: /bin/sh -c 'echo hello' + // + // This is far from ideal; it's a compromise to allow filtering and + // to preserve backward-compatibility. + ActionExecCreate Action = "exec_create" + // ActionExecStart is the prefix used for exec_create events. These + // event-actions are commonly followed by a colon and space (": "), + // and the command that's defined for the exec, for example: + // + // exec_start: /bin/sh -c 'echo hello' + // + // This is far from ideal; it's a compromise to allow filtering and + // to preserve backward-compatibility. + ActionExecStart Action = "exec_start" + ActionExecDie Action = "exec_die" + ActionExecDetach Action = "exec_detach" + + // ActionHealthStatus is the prefix to use for health_status events. + // + // Health-status events can either have a pre-defined status, in which + // case the "health_status" action is followed by a colon, or can be + // "free-form", in which case they're followed by the output of the + // health-check output. + // + // This is far form ideal, and a compromise to allow filtering, and + // to preserve backward-compatibility. + ActionHealthStatus Action = "health_status" + ActionHealthStatusRunning Action = "health_status: running" + ActionHealthStatusHealthy Action = "health_status: healthy" + ActionHealthStatusUnhealthy Action = "health_status: unhealthy" +) + +// Actor describes something that generates events, +// like a container, or a network, or a volume. +// It has a defined name and a set of attributes. +// The container attributes are its labels, other actors +// can generate these attributes from other properties. +type Actor struct { + ID string + Attributes map[string]string +} + +// Message represents the information an event contains +type Message struct { + Type Type + Action Action + Actor Actor +} + +// ListOptions holds parameters to filter events with. +type ListOptions struct { + Filters filters.Args +} diff --git a/internal/docker/api/types/filters/errors.go b/internal/docker/api/types/filters/errors.go new file mode 100644 index 0000000..0c795d8 --- /dev/null +++ b/internal/docker/api/types/filters/errors.go @@ -0,0 +1,29 @@ +package filters + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/filters/errors.go +*/ + +import "fmt" + +// invalidFilter indicates that the provided filter or its value is invalid +type invalidFilter struct { + Filter string + Value []string +} + +func (e invalidFilter) Error() string { + msg := "invalid filter" + if e.Filter != "" { + msg += " '" + e.Filter + if e.Value != nil { + msg = fmt.Sprintf("%s=%s", msg, e.Value) + } + msg += "'" + } + return msg +} + +// InvalidParameter marks this error as ErrInvalidParameter +func (e invalidFilter) InvalidParameter() {} diff --git a/internal/docker/api/types/filters/parse.go b/internal/docker/api/types/filters/parse.go new file mode 100644 index 0000000..cafebde --- /dev/null +++ b/internal/docker/api/types/filters/parse.go @@ -0,0 +1,305 @@ +/* +Package filters provides tools for encoding a mapping of keys to a set of +multiple values. + +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/filters/parse.go +*/ +package filters + +import ( + "encoding/json" + "regexp" + "strings" +) + +// Args stores a mapping of keys to a set of multiple values. +type Args struct { + fields map[string]map[string]bool +} + +// KeyValuePair are used to initialize a new Args +type KeyValuePair struct { + Key string + Value string +} + +// Arg creates a new KeyValuePair for initializing Args +func Arg(key, value string) KeyValuePair { + return KeyValuePair{Key: key, Value: value} +} + +// NewArgs returns a new Args populated with the initial args +func NewArgs(initialArgs ...KeyValuePair) Args { + args := Args{fields: map[string]map[string]bool{}} + for _, arg := range initialArgs { + args.Add(arg.Key, arg.Value) + } + return args +} + +// Keys returns all the keys in list of Args +func (args Args) Keys() []string { + keys := make([]string, 0, len(args.fields)) + for k := range args.fields { + keys = append(keys, k) + } + return keys +} + +// MarshalJSON returns a JSON byte representation of the Args +func (args Args) MarshalJSON() ([]byte, error) { + if len(args.fields) == 0 { + return []byte("{}"), nil + } + return json.Marshal(args.fields) +} + +// ToJSON returns the Args as a JSON encoded string +func ToJSON(a Args) (string, error) { + if a.Len() == 0 { + return "", nil + } + buf, err := json.Marshal(a) + return string(buf), err +} + +// FromJSON decodes a JSON encoded string into Args +func FromJSON(p string) (Args, error) { + args := NewArgs() + + if p == "" { + return args, nil + } + + raw := []byte(p) + err := json.Unmarshal(raw, &args) + if err == nil { + return args, nil + } + + // Fallback to parsing arguments in the legacy slice format + deprecated := map[string][]string{} + if legacyErr := json.Unmarshal(raw, &deprecated); legacyErr != nil { + return args, &invalidFilter{} + } + + args.fields = deprecatedArgs(deprecated) + return args, nil +} + +// UnmarshalJSON populates the Args from JSON encode bytes +func (args Args) UnmarshalJSON(raw []byte) error { + return json.Unmarshal(raw, &args.fields) +} + +// Get returns the list of values associated with the key +func (args Args) Get(key string) []string { + values := args.fields[key] + if values == nil { + return make([]string, 0) + } + slice := make([]string, 0, len(values)) + for key := range values { + slice = append(slice, key) + } + return slice +} + +// Add a new value to the set of values +func (args Args) Add(key, value string) { + if _, ok := args.fields[key]; ok { + args.fields[key][value] = true + } else { + args.fields[key] = map[string]bool{value: true} + } +} + +// Del removes a value from the set +func (args Args) Del(key, value string) { + if _, ok := args.fields[key]; ok { + delete(args.fields[key], value) + if len(args.fields[key]) == 0 { + delete(args.fields, key) + } + } +} + +// Len returns the number of keys in the mapping +func (args Args) Len() int { + return len(args.fields) +} + +// MatchKVList returns true if all the pairs in sources exist as key=value +// pairs in the mapping at key, or if there are no values at key. +func (args Args) MatchKVList(key string, sources map[string]string) bool { + fieldValues := args.fields[key] + + // do not filter if there is no filter set or cannot determine filter + if len(fieldValues) == 0 { + return true + } + + if len(sources) == 0 { + return false + } + + for value := range fieldValues { + testK, testV, hasValue := strings.Cut(value, "=") + + v, ok := sources[testK] + if !ok { + return false + } + if hasValue && testV != v { + return false + } + } + + return true +} + +// Match returns true if any of the values at key match the source string +func (args Args) Match(field, source string) bool { + if args.ExactMatch(field, source) { + return true + } + + fieldValues := args.fields[field] + for name2match := range fieldValues { + match, err := regexp.MatchString(name2match, source) + if err != nil { + continue + } + if match { + return true + } + } + return false +} + +// GetBoolOrDefault returns a boolean value of the key if the key is present +// and is interpretable as a boolean value. Otherwise the default value is returned. +// Error is not nil only if the filter values are not valid boolean or are conflicting. +func (args Args) GetBoolOrDefault(key string, defaultValue bool) (bool, error) { + fieldValues, ok := args.fields[key] + if !ok { + return defaultValue, nil + } + + if len(fieldValues) == 0 { + return defaultValue, &invalidFilter{key, nil} + } + + isFalse := fieldValues["0"] || fieldValues["false"] + isTrue := fieldValues["1"] || fieldValues["true"] + if isFalse == isTrue { + // Either no or conflicting truthy/falsy value were provided + return defaultValue, &invalidFilter{key, args.Get(key)} + } + return isTrue, nil +} + +// ExactMatch returns true if the source matches exactly one of the values. +func (args Args) ExactMatch(key, source string) bool { + fieldValues, ok := args.fields[key] + // do not filter if there is no filter set or cannot determine filter + if !ok || len(fieldValues) == 0 { + return true + } + + // try to match full name value to avoid O(N) regular expression matching + return fieldValues[source] +} + +// UniqueExactMatch returns true if there is only one value and the source +// matches exactly the value. +func (args Args) UniqueExactMatch(key, source string) bool { + fieldValues := args.fields[key] + // do not filter if there is no filter set or cannot determine filter + if len(fieldValues) == 0 { + return true + } + if len(args.fields[key]) != 1 { + return false + } + + // try to match full name value to avoid O(N) regular expression matching + return fieldValues[source] +} + +// FuzzyMatch returns true if the source matches exactly one value, or the +// source has one of the values as a prefix. +func (args Args) FuzzyMatch(key, source string) bool { + if args.ExactMatch(key, source) { + return true + } + + fieldValues := args.fields[key] + for prefix := range fieldValues { + if strings.HasPrefix(source, prefix) { + return true + } + } + return false +} + +// Contains returns true if the key exists in the mapping +func (args Args) Contains(field string) bool { + _, ok := args.fields[field] + return ok +} + +// Validate compared the set of accepted keys against the keys in the mapping. +// An error is returned if any mapping keys are not in the accepted set. +func (args Args) Validate(accepted map[string]bool) error { + for name := range args.fields { + if !accepted[name] { + return &invalidFilter{name, nil} + } + } + return nil +} + +// WalkValues iterates over the list of values for a key in the mapping and calls +// op() for each value. If op returns an error the iteration stops and the +// error is returned. +func (args Args) WalkValues(field string, op func(value string) error) error { + if _, ok := args.fields[field]; !ok { + return nil + } + for v := range args.fields[field] { + if err := op(v); err != nil { + return err + } + } + return nil +} + +// Clone returns a copy of args. +func (args Args) Clone() (newArgs Args) { + newArgs.fields = make(map[string]map[string]bool, len(args.fields)) + for k, m := range args.fields { + var mm map[string]bool + if m != nil { + mm = make(map[string]bool, len(m)) + for kk, v := range m { + mm[kk] = v + } + } + newArgs.fields[k] = mm + } + return newArgs +} + +func deprecatedArgs(d map[string][]string) map[string]map[string]bool { + m := map[string]map[string]bool{} + for k, v := range d { + values := map[string]bool{} + for _, vv := range v { + values[vv] = true + } + m[k] = values + } + return m +} diff --git a/internal/docker/api/types/network/endpoint.go b/internal/docker/api/types/network/endpoint.go new file mode 100644 index 0000000..1fdec47 --- /dev/null +++ b/internal/docker/api/types/network/endpoint.go @@ -0,0 +1,25 @@ +package network + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/network/endpoint.go +*/ + +// EndpointSettings stores the network endpoint details +type EndpointSettings struct { + // Operational data + NetworkID string + EndpointID string + Gateway string + IPAddress string + IPPrefixLen int + IPv6Gateway string + GlobalIPv6Address string + GlobalIPv6PrefixLen int +} + +// Copy makes a deep copy of `EndpointSettings` +func (es *EndpointSettings) Copy() *EndpointSettings { + epCopy := *es + return &epCopy +} diff --git a/internal/docker/api/types/types.go b/internal/docker/api/types/types.go new file mode 100644 index 0000000..0098f0e --- /dev/null +++ b/internal/docker/api/types/types.go @@ -0,0 +1,12 @@ +package types + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/types.go +*/ + +// Ping contains response of Engine API: +// GET "/_ping" +type Ping struct { + APIVersion string +} diff --git a/internal/docker/api/types/versions/compare.go b/internal/docker/api/types/versions/compare.go new file mode 100644 index 0000000..ee76f1f --- /dev/null +++ b/internal/docker/api/types/versions/compare.go @@ -0,0 +1,50 @@ +package versions + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/api/types/versions/compare.go +*/ + +import ( + "strconv" + "strings" +) + +// compare compares two version strings +// returns -1 if v1 < v2, 1 if v1 > v2, 0 otherwise. +func compare(v1, v2 string) int { + if v1 == v2 { + return 0 + } + var ( + currTab = strings.Split(v1, ".") + otherTab = strings.Split(v2, ".") + ) + + maxVer := len(currTab) + if len(otherTab) > maxVer { + maxVer = len(otherTab) + } + for i := 0; i < maxVer; i++ { + var currInt, otherInt int + + if len(currTab) > i { + currInt, _ = strconv.Atoi(currTab[i]) + } + if len(otherTab) > i { + otherInt, _ = strconv.Atoi(otherTab[i]) + } + if currInt > otherInt { + return 1 + } + if otherInt > currInt { + return -1 + } + } + return 0 +} + +// LessThan checks if a version is less than another +func LessThan(v, other string) bool { + return compare(v, other) == -1 +} diff --git a/internal/docker/client/client.go b/internal/docker/client/client.go new file mode 100644 index 0000000..a77edee --- /dev/null +++ b/internal/docker/client/client.go @@ -0,0 +1,257 @@ +/* +Package client is a Go client for the Docker Engine API. + +For more information about the Engine API, see the documentation: +https://docs.docker.com/reference/api/engine/ + +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/client.go +*/ +package client + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/wollomatic/socket-proxy/internal/docker/api" + "github.com/wollomatic/socket-proxy/internal/docker/api/types" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/versions" + "github.com/wollomatic/socket-proxy/internal/go-connections/sockets" +) + +// DefaultDockerHost defines default host +const DefaultDockerHost = "unix:///var/run/docker.sock" + +// DummyHost is a hostname used for local communication. +const DummyHost = "api.moby.localhost" + +// fallbackAPIVersion is the version to fallback to if API-version negotiation +// fails. This version is the highest version of the API before API-version +// negotiation was introduced. If negotiation fails (or no API version was +// included in the API response), we assume the API server uses the most +// recent version before negotiation was introduced. +const fallbackAPIVersion = "1.24" + +// Client is the API client that performs all operations +// against a docker server. +type Client struct { + // scheme sets the scheme for the client + scheme string + // host holds the server address to connect to + host string + // proto holds the client protocol i.e. unix. + proto string + // addr holds the client address. + addr string + // basePath holds the path to prepend to the requests. + basePath string + // client used to send and receive http requests. + client *http.Client + // version of the server to talk to. + version string + // userAgent is the User-Agent header to use for HTTP requests. It takes + // precedence over User-Agent headers set in customHTTPHeaders, and other + // header variables. When set to an empty string, the User-Agent header + // is removed, and no header is sent. + userAgent *string + // custom HTTP headers configured by users. + customHTTPHeaders map[string]string + + // negotiateVersion indicates if the client should automatically negotiate + // the API version to use when making requests. API version negotiation is + // performed on the first request, after which negotiated is set to "true" + // so that subsequent requests do not re-negotiate. + negotiateVersion bool + + // negotiated indicates that API version negotiation took place + negotiated atomic.Bool + + // negotiateLock is used to single-flight the version negotiation process + negotiateLock sync.Mutex + + // When the client transport is an *http.Transport (default) we need to do some extra things (like closing idle connections). + // Store the original transport as the http.Client transport will be wrapped with tracing libs. + baseTransport *http.Transport +} + +// ErrRedirect is the error returned by checkRedirect when the request is non-GET. +var ErrRedirect = errors.New("unexpected redirect in response") + +// CheckRedirect specifies the policy for dealing with redirect responses. It +// can be set on [http.Client.CheckRedirect] to prevent HTTP redirects for +// non-GET requests. It returns an [ErrRedirect] for non-GET request, otherwise +// returns a [http.ErrUseLastResponse], which is special-cased by http.Client +// to use the last response. +// +// Go 1.8 changed behavior for HTTP redirects (specifically 301, 307, and 308) +// in the client. The client (and by extension API client) can be made to send +// a request like "POST /containers//start" where what would normally be in the +// name section of the URL is empty. This triggers an HTTP 301 from the daemon. +// +// In go 1.8 this 301 is converted to a GET request, and ends up getting +// a 404 from the daemon. This behavior change manifests in the client in that +// before, the 301 was not followed and the client did not generate an error, +// but now results in a message like "Error response from daemon: page not found". +func CheckRedirect(_ *http.Request, via []*http.Request) error { + if via[0].Method == http.MethodGet { + return http.ErrUseLastResponse + } + return ErrRedirect +} + +// NewClientWithOpts initializes a new API client with a default HTTPClient, and +// default API host and version. It also initializes the custom HTTP headers to +// add to each request. +func NewClientWithOpts(ops ...Opt) (*Client, error) { + hostURL, err := ParseHostURL(DefaultDockerHost) + if err != nil { + return nil, err + } + + client, err := defaultHTTPClient(hostURL) + if err != nil { + return nil, err + } + c := &Client{ + host: DefaultDockerHost, + version: api.DefaultVersion, + client: client, + proto: hostURL.Scheme, + addr: hostURL.Host, + scheme: "http", + } + + for _, op := range ops { + if err := op(c); err != nil { + return nil, err + } + } + + if tr, ok := c.client.Transport.(*http.Transport); ok { + // Store the base transport + // This is used, as an example, to close idle connections when the client is closed + c.baseTransport = tr + } + + return c, nil +} + +func defaultHTTPClient(hostURL *url.URL) (*http.Client, error) { + transport := &http.Transport{} + // Necessary to prevent long-lived processes using the + // client from leaking connections due to idle connections + // not being released. + transport.MaxIdleConns = 6 + transport.IdleConnTimeout = 30 * time.Second + err := sockets.ConfigureTransport(transport, hostURL.Scheme, hostURL.Host) + if err != nil { + return nil, err + } + return &http.Client{ + Transport: transport, + CheckRedirect: CheckRedirect, + }, nil +} + +// Close the transport used by the client +func (cli *Client) Close() error { + if cli.baseTransport != nil { + cli.baseTransport.CloseIdleConnections() + return nil + } + return nil +} + +// checkVersion manually triggers API version negotiation (if configured). +// This allows for version-dependent code to use the same version as will +// be negotiated when making the actual requests, and for which cases +// we cannot do the negotiation lazily. +func (cli *Client) checkVersion(ctx context.Context) error { + if cli.negotiateVersion && !cli.negotiated.Load() { + // Ensure exclusive write access to version and negotiated fields + cli.negotiateLock.Lock() + defer cli.negotiateLock.Unlock() + + // May have been set during last execution of critical zone + if cli.negotiated.Load() { + return nil + } + + ping, err := cli.Ping(ctx) + if err != nil { + return err + } + cli.negotiateAPIVersionPing(ping) + } + return nil +} + +// getAPIPath returns the versioned request path to call the API. +// It appends the query parameters to the path if they are not empty. +func (cli *Client) getAPIPath(ctx context.Context, p string, query url.Values) string { + var apiPath string + _ = cli.checkVersion(ctx) + if cli.version != "" { + apiPath = path.Join(cli.basePath, "/v"+strings.TrimPrefix(cli.version, "v"), p) + } else { + apiPath = path.Join(cli.basePath, p) + } + return (&url.URL{Path: apiPath, RawQuery: query.Encode()}).String() +} + +// negotiateAPIVersionPing queries the API and updates the version to match the +// API version from the ping response. +func (cli *Client) negotiateAPIVersionPing(pingResponse types.Ping) { + // default to the latest version before versioning headers existed + if pingResponse.APIVersion == "" { + pingResponse.APIVersion = fallbackAPIVersion + } + + // if the client is not initialized with a version, start with the latest supported version + if cli.version == "" { + cli.version = api.DefaultVersion + } + + // if server version is lower than the client version, downgrade + if versions.LessThan(pingResponse.APIVersion, cli.version) { + cli.version = pingResponse.APIVersion + } + + // Store the results, so that automatic API version negotiation (if enabled) + // won't be performed on the next request. + if cli.negotiateVersion { + cli.negotiated.Store(true) + } +} + +// ParseHostURL parses a url string, validates the string is a host url, and +// returns the parsed URL +func ParseHostURL(host string) (*url.URL, error) { + proto, addr, ok := strings.Cut(host, "://") + if !ok || addr == "" { + return nil, fmt.Errorf("unable to parse docker host `%s`", host) + } + + var basePath string + if proto == "tcp" { + parsed, err := url.Parse("tcp://" + addr) + if err != nil { + return nil, err + } + addr = parsed.Host + basePath = parsed.Path + } + return &url.URL{ + Scheme: proto, + Host: addr, + Path: basePath, + }, nil +} diff --git a/internal/docker/client/container_list.go b/internal/docker/client/container_list.go new file mode 100644 index 0000000..3d12313 --- /dev/null +++ b/internal/docker/client/container_list.go @@ -0,0 +1,39 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/container_list.go +*/ + +import ( + "context" + "encoding/json" + "net/url" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types/container" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" +) + +// ContainerList returns the list of containers in the docker host. +func (cli *Client) ContainerList(ctx context.Context, options container.ListOptions) ([]container.Summary, error) { + query := url.Values{} + + if options.Filters.Len() > 0 { + filterJSON, err := filters.ToJSON(options.Filters) + if err != nil { + return nil, err + } + + query.Set("filters", filterJSON) + } + + resp, err := cli.get(ctx, "/containers/json", query, nil) + defer ensureReaderClosed(resp) + if err != nil { + return nil, err + } + + var containers []container.Summary + err = json.NewDecoder(resp.Body).Decode(&containers) + return containers, err +} diff --git a/internal/docker/client/errors.go b/internal/docker/client/errors.go new file mode 100644 index 0000000..12fd231 --- /dev/null +++ b/internal/docker/client/errors.go @@ -0,0 +1,42 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/errors.go +*/ + +import ( + "errors" + "fmt" +) + +// errConnectionFailed implements an error returned when connection failed. +type errConnectionFailed struct { + error +} + +// Error returns a string representation of an errConnectionFailed +func (e errConnectionFailed) Error() string { + return e.error.Error() +} + +func (e errConnectionFailed) Unwrap() error { + return e.error +} + +// IsErrConnectionFailed returns true if the error is caused by connection failed. +func IsErrConnectionFailed(err error) bool { + return errors.As(err, &errConnectionFailed{}) +} + +// connectionFailed returns an error with host in the error message when connection +// to docker daemon failed. +func connectionFailed(host string) error { + var err error + if host == "" { + err = errors.New("cannot connect to the Docker daemon: is the docker daemon running on this host?") + } else { + err = fmt.Errorf("cannot connect to the Docker daemon at %s: is the docker daemon running?", host) + } + return errConnectionFailed{error: err} +} diff --git a/internal/docker/client/events.go b/internal/docker/client/events.go new file mode 100644 index 0000000..cf7d827 --- /dev/null +++ b/internal/docker/client/events.go @@ -0,0 +1,85 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/events.go +*/ + +import ( + "context" + "encoding/json" + "net/url" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types/events" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/filters" +) + +// Events returns a stream of events in the daemon. It's up to the caller to close the stream +// by cancelling the context. Once the stream has been completely read an io.EOF error will +// be sent over the error channel. If an error is sent all processing will be stopped. It's up +// to the caller to reopen the stream in the event of an error by reinvoking this method. +func (cli *Client) Events(ctx context.Context, options events.ListOptions) (<-chan events.Message, <-chan error) { + messages := make(chan events.Message) + errs := make(chan error, 1) + + started := make(chan struct{}) + go func() { + defer close(errs) + + query, err := buildEventsQueryParams(options) + if err != nil { + close(started) + errs <- err + return + } + + resp, err := cli.get(ctx, "/events", query, nil) + if err != nil { + close(started) + errs <- err + return + } + defer resp.Body.Close() + + decoder := json.NewDecoder(resp.Body) + + close(started) + for { + select { + case <-ctx.Done(): + errs <- ctx.Err() + return + default: + var event events.Message + if err := decoder.Decode(&event); err != nil { + errs <- err + return + } + + select { + case messages <- event: + case <-ctx.Done(): + errs <- ctx.Err() + return + } + } + } + }() + <-started + + return messages, errs +} + +func buildEventsQueryParams(options events.ListOptions) (url.Values, error) { + query := url.Values{} + + if options.Filters.Len() > 0 { + filterJSON, err := filters.ToJSON(options.Filters) + if err != nil { + return nil, err + } + query.Set("filters", filterJSON) + } + + return query, nil +} diff --git a/internal/docker/client/options.go b/internal/docker/client/options.go new file mode 100644 index 0000000..7464a05 --- /dev/null +++ b/internal/docker/client/options.go @@ -0,0 +1,45 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/options.go +*/ + +import ( + "fmt" + "net/http" + + "github.com/wollomatic/socket-proxy/internal/go-connections/sockets" +) + +// Opt is a configuration option to initialize a [Client]. +type Opt func(*Client) error + +// WithHost overrides the client host with the specified one. +func WithHost(host string) Opt { + return func(c *Client) error { + hostURL, err := ParseHostURL(host) + if err != nil { + return err + } + c.host = host + c.proto = hostURL.Scheme + c.addr = hostURL.Host + c.basePath = hostURL.Path + if transport, ok := c.client.Transport.(*http.Transport); ok { + return sockets.ConfigureTransport(transport, c.proto, c.addr) + } + return fmt.Errorf("cannot apply host to transport: %v", c.client.Transport) + } +} + +// WithAPIVersionNegotiation enables automatic API version negotiation for the client. +// With this option enabled, the client automatically negotiates the API version +// to use when making requests. API version negotiation is performed on the first +// request; subsequent requests do not re-negotiate. +func WithAPIVersionNegotiation() Opt { + return func(c *Client) error { + c.negotiateVersion = true + return nil + } +} diff --git a/internal/docker/client/ping.go b/internal/docker/client/ping.go new file mode 100644 index 0000000..619a4d8 --- /dev/null +++ b/internal/docker/client/ping.go @@ -0,0 +1,68 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/ping.go +*/ + +import ( + "context" + "net/http" + "path" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types" +) + +// Ping pings the server and returns the value of the "API-Version" header. +// It attempts to use a HEAD request on the endpoint, but falls back to GET if +// HEAD is not supported by the daemon. It ignores internal server errors +// returned by the API, which may be returned if the daemon is in an unhealthy +// state, but returns errors for other non-success status codes, failing to +// connect to the API, or failing to parse the API response. +func (cli *Client) Ping(ctx context.Context) (types.Ping, error) { + var ping types.Ping + + // Using cli.buildRequest() + cli.doRequest() instead of cli.sendRequest() + // because ping requests are used during API version negotiation, so we want + // to hit the non-versioned /_ping endpoint, not /v1.xx/_ping + req, err := cli.buildRequest(ctx, http.MethodHead, path.Join(cli.basePath, "/_ping"), nil, nil) + if err != nil { + return ping, err + } + resp, err := cli.doRequest(req) + if err != nil { + if IsErrConnectionFailed(err) { + return ping, err + } + // We managed to connect, but got some error; continue and try GET request. + } else { + defer ensureReaderClosed(resp) + switch resp.StatusCode { + case http.StatusOK, http.StatusInternalServerError: + // Server handled the request, so parse the response + return parsePingResponse(cli, resp) + } + } + + // HEAD failed; fallback to GET. + req.Method = http.MethodGet + resp, err = cli.doRequest(req) + defer ensureReaderClosed(resp) + if err != nil { + return ping, err + } + return parsePingResponse(cli, resp) +} + +func parsePingResponse(cli *Client, resp *http.Response) (types.Ping, error) { + if resp == nil { + return types.Ping{}, nil + } + + var ping types.Ping + if resp.Header == nil { + return ping, cli.checkResponseErr(resp) + } + ping.APIVersion = resp.Header.Get("Api-Version") + return ping, cli.checkResponseErr(resp) +} diff --git a/internal/docker/client/request.go b/internal/docker/client/request.go new file mode 100644 index 0000000..af7577f --- /dev/null +++ b/internal/docker/client/request.go @@ -0,0 +1,219 @@ +package client + +/* +This was modified from: +https://github.com/moby/moby/blob/v28.5.1/client/request.go +*/ + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + + "github.com/wollomatic/socket-proxy/internal/docker/api/types" + "github.com/wollomatic/socket-proxy/internal/docker/api/types/versions" +) + +// get sends an http request to the docker API using the method GET with a specific Go context. +func (cli *Client) get(ctx context.Context, path string, query url.Values, headers http.Header) (*http.Response, error) { + return cli.sendRequest(ctx, http.MethodGet, path, query, nil, headers) +} + +func (cli *Client) buildRequest(ctx context.Context, method, path string, body io.Reader, headers http.Header) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, path, body) + if err != nil { + return nil, err + } + req = cli.addHeaders(req, headers) + req.URL.Scheme = cli.scheme + req.URL.Host = cli.addr + + if cli.proto == "unix" { + // Override host header for non-tcp connections. + req.Host = DummyHost + } + + if body != nil && req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "text/plain") + } + return req, nil +} + +func (cli *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body io.Reader, headers http.Header) (*http.Response, error) { + req, err := cli.buildRequest(ctx, method, cli.getAPIPath(ctx, path, query), body, headers) + if err != nil { + return nil, err + } + + resp, err := cli.doRequest(req) + switch { + case errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): + return nil, err + case err == nil: + return resp, cli.checkResponseErr(resp) + default: + return resp, err + } +} + +func (cli *Client) doRequest(req *http.Request) (*http.Response, error) { + resp, err := cli.client.Do(req) + if err != nil { + // Don't decorate context sentinel errors; users may be comparing to + // them directly. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + + var uErr *url.Error + if errors.As(err, &uErr) { + var nErr *net.OpError + if errors.As(uErr.Err, &nErr) { + if os.IsPermission(nErr.Err) { + return nil, errConnectionFailed{fmt.Errorf("permission denied while trying to connect to the Docker daemon socket at %v: %v", cli.host, err)} + } + } + } + + var nErr net.Error + if errors.As(err, &nErr) { + if nErr.Timeout() { + return nil, connectionFailed(cli.host) + } + if strings.Contains(nErr.Error(), "connection refused") || strings.Contains(nErr.Error(), "dial unix") { + return nil, connectionFailed(cli.host) + } + } + + return nil, errConnectionFailed{fmt.Errorf("error during connect: %v", err)} + } + + return resp, nil +} + +func (cli *Client) checkResponseErr(serverResp *http.Response) (retErr error) { + if serverResp == nil { + return nil + } + if serverResp.StatusCode >= http.StatusOK && serverResp.StatusCode < http.StatusBadRequest { + return nil + } + defer func() { + if retErr != nil { + retErr = fmt.Errorf("HTTP error %d: %v", serverResp.StatusCode, retErr) + } + }() + + var body []byte + var err error + var reqURL string + if serverResp.Request != nil { + reqURL = serverResp.Request.URL.String() + } + statusMsg := serverResp.Status + if statusMsg == "" { + statusMsg = http.StatusText(serverResp.StatusCode) + } + if serverResp.Body != nil { + bodyMax := 1 * 1024 * 1024 // 1 MiB + bodyR := &io.LimitedReader{ + R: serverResp.Body, + N: int64(bodyMax), + } + body, err = io.ReadAll(bodyR) + if err != nil { + return err + } + if bodyR.N == 0 { + if reqURL != "" { + return fmt.Errorf("request returned %s with a message (> %d bytes) for API route and version %s, check if the server supports the requested API version", statusMsg, bodyMax, reqURL) + } + return fmt.Errorf("request returned %s with a message (> %d bytes); check if the server supports the requested API version", statusMsg, bodyMax) + } + } + if len(body) == 0 { + if reqURL != "" { + return fmt.Errorf("request returned %s for API route and version %s, check if the server supports the requested API version", statusMsg, reqURL) + } + return fmt.Errorf("request returned %s; check if the server supports the requested API version", statusMsg) + } + + var daemonErr error + if serverResp.Header.Get("Content-Type") == "application/json" { + var errorResponse types.ErrorResponse + if err := json.Unmarshal(body, &errorResponse); err != nil { + return fmt.Errorf("Error reading JSON: %v", err) + } + if errorResponse.Message == "" { + // Error-message is empty, which means that we successfully parsed the + // JSON-response (no error produced), but it didn't contain an error + // message. This could either be because the response was empty, or + // the response was valid JSON, but not with the expected schema + // ([types.ErrorResponse]). + // + // We cannot use "strict" JSON handling (json.NewDecoder with DisallowUnknownFields) + // due to the API using an open schema (we must anticipate fields + // being added to [types.ErrorResponse] in the future, and not + // reject those responses. + // + // For these cases, we construct an error with the status-code + // returned, but we could consider returning (a truncated version + // of) the actual response as-is. + + daemonErr = fmt.Errorf(`API returned a %d (%s) but provided no error-message`, + serverResp.StatusCode, + http.StatusText(serverResp.StatusCode), + ) + } else { + daemonErr = errors.New(strings.TrimSpace(errorResponse.Message)) + } + } else { + // Fall back to returning the response as-is for API versions < 1.24 + // that didn't support JSON error responses, and for situations + // where a plain text error is returned. This branch may also catch + // situations where a proxy is involved, returning a HTML response. + daemonErr = errors.New(strings.TrimSpace(string(body))) + } + return fmt.Errorf("Error response from daemon: %v", daemonErr) +} + +func (cli *Client) addHeaders(req *http.Request, headers http.Header) *http.Request { + // Add CLI Config's HTTP Headers BEFORE we set the Docker headers + // then the user can't change OUR headers + for k, v := range cli.customHTTPHeaders { + if versions.LessThan(cli.version, "1.25") && http.CanonicalHeaderKey(k) == "User-Agent" { + continue + } + req.Header.Set(k, v) + } + + for k, v := range headers { + req.Header[http.CanonicalHeaderKey(k)] = v + } + + if cli.userAgent != nil { + if *cli.userAgent == "" { + req.Header.Del("User-Agent") + } else { + req.Header.Set("User-Agent", *cli.userAgent) + } + } + return req +} + +func ensureReaderClosed(response *http.Response) { + if response != nil && response.Body != nil { + // Drain up to 512 bytes and close the body to let the Transport reuse the connection + // see https://github.com/google/go-github/pull/317/files#r57536827 + + _, _ = io.CopyN(io.Discard, response.Body, 512) + _ = response.Body.Close() + } +} diff --git a/internal/go-connections/sockets/sockets.go b/internal/go-connections/sockets/sockets.go new file mode 100644 index 0000000..0d889e8 --- /dev/null +++ b/internal/go-connections/sockets/sockets.go @@ -0,0 +1,69 @@ +/* +Package sockets provides helper functions to create and configure Unix or TCP sockets. + +This was modified from: +https://github.com/docker/go-connections/blob/v0.6.0/sockets/sockets.go +*/ +package sockets + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "syscall" + "time" +) + +const ( + defaultTimeout = 10 * time.Second + maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) +) + +// ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system. +var ErrProtocolNotAvailable = errors.New("protocol not available") + +// ConfigureTransport configures the specified [http.Transport] according to the specified proto +// and addr. +// +// If the proto is unix (using a unix socket to communicate) or npipe the compression is disabled. +// For other protos, compression is enabled. If you want to manually enable/disable compression, +// make sure you do it _after_ any subsequent calls to ConfigureTransport is made against the same +// [http.Transport]. +func ConfigureTransport(tr *http.Transport, proto, addr string) error { + if tr.MaxIdleConns == 0 { + // prevent long-lived processes from leaking connections + // due to idle connections not being released. + // + // TODO: see if we can also address this from the server side; see: https://github.com/moby/moby/issues/45539 + tr.MaxIdleConns = 6 + tr.IdleConnTimeout = 30 * time.Second + } + switch proto { + case "unix": + return configureUnixTransport(tr, addr) + default: + tr.Proxy = http.ProxyFromEnvironment + tr.DisableCompression = false + tr.DialContext = (&net.Dialer{ + Timeout: defaultTimeout, + }).DialContext + } + return nil +} + +func configureUnixTransport(tr *http.Transport, addr string) error { + if len(addr) > maxUnixSocketPathSize { + return fmt.Errorf("unix socket path %q is too long", addr) + } + // No need for compression in local communications. + tr.DisableCompression = true + dialer := &net.Dialer{ + Timeout: defaultTimeout, + } + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, "unix", addr) + } + return nil +}