From 2cc58c1c78079d627418bd568a67d71eee145005 Mon Sep 17 00:00:00 2001 From: darkweak Date: Sun, 27 Oct 2024 20:10:33 +0100 Subject: [PATCH] fix(traefik): resolves #558, full rewrite to match the actual cache behavior. Yaegi sucks! --- plugins/traefik/go.mod | 6 +- .../override/configurationtypes/types.go | 18 +- plugins/traefik/override/context/cache.go | 4 + plugins/traefik/override/context/graphql.go | 26 + plugins/traefik/override/context/key.go | 128 +++-- plugins/traefik/override/context/method.go | 4 + plugins/traefik/override/context/mode.go | 6 +- plugins/traefik/override/context/now.go | 4 + plugins/traefik/override/context/timeout.go | 6 +- plugins/traefik/override/context/types.go | 5 +- .../traefik/override/middleware/middleware.go | 521 +++++++++++++----- plugins/traefik/override/middleware/writer.go | 41 +- plugins/traefik/override/rfc/revalidation.go | 141 ++++- .../override/storage/abstractProvider.go | 2 +- .../traefik/override/storage/cacheProvider.go | 48 +- .../traefik/override/storage/types/types.go | 28 +- .../souin/configurationtypes/types.go | 18 +- .../darkweak/souin/context/cache.go | 4 + .../darkweak/souin/context/graphql.go | 26 + .../github.com/darkweak/souin/context/key.go | 128 +++-- .../darkweak/souin/context/method.go | 4 + .../github.com/darkweak/souin/context/mode.go | 6 +- .../github.com/darkweak/souin/context/now.go | 4 + .../darkweak/souin/context/timeout.go | 6 +- .../darkweak/souin/context/types.go | 5 +- .../souin/pkg/middleware/middleware.go | 521 +++++++++++++----- .../darkweak/souin/pkg/middleware/writer.go | 41 +- .../darkweak/souin/pkg/rfc/revalidation.go | 141 ++++- .../souin/pkg/storage/abstractProvider.go | 2 +- .../souin/pkg/storage/cacheProvider.go | 48 +- .../darkweak/souin/pkg/storage/types/types.go | 28 +- 31 files changed, 1431 insertions(+), 539 deletions(-) diff --git a/plugins/traefik/go.mod b/plugins/traefik/go.mod index d9ffeb379..abec048ec 100644 --- a/plugins/traefik/go.mod +++ b/plugins/traefik/go.mod @@ -96,13 +96,13 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/darkweak/go-esi v0.0.5 + github.com/darkweak/go-esi v0.0.5 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/golang/glog v1.2.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/google/uuid v1.6.0 github.com/imdario/mergo v0.3.13 // indirect github.com/klauspost/compress v1.17.8 // indirect github.com/miekg/dns v1.1.59 // indirect @@ -115,7 +115,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.25.0 // indirect - golang.org/x/sync v0.7.0 // indirect + golang.org/x/sync v0.7.0 golang.org/x/sys v0.20.0 // indirect golang.org/x/text v0.15.0 // indirect golang.org/x/tools v0.21.0 // indirect diff --git a/plugins/traefik/override/configurationtypes/types.go b/plugins/traefik/override/configurationtypes/types.go index 10729fea2..9af88bd8d 100644 --- a/plugins/traefik/override/configurationtypes/types.go +++ b/plugins/traefik/override/configurationtypes/types.go @@ -183,6 +183,10 @@ type URL struct { // CacheProvider config type CacheProvider struct { + // Uuid to identify a unique instance. + Uuid string + // Found to determine if we can use that storage. + Found bool `json:"found" yaml:"found"` // URL to connect to the storage system. URL string `json:"url" yaml:"url"` // Path to the configuration file. @@ -247,7 +251,7 @@ type DefaultCache struct { Timeout Timeout `json:"timeout" yaml:"timeout"` TTL Duration `json:"ttl" yaml:"ttl"` DefaultCacheControl string `json:"default_cache_control" yaml:"default_cache_control"` - MaxBodyBytes uint64 `json:"max_cachable_body_bytes" yaml:"max_cachable_body_bytes"` + MaxBodyBytes uint64 `json:"max_cacheable_body_bytes" yaml:"max_cacheable_body_bytes"` DisableCoalescing bool `json:"disable_coalescing" yaml:"disable_coalescing"` } @@ -326,11 +330,6 @@ func (d *DefaultCache) GetRegex() Regex { return d.Regex } -// GetSimpleFS returns simpleFS configuration -func (d *DefaultCache) GetSimpleFS() CacheProvider { - return d.SimpleFS -} - // GetTimeout returns the backend and cache timeouts func (d *DefaultCache) GetTimeout() Timeout { return d.Timeout @@ -341,6 +340,11 @@ func (d *DefaultCache) GetTTL() time.Duration { return d.TTL.Duration } +// GetSimpleFS returns simplefs configuration +func (d *DefaultCache) GetSimpleFS() CacheProvider { + return d.SimpleFS +} + // GetStale returns the stale duration func (d *DefaultCache) GetStale() time.Duration { return d.Stale.Duration @@ -376,6 +380,7 @@ type DefaultCacheInterface interface { GetEtcd() CacheProvider GetMode() string GetOtter() CacheProvider + GetNats() CacheProvider GetNuts() CacheProvider GetOlric() CacheProvider GetRedis() CacheProvider @@ -389,6 +394,7 @@ type DefaultCacheInterface interface { GetTTL() time.Duration GetDefaultCacheControl() string GetMaxBodyBytes() uint64 + IsCoalescingDisable() bool } // APIEndpoint is the minimal structure to define an endpoint diff --git a/plugins/traefik/override/context/cache.go b/plugins/traefik/override/context/cache.go index fbdd89cdc..c4051f0c4 100644 --- a/plugins/traefik/override/context/cache.go +++ b/plugins/traefik/override/context/cache.go @@ -19,6 +19,10 @@ type cacheContext struct { cacheName string } +func (*cacheContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *cacheContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { cc.cacheName = defaultCacheName if c.GetDefaultCache().GetCacheName() != "" { diff --git a/plugins/traefik/override/context/graphql.go b/plugins/traefik/override/context/graphql.go index 8557de703..dad0e70b8 100644 --- a/plugins/traefik/override/context/graphql.go +++ b/plugins/traefik/override/context/graphql.go @@ -21,6 +21,32 @@ type graphQLContext struct { custom bool } +func (g *graphQLContext) SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request { + ctx := req.Context() + ctx = context.WithValue(ctx, GraphQL, g.custom) + ctx = context.WithValue(ctx, HashBody, "") + ctx = context.WithValue(ctx, IsMutationRequest, false) + + if g.custom && req.Body != nil { + b := bytes.NewBuffer([]byte{}) + _, _ = io.Copy(b, req.Body) + req.Body = io.NopCloser(b) + baseRq.Body = io.NopCloser(b) + + if b.Len() > 0 { + if isMutation(b.Bytes()) { + ctx = context.WithValue(ctx, IsMutationRequest, true) + } else { + h := sha256.New() + h.Write(b.Bytes()) + ctx = context.WithValue(ctx, HashBody, fmt.Sprintf("-%x", h.Sum(nil))) + } + } + } + + return req.WithContext(ctx) +} + func (g *graphQLContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { g.custom = true diff --git a/plugins/traefik/override/context/key.go b/plugins/traefik/override/context/key.go index 352fd05b6..9a468cf2e 100644 --- a/plugins/traefik/override/context/key.go +++ b/plugins/traefik/override/context/key.go @@ -12,6 +12,7 @@ const ( Key ctxKey = "souin_ctx.CACHE_KEY" DisplayableKey ctxKey = "souin_ctx.DISPLAYABLE_KEY" IgnoredHeaders ctxKey = "souin_ctx.IGNORE_HEADERS" + Hashed ctxKey = "souin_ctx.HASHED" ) type keyContext struct { @@ -20,10 +21,17 @@ type keyContext struct { disable_method bool disable_query bool disable_scheme bool - hash bool displayable bool + hash bool headers []string + template string overrides []map[*regexp.Regexp]keyContext + + initializer func(r *http.Request) *http.Request +} + +func (*keyContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req } func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { @@ -35,89 +43,76 @@ func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInte g.disable_scheme = k.DisableScheme g.hash = k.Hash g.displayable = !k.Hide + g.template = k.Template g.headers = k.Headers g.overrides = make([]map[*regexp.Regexp]keyContext, 0) - for _, cacheKey := range c.GetCacheKeys() { - for r, v := range cacheKey { - g.overrides = append(g.overrides, map[*regexp.Regexp]keyContext{r.Regexp: { - disable_body: v.DisableBody, - disable_host: v.DisableHost, - disable_method: v.DisableMethod, - disable_query: v.DisableQuery, - disable_scheme: v.DisableScheme, - hash: v.Hash, - displayable: !v.Hide, - headers: v.Headers, - }}) - } + // for _, cacheKey := range c.GetCacheKeys() { + // for r, v := range cacheKey { + // g.overrides = append(g.overrides, map[*regexp.Regexp]keyContext{r.Regexp: { + // disable_body: v.DisableBody, + // disable_host: v.DisableHost, + // disable_method: v.DisableMethod, + // disable_query: v.DisableQuery, + // disable_scheme: v.DisableScheme, + // hash: v.Hash, + // displayable: !v.Hide, + // template: v.Template, + // headers: v.Headers, + // }}) + // } + // } + + g.initializer = func(r *http.Request) *http.Request { + return r } } -func (g *keyContext) SetContext(req *http.Request) *http.Request { - key := req.URL.Path - var headers []string +func parseKeyInformations(req *http.Request, kCtx keyContext) (query, body, host, scheme, method, headerValues string, headers []string, displayable, hash bool) { + displayable = kCtx.displayable + hash = kCtx.hash - scheme := "http-" - if req.TLS != nil { - scheme = "https-" - } - query := "" - body := "" - host := "" - method := "" - headerValues := "" - displayable := g.displayable - - if !g.disable_query && len(req.URL.RawQuery) > 0 { + if !kCtx.disable_query && len(req.URL.RawQuery) > 0 { query += "?" + req.URL.RawQuery } - if !g.disable_body { + if !kCtx.disable_body { body = req.Context().Value(HashBody).(string) } - if !g.disable_host { + if !kCtx.disable_host { host = req.Host + "-" } - if !g.disable_method { + if !kCtx.disable_scheme { + scheme = "http-" + if req.TLS != nil { + scheme = "https-" + } + } + + if !kCtx.disable_method { method = req.Method + "-" } - headers = g.headers - for _, hn := range g.headers { + headers = kCtx.headers + for _, hn := range kCtx.headers { headerValues += "-" + req.Header.Get(hn) } + return +} + +func (g *keyContext) computeKey(req *http.Request) (key string, headers []string, hash, displayable bool) { + key = req.URL.Path + query, body, host, scheme, method, headerValues, headers, displayable, hash := parseKeyInformations(req, *g) + hasOverride := false for _, current := range g.overrides { for k, v := range current { if k.MatchString(req.RequestURI) { - displayable = v.displayable - host = "" - method = "" - query = "" - if !v.disable_query && len(req.URL.RawQuery) > 0 { - query = "?" + req.URL.RawQuery - } - if !v.disable_body { - body = req.Context().Value(HashBody).(string) - } - if !v.disable_method { - method = req.Method + "-" - } - if !v.disable_host { - host = req.Host + "-" - } - if len(v.headers) > 0 { - headerValues = "" - for _, hn := range v.headers { - headers = v.headers - headerValues += "-" + req.Header.Get(hn) - } - } + query, body, host, scheme, method, headerValues, headers, displayable, hash = parseKeyInformations(req, v) hasOverride = true break } @@ -128,13 +123,26 @@ func (g *keyContext) SetContext(req *http.Request) *http.Request { } } + key = method + scheme + host + key + query + body + headerValues + + return +} + +func (g *keyContext) SetContext(req *http.Request) *http.Request { + rq := g.initializer(req) + key, headers, hash, displayable := g.computeKey(rq) + return req.WithContext( context.WithValue( context.WithValue( context.WithValue( - req.Context(), - Key, - method+scheme+host+key+query+body+headerValues, + context.WithValue( + req.Context(), + Key, + key, + ), + Hashed, + hash, ), IgnoredHeaders, headers, diff --git a/plugins/traefik/override/context/method.go b/plugins/traefik/override/context/method.go index 1e6417cbb..ee772a3de 100644 --- a/plugins/traefik/override/context/method.go +++ b/plugins/traefik/override/context/method.go @@ -16,6 +16,10 @@ type methodContext struct { custom bool } +func (*methodContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (m *methodContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { m.allowedVerbs = defaultVerbs if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { diff --git a/plugins/traefik/override/context/mode.go b/plugins/traefik/override/context/mode.go index b041abb15..ec2d5221d 100644 --- a/plugins/traefik/override/context/mode.go +++ b/plugins/traefik/override/context/mode.go @@ -13,6 +13,10 @@ type ModeContext struct { Strict, Bypass_request, Bypass_response bool } +func (*ModeContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (mc *ModeContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { mode := c.GetDefaultCache().GetMode() mc.Bypass_request = mode == "bypass" || mode == "bypass_request" @@ -24,4 +28,4 @@ func (mc *ModeContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(req.Context(), Mode, mc)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*ModeContext)(nil) diff --git a/plugins/traefik/override/context/now.go b/plugins/traefik/override/context/now.go index 898cc18fe..d0d4e0f3b 100644 --- a/plugins/traefik/override/context/now.go +++ b/plugins/traefik/override/context/now.go @@ -12,6 +12,10 @@ const Now ctxKey = "souin_ctx.NOW" type nowContext struct{} +func (*nowContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *nowContext) SetupContext(_ configurationtypes.AbstractConfigurationInterface) {} func (cc *nowContext) SetContext(req *http.Request) *http.Request { diff --git a/plugins/traefik/override/context/timeout.go b/plugins/traefik/override/context/timeout.go index 6c737d24c..4da27d984 100644 --- a/plugins/traefik/override/context/timeout.go +++ b/plugins/traefik/override/context/timeout.go @@ -22,6 +22,10 @@ type timeoutContext struct { timeoutCache, timeoutBackend time.Duration } +func (*timeoutContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (t *timeoutContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { t.timeoutBackend = defaultTimeoutBackend t.timeoutCache = defaultTimeoutCache @@ -38,4 +42,4 @@ func (t *timeoutContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(context.WithValue(ctx, TimeoutCancel, cancel), TimeoutCache, t.timeoutCache)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*timeoutContext)(nil) diff --git a/plugins/traefik/override/context/types.go b/plugins/traefik/override/context/types.go index 38bf5ed19..34e56363f 100644 --- a/plugins/traefik/override/context/types.go +++ b/plugins/traefik/override/context/types.go @@ -12,6 +12,7 @@ type ( ctx interface { SetupContext(c configurationtypes.AbstractConfigurationInterface) SetContext(req *http.Request) *http.Request + SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request } Context struct { @@ -53,6 +54,6 @@ func (c *Context) SetBaseContext(req *http.Request) *http.Request { return c.Mode.SetContext(c.Timeout.SetContext(c.Method.SetContext(c.CacheName.SetContext(c.Now.SetContext(req))))) } -func (c *Context) SetContext(req *http.Request) *http.Request { - return c.Key.SetContext(c.GraphQL.SetContext(req)) +func (c *Context) SetContext(req *http.Request, baseRq *http.Request) *http.Request { + return c.Key.SetContext(c.GraphQL.SetContextWithBaseRequest(req, baseRq)) } diff --git a/plugins/traefik/override/middleware/middleware.go b/plugins/traefik/override/middleware/middleware.go index 34735b88d..53f899690 100644 --- a/plugins/traefik/override/middleware/middleware.go +++ b/plugins/traefik/override/middleware/middleware.go @@ -22,7 +22,9 @@ import ( "github.com/darkweak/souin/pkg/storage/types" "github.com/darkweak/souin/pkg/surrogate" "github.com/darkweak/souin/pkg/surrogate/providers" + "github.com/google/uuid" "github.com/pquerna/cachecontrol/cacheobject" + "golang.org/x/sync/singleflight" ) func NewHTTPCacheHandler(c configurationtypes.AbstractConfigurationInterface) *SouinBaseHandler { @@ -65,6 +67,7 @@ func NewHTTPCacheHandler(c configurationtypes.AbstractConfigurationInterface) *S context: ctx, bufPool: bufPool, storersLen: len(storers), + singleflightPool: singleflight.Group{}, } } @@ -78,14 +81,17 @@ type SouinBaseHandler struct { SurrogateKeyStorer providers.SurrogateInterface DefaultMatchedUrl configurationtypes.URL context *context.Context + singleflightPool singleflight.Group bufPool *sync.Pool storersLen int } -type upsreamError struct{} +var Upstream50xError = upstream50xError{} -func (upsreamError) Error() string { - return "Upstream error" +type upstream50xError struct{} + +func (upstream50xError) Error() string { + return "Upstream 50x error" } func isCacheableCode(code int) bool { @@ -97,6 +103,15 @@ func isCacheableCode(code int) bool { return false } +func canStatusCodeEmptyContent(code int) bool { + switch code { + case 204, 301, 405: + return true + } + + return false +} + func canBypassAuthorizationRestriction(headers http.Header, bypassed []string) bool { for _, header := range bypassed { if strings.ToLower(header) == "authorization" { @@ -113,35 +128,37 @@ func (s *SouinBaseHandler) Store( requestCc *cacheobject.RequestCacheDirectives, cachedKey string, ) error { - if !isCacheableCode(customWriter.statusCode) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + statusCode := customWriter.GetStatusCode() + if !isCacheableCode(statusCode) { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - switch customWriter.statusCode { + switch statusCode { case 500, 502, 503, 504: - return new(upsreamError) + return Upstream50xError } return nil } - if customWriter.Header().Get("Cache-Control") == "" { + headerName, cacheControl := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if cacheControl == "" { // TODO see with @mnot if mandatory to not store the response when no Cache-Control given. // if s.DefaultMatchedUrl.DefaultCacheControl == "" { - // customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + // customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) // return nil // } - customWriter.Header().Set("Cache-Control", s.DefaultMatchedUrl.DefaultCacheControl) + customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl) } - responseCc, _ := cacheobject.ParseResponseCacheControl(customWriter.Header().Get("Cache-Control")) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(customWriter.Header(), headerName)) if responseCc == nil { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=INVALID-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=INVALID-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) return nil } modeContext := rq.Context().Value(context.Mode).(*context.ModeContext) if !modeContext.Bypass_request && (responseCc.PrivatePresent || rq.Header.Get("Authorization") != "") && !canBypassAuthorizationRestriction(customWriter.Header(), rq.Context().Value(context.IgnoredHeaders).([]string)) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=PRIVATE-OR-AUTHENTICATED-RESPONSE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=PRIVATE-OR-AUTHENTICATED-RESPONSE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) return nil } @@ -156,40 +173,55 @@ func (s *SouinBaseHandler) Store( } } + hasFreshness := false ma := currentMatchedURL.TTL.Duration if responseCc.SMaxAge >= 0 { ma = time.Duration(responseCc.SMaxAge) * time.Second } else if responseCc.MaxAge >= 0 { ma = time.Duration(responseCc.MaxAge) * time.Second - } - if ma > currentMatchedURL.TTL.Duration { - ma = currentMatchedURL.TTL.Duration + } else if customWriter.Header().Get("Expires") != "" { + exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires")) + if err != nil { + return nil + } + + duration := time.Until(exp) + if duration <= 0 || duration > 10*types.OneYearDuration { + return nil + } + + date, _ := time.Parse(time.RFC1123, customWriter.Header().Get("Date")) + if date.Sub(exp) > 0 { + return nil + } + + ma = duration + hasFreshness = true } now := rq.Context().Value(context.Now).(time.Time) date, _ := http.ParseTime(now.Format(http.TimeFormat)) - customWriter.Headers.Set(rfc.StoredTTLHeader, ma.String()) + customWriter.Header().Set(rfc.StoredTTLHeader, ma.String()) ma = ma - time.Since(date) - if exp := customWriter.Header().Get("Expires"); exp != "" { - delta, _ := time.Parse(exp, time.RFC1123) - if sub := delta.Sub(now); sub > 0 { - ma = sub - } - } - status := fmt.Sprintf("%s; fwd=uri-miss", rq.Context().Value(context.CacheName)) if (modeContext.Bypass_request || !requestCc.NoStore) && - (modeContext.Bypass_response || !responseCc.NoStore) { - headers := customWriter.Headers.Clone() + (modeContext.Bypass_response || !responseCc.NoStore || hasFreshness) { + headers := customWriter.Header().Clone() for hname, shouldDelete := range responseCc.NoCache { if shouldDelete { headers.Del(hname) } } + + customWriter.mutex.Lock() + b := customWriter.Buf.Bytes() + bLen := customWriter.Buf.Len() + customWriter.mutex.Unlock() + res := http.Response{ - StatusCode: customWriter.statusCode, - Body: io.NopCloser(bytes.NewBuffer(customWriter.Buf.Bytes())), + StatusCode: statusCode, + Body: io.NopCloser(bytes.NewBuffer(b)), Header: headers, } @@ -197,17 +229,23 @@ func (s *SouinBaseHandler) Store( res.Header.Set("Date", now.Format(http.TimeFormat)) } if res.Header.Get("Content-Length") == "" { - res.Header.Set("Content-Length", fmt.Sprint(customWriter.Buf.Len())) + res.Header.Set("Content-Length", fmt.Sprint(bLen)) + } + respBodyMaxSize := int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()) + if respBodyMaxSize > 0 && bLen > respBodyMaxSize { + customWriter.Header().Set("Cache-Status", status+"; detail=UPSTREAM-RESPONSE-TOO-LARGE; key="+rfc.GetCacheKeyFromCtx(rq.Context())) + + return nil } res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length")) response, err := httputil.DumpResponse(&res, true) - if err == nil { + if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) { variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header) if isVaryStar { // "Implies that the response is uncacheable" status += "; detail=UPSTREAM-VARY-STAR" } else { - cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders) + variedKey := cachedKey + rfc.GetVariedCacheKey(rq, variedHeaders) var wg sync.WaitGroup mu := sync.Mutex{} @@ -216,11 +254,25 @@ func (s *SouinBaseHandler) Store( case <-rq.Context().Done(): status += "; detail=REQUEST-CANCELED-OR-UPSTREAM-BROKEN-PIPE" default: + vhs := http.Header{} + for _, hname := range variedHeaders { + hn := strings.Split(hname, ":") + vhs.Set(hn[0], rq.Header.Get(hn[0])) + } for _, storer := range s.Storers { wg.Add(1) go func(currentStorer types.Storer) { defer wg.Done() - if currentStorer.Set(cachedKey, response, ma) != nil { + if currentStorer.SetMultiLevel( + cachedKey, + variedKey, + response, + vhs, + res.Header.Get("Etag"), ma, + variedKey, + ) == nil { + res.Request = rq + } else { mu.Lock() fails = append(fails, fmt.Sprintf("; detail=%s-INSERTION-ERROR", currentStorer.Name())) mu.Unlock() @@ -232,7 +284,7 @@ func (s *SouinBaseHandler) Store( if len(fails) < s.storersLen { go func(rs http.Response, key string) { _ = s.SurrogateKeyStorer.Store(&rs, key, "") - }(res, cachedKey) + }(res, variedKey) status += "; stored" } @@ -241,15 +293,25 @@ func (s *SouinBaseHandler) Store( } } } + + } else { + status += "; detail=UPSTREAM-ERROR-OR-EMPTY-RESPONSE" } } else { status += "; detail=NO-STORE-DIRECTIVE" } - customWriter.Headers.Set("Cache-Status", status+"; key="+rfc.GetCacheKeyFromCtx(rq.Context())) + customWriter.Header().Set("Cache-Status", status+"; key="+rfc.GetCacheKeyFromCtx(rq.Context())) return nil } +type singleflightValue struct { + body []byte + headers http.Header + requestHeaders http.Header + code int +} + func (s *SouinBaseHandler) Upstream( customWriter *CustomWriter, rq *http.Request, @@ -257,71 +319,138 @@ func (s *SouinBaseHandler) Upstream( requestCc *cacheobject.RequestCacheDirectives, cachedKey string, ) error { - if err := next(customWriter, rq); err != nil { - customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - return err + var recoveredFromErr error = nil + defer func() { + // In case of "http.ErrAbortHandler" panic, + // prevent singleflight from wrapping it into "singleflight.panicError". + if r := recover(); r != nil { + err, ok := r.(error) + // Sometimes, the error is a string. + if !ok || errors.Is(err, http.ErrAbortHandler) { + recoveredFromErr = http.ErrAbortHandler + } else { + panic(err) + } + } + }() + + singleflightCacheKey := cachedKey + if s.Configuration.GetDefaultCache().IsCoalescingDisable() { + singleflightCacheKey += uuid.NewString() } + sfValue, err, _ := s.singleflightPool.Do(singleflightCacheKey, func() (interface{}, error) { + if e := next(customWriter, rq); e != nil { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + return nil, e + } - s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - if !isCacheableCode(customWriter.statusCode) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - switch customWriter.statusCode { - case 500, 502, 503, 504: - return new(upsreamError) + statusCode := customWriter.GetStatusCode() + if !isCacheableCode(statusCode) { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + + switch statusCode { + case 500, 502, 503, 504: + return nil, Upstream50xError + } } - return nil - } + headerName, cacheControl := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if cacheControl == "" { + customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl) + } - if customWriter.Header().Get("Cache-Control") == "" { - // TODO see with @mnot if mandatory to not store the response when no Cache-Control given. - // if s.DefaultMatchedUrl.DefaultCacheControl == "" { - // customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - // return nil - // } - customWriter.Header().Set("Cache-Control", s.DefaultMatchedUrl.DefaultCacheControl) + err := s.Store(customWriter, rq, requestCc, cachedKey) + defer customWriter.Buf.Reset() + + return singleflightValue{ + body: customWriter.Buf.Bytes(), + headers: customWriter.Header().Clone(), + requestHeaders: rq.Header, + code: statusCode, + }, err + }) + if recoveredFromErr != nil { + panic(recoveredFromErr) + } + if err != nil { + return err } - select { - case <-rq.Context().Done(): - return baseCtx.Canceled - default: - return s.Store(customWriter, rq, requestCc, cachedKey) + if sfWriter, ok := sfValue.(singleflightValue); ok { + if vary := sfWriter.headers.Get("Vary"); vary != "" { + variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(sfWriter.headers) + if !isVaryStar { + for _, vh := range variedHeaders { + if rq.Header.Get(vh) != sfWriter.requestHeaders.Get(vh) { + // cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders) + return s.Upstream(customWriter, rq, next, requestCc, cachedKey) + } + } + } + } + _, _ = customWriter.Write(sfWriter.body) + // Yaegi sucks, we can't use maps. + for k := range sfWriter.headers { + customWriter.Header().Set(k, sfWriter.headers.Get(k)) + } + customWriter.WriteHeader(sfWriter.code) } + + return nil } -func (s *SouinBaseHandler) Revalidate(validator *rfc.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string) error { - err := next(customWriter, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) +func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string, uri string) error { + singleflightCacheKey := cachedKey + if s.Configuration.GetDefaultCache().IsCoalescingDisable() { + singleflightCacheKey += uuid.NewString() + } + sfValue, err, _ := s.singleflightPool.Do(singleflightCacheKey, func() (interface{}, error) { + err := next(customWriter, rq) + s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - if err == nil { - if validator.IfUnmodifiedSincePresent && customWriter.statusCode != http.StatusNotModified { - customWriter.Buf.Reset() - for h, v := range customWriter.Headers { - if len(v) > 0 { - customWriter.Rw.Header().Set(h, strings.Join(v, ", ")) - } + statusCode := customWriter.GetStatusCode() + if err == nil { + if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified { + customWriter.Buf.Reset() + customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) + + return nil, errors.New("") } - customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) - return errors.New("") + if statusCode != http.StatusNotModified { + err = s.Store(customWriter, rq, requestCc, cachedKey) + } } - if customWriter.statusCode != http.StatusNotModified { - err = s.Store(customWriter, rq, requestCc, cachedKey) + customWriter.Header().Set( + "Cache-Status", + fmt.Sprintf( + "%s; fwd=request; fwd-status=%d; key=%s; detail=REQUEST-REVALIDATION", + rq.Context().Value(context.CacheName), + statusCode, + rfc.GetCacheKeyFromCtx(rq.Context()), + ), + ) + + defer customWriter.Buf.Reset() + return singleflightValue{ + body: customWriter.Buf.Bytes(), + headers: customWriter.Header().Clone(), + code: statusCode, + }, err + }) + + if sfWriter, ok := sfValue.(singleflightValue); ok { + _, _ = customWriter.Write(sfWriter.body) + // Yaegi sucks, we can't use maps. + for k := range sfWriter.headers { + customWriter.Header().Set(k, sfWriter.headers.Get(k)) } + customWriter.WriteHeader(sfWriter.code) } - customWriter.Header().Set( - "Cache-Status", - fmt.Sprintf( - "%s; fwd=request; fwd-status=%d; key=%s; detail=REQUEST-REVALIDATION", - rq.Context().Value(context.CacheName), - customWriter.statusCode, - rfc.GetCacheKeyFromCtx(rq.Context()), - ), - ) return err } @@ -339,89 +468,128 @@ func (s *SouinBaseHandler) HandleInternally(r *http.Request) (bool, http.Handler } type handlerFunc = func(http.ResponseWriter, *http.Request) error +type statusCodeLogger struct { + http.ResponseWriter + statusCode int +} + +func (s *statusCodeLogger) WriteHeader(code int) { + s.statusCode = code + s.ResponseWriter.WriteHeader(code) +} func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, next handlerFunc) error { - b, handler := s.HandleInternally(rq) - if b { + if b, handler := s.HandleInternally(rq); b { handler(rw, rq) return nil } + req := s.context.SetBaseContext(rq) + cacheName := req.Context().Value(context.CacheName).(string) - rq = s.context.SetBaseContext(rq) - cacheName := rq.Context().Value(context.CacheName).(string) - if rq.Header.Get("Upgrade") == "websocket" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { + if rq.Header.Get("Upgrade") == "websocket" || rq.Header.Get("Accept") == "text/event-stream" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI") - return next(rw, rq) + return next(rw, req) } - if !rq.Context().Value(context.SupportedMethod).(bool) { + if !req.Context().Value(context.SupportedMethod).(bool) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=UNSUPPORTED-METHOD") + nrw := &statusCodeLogger{ + ResponseWriter: rw, + statusCode: 0, + } - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(nrw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) + + if err == nil && req.Method != http.MethodGet && nrw.statusCode < http.StatusBadRequest { + // Invalidate related GET keys when the method is not allowed and the response is valid + req.Method = http.MethodGet + keyname := s.context.SetContext(req, rq).Context().Value(context.Key).(string) + for _, storer := range s.Storers { + storer.Delete("IDX_" + keyname) + } + } return err } - requestCc, coErr := cacheobject.ParseRequestCacheControl(rq.Header.Get("Cache-Control")) + requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control")) + + modeContext := req.Context().Value(context.Mode).(*context.ModeContext) - modeContext := rq.Context().Value(context.Mode).(*context.ModeContext) if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - rq = s.context.SetContext(rq) + req = s.context.SetContext(req, rq) + + isMutationRequest := false + // Yaegi sucks AGAIN, it considers the value as nil if we directly try to cast as bool + mutationRequestValue := req.Context().Value(context.IsMutationRequest) + if mutationRequestValue != nil { + isMutationRequest = mutationRequestValue.(bool) + } - // Yaegi sucks again, it considers false as true - isMutationRequest := rq.Context().Value(context.IsMutationRequest).(bool) if isMutationRequest { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=IS-MUTATION-REQUEST") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - cachedKey := rq.Context().Value(context.Key).(string) + cachedKey := req.Context().Value(context.Key).(string) + + // Need to copy URL path before calling next because it can alter the URI + uri := req.URL.Path bufPool := s.bufPool.Get().(*bytes.Buffer) bufPool.Reset() defer s.bufPool.Put(bufPool) - customWriter := NewCustomWriter(rq, rw, bufPool) + customWriter := NewCustomWriter(req, rw, bufPool) + go func(req *http.Request, crw *CustomWriter) { <-req.Context().Done() crw.mutex.Lock() crw.headersSent = true crw.mutex.Unlock() - }(rq, customWriter) + }(req, customWriter) + if modeContext.Bypass_request || !requestCc.NoCache { - validator := rfc.ParseRequest(rq) - var response *http.Response + validator := rfc.ParseRequest(req) + var fresh, stale *http.Response + var storerName string for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(cachedKey, rq, validator) - if response != nil { + fresh, stale = currentStorer.GetMultiLevel(cachedKey, req, validator) + + if fresh != nil || stale != nil { + storerName = currentStorer.Name() break } } - if response != nil && (!modeContext.Strict || rfc.ValidateCacheControl(response, requestCc)) { + headerName, _ := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if fresh != nil && (!modeContext.Strict || rfc.ValidateCacheControl(fresh, requestCc)) { + response := fresh if validator.ResponseETag != "" && validator.Matched { - rfc.SetCacheStatusHeader(response, "DEFAULT") - customWriter.Headers = response.Header + rfc.SetCacheStatusHeader(response, storerName) + for h, v := range response.Header { + customWriter.Header()[h] = v + } if validator.NotModified { - customWriter.statusCode = http.StatusNotModified + customWriter.WriteHeader(http.StatusNotModified) customWriter.Buf.Reset() _, _ = customWriter.Send() return nil } - customWriter.statusCode = response.StatusCode + customWriter.WriteHeader(response.StatusCode) _, _ = io.Copy(customWriter.Buf, response.Body) _, _ = customWriter.Send() @@ -429,48 +597,48 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NeedRevalidation { - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) _, _ = customWriter.Send() return err } - if resCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")); resCc.NoCachePresent { - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + if resCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, headerName)); resCc.NoCachePresent { + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) _, _ = customWriter.Send() return err } - rfc.SetCacheStatusHeader(response, "DEFAULT") + rfc.SetCacheStatusHeader(response, storerName) if !modeContext.Strict || rfc.ValidateMaxAgeCachedResponse(requestCc, response) != nil { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + for h, v := range response.Header { + customWriter.Header()[h] = v + } + customWriter.WriteHeader(response.StatusCode) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() return err } - } else if response == nil && !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { - for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(storage.StalePrefix+cachedKey, rq, validator) - if response != nil { - break - } - } + } else if !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { + response := stale + if nil != response && (!modeContext.Strict || rfc.ValidateCacheControl(response, requestCc)) { addTime, _ := time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) - rfc.SetCacheStatusHeader(response, "DEFAULT") + rfc.SetCacheStatusHeader(response, storerName) - responseCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) if responseCc.StaleWhileRevalidate > 0 { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + for h, v := range response.Header { + customWriter.Header()[h] = v + } + customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() - customWriter = NewCustomWriter(rq, rw, bufPool) - go func(v *rfc.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) { - _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk) - }(validator, customWriter, rq, next, requestCc, cachedKey) + customWriter = NewCustomWriter(req, rw, bufPool) + go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { + _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri) + }(validator, customWriter, req, next, requestCc, cachedKey, uri) buf := s.bufPool.Get().(*bytes.Buffer) buf.Reset() defer s.bufPool.Put(buf) @@ -479,15 +647,20 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { - rq.Header["If-None-Match"] = append(rq.Header["If-None-Match"], validator.ResponseETag) - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) + statusCode := customWriter.GetStatusCode() if err != nil { if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 { - code := fmt.Sprintf("; fwd-status=%d", customWriter.statusCode) - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + code := fmt.Sprintf("; fwd-status=%d", statusCode) rfc.HitStaleCache(&response.Header) response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } + customWriter.WriteHeader(response.StatusCode) + customWriter.Buf.Reset() _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() @@ -500,11 +673,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if customWriter.statusCode == http.StatusNotModified { + if statusCode == http.StatusNotModified { if !validator.Matched { - rfc.SetCacheStatusHeader(response, "DEFAULT") - customWriter.statusCode = response.StatusCode - customWriter.Headers = response.Header + rfc.SetCacheStatusHeader(response, storerName) + customWriter.WriteHeader(response.StatusCode) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } _, _ = io.Copy(customWriter.Buf, response.Body) _, _ = customWriter.Send() @@ -512,8 +688,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } } - if customWriter.statusCode != http.StatusNotModified && validator.Matched { - customWriter.statusCode = http.StatusNotModified + if statusCode != http.StatusNotModified && validator.Matched { + customWriter.WriteHeader(http.StatusNotModified) customWriter.Buf.Reset() _, _ = customWriter.Send() @@ -525,27 +701,64 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, response, int(addTime.Seconds())) != nil { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, responseCc, response, int(addTime.Seconds())) != nil { + customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() return err } } + } else if stale != nil { + response := stale + addTime, _ := time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) + + if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, responseCc, response, int(addTime.Seconds())) != nil { + _, _ = time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) + rfc.SetCacheStatusHeader(response, storerName) + + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) + + if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 { + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) + statusCode := customWriter.GetStatusCode() + if err != nil { + code := fmt.Sprintf("; fwd-status=%d", statusCode) + rfc.HitStaleCache(&response.Header) + response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } + customWriter.WriteHeader(response.StatusCode) + customWriter.Buf.Reset() + _, _ = io.Copy(customWriter.Buf, response.Body) + _, err := customWriter.Send() + + return err + } + } + + } } } errorCacheCh := make(chan error) - go func() { - errorCacheCh <- s.Upstream(customWriter, rq, next, requestCc, cachedKey) - }() + + go func(vr *http.Request, cw *CustomWriter) { + errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey) + }(req, customWriter) select { - case <-rq.Context().Done(): - switch rq.Context().Err() { + case <-req.Context().Done(): + + switch req.Context().Err() { case baseCtx.DeadlineExceeded: customWriter.WriteHeader(http.StatusGatewayTimeout) rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=DEADLINE-EXCEEDED") @@ -556,9 +769,15 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n default: return nil } + case v := <-errorCacheCh: - if v == nil { + + switch v { + case nil: + _, _ = customWriter.Send() + case Upstream50xError: _, _ = customWriter.Send() + return nil } return v } diff --git a/plugins/traefik/override/middleware/writer.go b/plugins/traefik/override/middleware/writer.go index ddbccfdbd..97b479bd7 100644 --- a/plugins/traefik/override/middleware/writer.go +++ b/plugins/traefik/override/middleware/writer.go @@ -3,10 +3,9 @@ package middleware import ( "bytes" "net/http" - "strings" + "strconv" "sync" - "github.com/darkweak/go-esi/esi" "github.com/darkweak/souin/pkg/rfc" ) @@ -37,19 +36,28 @@ type CustomWriter struct { headersSent bool mutex *sync.Mutex statusCode int - // size int } // Header will write the response headers func (r *CustomWriter) Header() http.Header { r.mutex.Lock() defer r.mutex.Unlock() - if r.headersSent { + + if r.headersSent || r.Req.Context().Err() != nil { return http.Header{} } + return r.Rw.Header() } +// GetStatusCode returns the response status code +func (r *CustomWriter) GetStatusCode() int { + r.mutex.Lock() + defer r.mutex.Unlock() + + return r.statusCode +} + // WriteHeader will write the response headers func (r *CustomWriter) WriteHeader(code int) { r.mutex.Lock() @@ -57,14 +65,13 @@ func (r *CustomWriter) WriteHeader(code int) { if r.headersSent { return } - r.Headers = r.Rw.Header() r.statusCode = code - // r.headersSent = true - // r.Rw.WriteHeader(code) } // Write will write the response body func (r *CustomWriter) Write(b []byte) (int, error) { + r.mutex.Lock() + defer r.mutex.Unlock() r.Buf.Grow(len(b)) _, _ = r.Buf.Write(b) @@ -73,24 +80,20 @@ func (r *CustomWriter) Write(b []byte) (int, error) { // Send delays the response to handle Cache-Status func (r *CustomWriter) Send() (int, error) { - contentLength := r.Headers.Get(rfc.StoredLengthHeader) - if contentLength != "" { - r.Header().Set("Content-Length", contentLength) - } defer r.Buf.Reset() - b := esi.Parse(r.Buf.Bytes(), r.Req) - for h, v := range r.Headers { - if len(v) > 0 { - r.Rw.Header().Set(h, strings.Join(v, ", ")) - } + storedLength := r.Header().Get(rfc.StoredLengthHeader) + if storedLength != "" { + r.Header().Set("Content-Length", storedLength) + } + b := r.Buf.Bytes() + if len(b) != 0 { + r.Header().Set("Content-Length", strconv.Itoa(len(b))) } r.Header().Del(rfc.StoredLengthHeader) r.Header().Del(rfc.StoredTTLHeader) if !r.headersSent { - - // r.Rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(b))) - r.Rw.WriteHeader(r.statusCode) + r.Rw.WriteHeader(r.GetStatusCode()) r.headersSent = true } diff --git a/plugins/traefik/override/rfc/revalidation.go b/plugins/traefik/override/rfc/revalidation.go index c0eb363b3..ccb6e73be 100644 --- a/plugins/traefik/override/rfc/revalidation.go +++ b/plugins/traefik/override/rfc/revalidation.go @@ -1,29 +1,17 @@ package rfc import ( + "bufio" + "bytes" + "encoding/json" "net/http" "strings" "time" -) -type Revalidator struct { - Matched bool - IfNoneMatchPresent bool - IfMatchPresent bool - IfModifiedSincePresent bool - IfUnmodifiedSincePresent bool - IfUnmotModifiedSincePresent bool - NeedRevalidation bool - NotModified bool - IfModifiedSince time.Time - IfUnmodifiedSince time.Time - IfNoneMatch []string - IfMatch []string - RequestETags []string - ResponseETag string -} + "github.com/darkweak/souin/pkg/storage/types" +) -func ValidateETagFromHeader(etag string, validator *Revalidator) { +func ValidateETagFromHeader(etag string, validator *types.Revalidator) { validator.ResponseETag = etag validator.NeedRevalidation = validator.NeedRevalidation || validator.ResponseETag != "" validator.Matched = validator.ResponseETag == "" || (validator.ResponseETag != "" && len(validator.RequestETags) == 0) @@ -72,7 +60,7 @@ func ValidateETagFromHeader(etag string, validator *Revalidator) { } } -func ParseRequest(req *http.Request) *Revalidator { +func ParseRequest(req *http.Request) *types.Revalidator { var rqEtags []string if len(req.Header.Get("If-None-Match")) > 0 { rqEtags = strings.Split(req.Header.Get("If-None-Match"), ",") @@ -80,7 +68,7 @@ func ParseRequest(req *http.Request) *Revalidator { for i, tag := range rqEtags { rqEtags[i] = strings.Trim(tag, " ") } - validator := Revalidator{ + validator := types.Revalidator{ NotModified: len(rqEtags) > 0, RequestETags: rqEtags, } @@ -106,3 +94,116 @@ func ParseRequest(req *http.Request) *Revalidator { return &validator } + +func DecodeMapping(item []byte) (*StorageMapper, error) { + mapping := &StorageMapper{} + e := json.Unmarshal(item, mapping) + + return mapping, e +} + +func MappingElection(provider types.Storer, item []byte, req *http.Request, validator *types.Revalidator) (resultFresh *http.Response, resultStale *http.Response, e error) { + mapping := &StorageMapper{} + + if len(item) != 0 { + mapping, e = DecodeMapping(item) + if e != nil { + return resultFresh, resultStale, e + } + } + + for keyName, keyItem := range mapping.Mapping { + valid := true + + for hname, hval := range keyItem.VariedHeaders { + if req.Header.Get(hname) != strings.Join(hval, ", ") { + valid = false + + break + } + } + + if !valid { + continue + } + + ValidateETagFromHeader(keyItem.Etag, validator) + + if validator.Matched { + // If the key is fresh enough. + if time.Since(keyItem.FreshTime) < 0 { + response := provider.Get(keyName) + if response != nil { + if resultFresh, e = http.ReadResponse(bufio.NewReader(bytes.NewBuffer(response)), req); e != nil { + return resultFresh, resultStale, e + } + + return resultFresh, resultStale, e + } + } + + // If the key is still stale. + if time.Since(keyItem.StaleTime) < 0 { + response := provider.Get(keyName) + if response != nil { + if resultStale, e = http.ReadResponse(bufio.NewReader(bytes.NewBuffer(response)), req); e != nil { + return resultFresh, resultStale, e + } + } + } + } + } + + return resultFresh, resultStale, e +} + +type KeyIndex struct { + StoredAt time.Time `json:"stored_at,omitempty"` + FreshTime time.Time `json:"fresh_time,omitempty"` + StaleTime time.Time `json:"stale_time,omitempty"` + VariedHeaders map[string][]string `json:"varied_headers,omitempty"` + Etag string `json:"etag,omitempty"` + RealKey string `json:"real_key,omitempty"` +} +type StorageMapper struct { + Mapping map[string]*KeyIndex `json:"mapping,omitempty"` +} + +func MappingUpdater(key string, item []byte, now, freshTime, staleTime time.Time, variedHeaders http.Header, etag, realKey string) (val []byte, e error) { + mapping := &StorageMapper{} + if len(item) != 0 { + e = json.Unmarshal(item, mapping) + if e != nil { + return nil, e + } + } + + if mapping.Mapping == nil { + mapping.Mapping = make(map[string]*KeyIndex) + } + + var pbvariedeheader map[string][]string + if variedHeaders != nil { + pbvariedeheader = make(map[string][]string) + } + + for k, v := range variedHeaders { + pbvariedeheader[k] = append(pbvariedeheader[k], v...) + } + + mapping.Mapping[key] = &KeyIndex{ + StoredAt: now, + FreshTime: freshTime, + StaleTime: staleTime, + VariedHeaders: pbvariedeheader, + Etag: etag, + RealKey: realKey, + } + + val, e = json.Marshal(mapping) + if e != nil { + return nil, e + } + + return val, e +} diff --git a/plugins/traefik/override/storage/abstractProvider.go b/plugins/traefik/override/storage/abstractProvider.go index 65b1c81f3..577d5fade 100644 --- a/plugins/traefik/override/storage/abstractProvider.go +++ b/plugins/traefik/override/storage/abstractProvider.go @@ -20,7 +20,7 @@ const ( type StorerInstanciator func(configurationtypes.AbstractConfigurationInterface) (types.Storer, error) func NewStorages(configuration configurationtypes.AbstractConfigurationInterface) ([]types.Storer, error) { - s, err := CacheConnectionFactory(configuration) + s, err := Factory(configuration) return []types.Storer{s}, err } diff --git a/plugins/traefik/override/storage/cacheProvider.go b/plugins/traefik/override/storage/cacheProvider.go index 8331486c6..dff56aa09 100644 --- a/plugins/traefik/override/storage/cacheProvider.go +++ b/plugins/traefik/override/storage/cacheProvider.go @@ -22,8 +22,8 @@ type Cache struct { var sharedCache *Cache -// CacheConnectionFactory function create new Cache instance -func CacheConnectionFactory(c t.AbstractConfigurationInterface) (types.Storer, error) { +// Factory function create new Cache instance +func Factory(c t.AbstractConfigurationInterface) (types.Storer, error) { provider := cache.New(1 * time.Second) if sharedCache == nil { @@ -38,6 +38,11 @@ func (provider *Cache) Name() string { return "CACHE" } +// Uuid returns an unique identifier +func (provider *Cache) Uuid() string { + return "" +} + // ListKeys method returns the list of existing keys func (provider *Cache) ListKeys() []string { var keys []string @@ -74,8 +79,44 @@ func (provider *Cache) Get(key string) []byte { return result.([]byte) } +// GetMultiLevel tries to load the key and check if one of linked keys is a fresh/stale candidate. +func (provider *Cache) GetMultiLevel(key string, req *http.Request, validator *types.Revalidator) (fresh *http.Response, stale *http.Response) { + result, found := provider.Cache.Get("IDX_" + key) + if !found { + return + } + + fresh, stale, _ = rfc.MappingElection(provider, result.([]byte), req, validator) + + return +} + +// SetMultiLevel tries to store the key with the given value and update the mapping key to store metadata. +func (provider *Cache) SetMultiLevel(baseKey, variedKey string, value []byte, variedHeaders http.Header, etag string, duration time.Duration, realKey string) error { + now := time.Now() + + var e error + + provider.Cache.Set(variedKey, value, duration) + + mappingKey := "IDX_" + baseKey + item, ok := provider.Cache.Get(mappingKey) + var val []byte + if ok { + val = item.([]byte) + } + + val, e = rfc.MappingUpdater(variedKey, val, now, now.Add(duration), now.Add(duration+provider.stale), variedHeaders, etag, realKey) + if e != nil { + return e + } + + provider.Cache.Set(mappingKey, val, 0) + return nil +} + // Prefix method returns the populated response if exists, empty response then -func (provider *Cache) Prefix(key string, req *http.Request, validator *rfc.Revalidator) *http.Response { +func (provider *Cache) Prefix(key string, req *http.Request, validator *types.Revalidator) *http.Response { var result *http.Response provider.Cache.Range(func(k, v interface{}) bool { @@ -103,7 +144,6 @@ func (provider *Cache) Prefix(key string, req *http.Request, validator *rfc.Reva // Set method will store the response in Cache provider func (provider *Cache) Set(key string, value []byte, duration time.Duration) error { provider.Cache.Set(key, value, duration) - provider.Cache.Set(StalePrefix+key, value, provider.stale+duration) return nil } diff --git a/plugins/traefik/override/storage/types/types.go b/plugins/traefik/override/storage/types/types.go index 67031c15e..e41007745 100644 --- a/plugins/traefik/override/storage/types/types.go +++ b/plugins/traefik/override/storage/types/types.go @@ -3,19 +3,41 @@ package types import ( "net/http" "time" - - "github.com/darkweak/souin/pkg/rfc" ) +type Revalidator struct { + Matched bool + IfNoneMatchPresent bool + IfMatchPresent bool + IfModifiedSincePresent bool + IfUnmodifiedSincePresent bool + IfUnmotModifiedSincePresent bool + NeedRevalidation bool + NotModified bool + IfModifiedSince time.Time + IfUnmodifiedSince time.Time + IfNoneMatch []string + IfMatch []string + RequestETags []string + ResponseETag string +} + +const DefaultStorageName = "CACHE" +const OneYearDuration = 365 * 24 * time.Hour + type Storer interface { MapKeys(prefix string) map[string]string ListKeys() []string - Prefix(key string, req *http.Request, validator *rfc.Revalidator) *http.Response Get(key string) []byte Set(key string, value []byte, duration time.Duration) error Delete(key string) DeleteMany(key string) Init() error Name() string + Uuid() string Reset() error + + // Multi level storer to handle fresh/stale at once + GetMultiLevel(key string, req *http.Request, validator *Revalidator) (fresh *http.Response, stale *http.Response) + SetMultiLevel(baseKey, variedKey string, value []byte, variedHeaders http.Header, etag string, duration time.Duration, realKey string) error } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/configurationtypes/types.go b/plugins/traefik/vendor/github.com/darkweak/souin/configurationtypes/types.go index 10729fea2..9af88bd8d 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/configurationtypes/types.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/configurationtypes/types.go @@ -183,6 +183,10 @@ type URL struct { // CacheProvider config type CacheProvider struct { + // Uuid to identify a unique instance. + Uuid string + // Found to determine if we can use that storage. + Found bool `json:"found" yaml:"found"` // URL to connect to the storage system. URL string `json:"url" yaml:"url"` // Path to the configuration file. @@ -247,7 +251,7 @@ type DefaultCache struct { Timeout Timeout `json:"timeout" yaml:"timeout"` TTL Duration `json:"ttl" yaml:"ttl"` DefaultCacheControl string `json:"default_cache_control" yaml:"default_cache_control"` - MaxBodyBytes uint64 `json:"max_cachable_body_bytes" yaml:"max_cachable_body_bytes"` + MaxBodyBytes uint64 `json:"max_cacheable_body_bytes" yaml:"max_cacheable_body_bytes"` DisableCoalescing bool `json:"disable_coalescing" yaml:"disable_coalescing"` } @@ -326,11 +330,6 @@ func (d *DefaultCache) GetRegex() Regex { return d.Regex } -// GetSimpleFS returns simpleFS configuration -func (d *DefaultCache) GetSimpleFS() CacheProvider { - return d.SimpleFS -} - // GetTimeout returns the backend and cache timeouts func (d *DefaultCache) GetTimeout() Timeout { return d.Timeout @@ -341,6 +340,11 @@ func (d *DefaultCache) GetTTL() time.Duration { return d.TTL.Duration } +// GetSimpleFS returns simplefs configuration +func (d *DefaultCache) GetSimpleFS() CacheProvider { + return d.SimpleFS +} + // GetStale returns the stale duration func (d *DefaultCache) GetStale() time.Duration { return d.Stale.Duration @@ -376,6 +380,7 @@ type DefaultCacheInterface interface { GetEtcd() CacheProvider GetMode() string GetOtter() CacheProvider + GetNats() CacheProvider GetNuts() CacheProvider GetOlric() CacheProvider GetRedis() CacheProvider @@ -389,6 +394,7 @@ type DefaultCacheInterface interface { GetTTL() time.Duration GetDefaultCacheControl() string GetMaxBodyBytes() uint64 + IsCoalescingDisable() bool } // APIEndpoint is the minimal structure to define an endpoint diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/cache.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/cache.go index fbdd89cdc..c4051f0c4 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/cache.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/cache.go @@ -19,6 +19,10 @@ type cacheContext struct { cacheName string } +func (*cacheContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *cacheContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { cc.cacheName = defaultCacheName if c.GetDefaultCache().GetCacheName() != "" { diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/graphql.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/graphql.go index 8557de703..dad0e70b8 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/graphql.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/graphql.go @@ -21,6 +21,32 @@ type graphQLContext struct { custom bool } +func (g *graphQLContext) SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request { + ctx := req.Context() + ctx = context.WithValue(ctx, GraphQL, g.custom) + ctx = context.WithValue(ctx, HashBody, "") + ctx = context.WithValue(ctx, IsMutationRequest, false) + + if g.custom && req.Body != nil { + b := bytes.NewBuffer([]byte{}) + _, _ = io.Copy(b, req.Body) + req.Body = io.NopCloser(b) + baseRq.Body = io.NopCloser(b) + + if b.Len() > 0 { + if isMutation(b.Bytes()) { + ctx = context.WithValue(ctx, IsMutationRequest, true) + } else { + h := sha256.New() + h.Write(b.Bytes()) + ctx = context.WithValue(ctx, HashBody, fmt.Sprintf("-%x", h.Sum(nil))) + } + } + } + + return req.WithContext(ctx) +} + func (g *graphQLContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { g.custom = true diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/key.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/key.go index 352fd05b6..9a468cf2e 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/key.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/key.go @@ -12,6 +12,7 @@ const ( Key ctxKey = "souin_ctx.CACHE_KEY" DisplayableKey ctxKey = "souin_ctx.DISPLAYABLE_KEY" IgnoredHeaders ctxKey = "souin_ctx.IGNORE_HEADERS" + Hashed ctxKey = "souin_ctx.HASHED" ) type keyContext struct { @@ -20,10 +21,17 @@ type keyContext struct { disable_method bool disable_query bool disable_scheme bool - hash bool displayable bool + hash bool headers []string + template string overrides []map[*regexp.Regexp]keyContext + + initializer func(r *http.Request) *http.Request +} + +func (*keyContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req } func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { @@ -35,89 +43,76 @@ func (g *keyContext) SetupContext(c configurationtypes.AbstractConfigurationInte g.disable_scheme = k.DisableScheme g.hash = k.Hash g.displayable = !k.Hide + g.template = k.Template g.headers = k.Headers g.overrides = make([]map[*regexp.Regexp]keyContext, 0) - for _, cacheKey := range c.GetCacheKeys() { - for r, v := range cacheKey { - g.overrides = append(g.overrides, map[*regexp.Regexp]keyContext{r.Regexp: { - disable_body: v.DisableBody, - disable_host: v.DisableHost, - disable_method: v.DisableMethod, - disable_query: v.DisableQuery, - disable_scheme: v.DisableScheme, - hash: v.Hash, - displayable: !v.Hide, - headers: v.Headers, - }}) - } + // for _, cacheKey := range c.GetCacheKeys() { + // for r, v := range cacheKey { + // g.overrides = append(g.overrides, map[*regexp.Regexp]keyContext{r.Regexp: { + // disable_body: v.DisableBody, + // disable_host: v.DisableHost, + // disable_method: v.DisableMethod, + // disable_query: v.DisableQuery, + // disable_scheme: v.DisableScheme, + // hash: v.Hash, + // displayable: !v.Hide, + // template: v.Template, + // headers: v.Headers, + // }}) + // } + // } + + g.initializer = func(r *http.Request) *http.Request { + return r } } -func (g *keyContext) SetContext(req *http.Request) *http.Request { - key := req.URL.Path - var headers []string +func parseKeyInformations(req *http.Request, kCtx keyContext) (query, body, host, scheme, method, headerValues string, headers []string, displayable, hash bool) { + displayable = kCtx.displayable + hash = kCtx.hash - scheme := "http-" - if req.TLS != nil { - scheme = "https-" - } - query := "" - body := "" - host := "" - method := "" - headerValues := "" - displayable := g.displayable - - if !g.disable_query && len(req.URL.RawQuery) > 0 { + if !kCtx.disable_query && len(req.URL.RawQuery) > 0 { query += "?" + req.URL.RawQuery } - if !g.disable_body { + if !kCtx.disable_body { body = req.Context().Value(HashBody).(string) } - if !g.disable_host { + if !kCtx.disable_host { host = req.Host + "-" } - if !g.disable_method { + if !kCtx.disable_scheme { + scheme = "http-" + if req.TLS != nil { + scheme = "https-" + } + } + + if !kCtx.disable_method { method = req.Method + "-" } - headers = g.headers - for _, hn := range g.headers { + headers = kCtx.headers + for _, hn := range kCtx.headers { headerValues += "-" + req.Header.Get(hn) } + return +} + +func (g *keyContext) computeKey(req *http.Request) (key string, headers []string, hash, displayable bool) { + key = req.URL.Path + query, body, host, scheme, method, headerValues, headers, displayable, hash := parseKeyInformations(req, *g) + hasOverride := false for _, current := range g.overrides { for k, v := range current { if k.MatchString(req.RequestURI) { - displayable = v.displayable - host = "" - method = "" - query = "" - if !v.disable_query && len(req.URL.RawQuery) > 0 { - query = "?" + req.URL.RawQuery - } - if !v.disable_body { - body = req.Context().Value(HashBody).(string) - } - if !v.disable_method { - method = req.Method + "-" - } - if !v.disable_host { - host = req.Host + "-" - } - if len(v.headers) > 0 { - headerValues = "" - for _, hn := range v.headers { - headers = v.headers - headerValues += "-" + req.Header.Get(hn) - } - } + query, body, host, scheme, method, headerValues, headers, displayable, hash = parseKeyInformations(req, v) hasOverride = true break } @@ -128,13 +123,26 @@ func (g *keyContext) SetContext(req *http.Request) *http.Request { } } + key = method + scheme + host + key + query + body + headerValues + + return +} + +func (g *keyContext) SetContext(req *http.Request) *http.Request { + rq := g.initializer(req) + key, headers, hash, displayable := g.computeKey(rq) + return req.WithContext( context.WithValue( context.WithValue( context.WithValue( - req.Context(), - Key, - method+scheme+host+key+query+body+headerValues, + context.WithValue( + req.Context(), + Key, + key, + ), + Hashed, + hash, ), IgnoredHeaders, headers, diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/method.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/method.go index 1e6417cbb..ee772a3de 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/method.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/method.go @@ -16,6 +16,10 @@ type methodContext struct { custom bool } +func (*methodContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (m *methodContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { m.allowedVerbs = defaultVerbs if len(c.GetDefaultCache().GetAllowedHTTPVerbs()) != 0 { diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/mode.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/mode.go index b041abb15..ec2d5221d 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/mode.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/mode.go @@ -13,6 +13,10 @@ type ModeContext struct { Strict, Bypass_request, Bypass_response bool } +func (*ModeContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (mc *ModeContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { mode := c.GetDefaultCache().GetMode() mc.Bypass_request = mode == "bypass" || mode == "bypass_request" @@ -24,4 +28,4 @@ func (mc *ModeContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(req.Context(), Mode, mc)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*ModeContext)(nil) diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/now.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/now.go index 898cc18fe..d0d4e0f3b 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/now.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/now.go @@ -12,6 +12,10 @@ const Now ctxKey = "souin_ctx.NOW" type nowContext struct{} +func (*nowContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (cc *nowContext) SetupContext(_ configurationtypes.AbstractConfigurationInterface) {} func (cc *nowContext) SetContext(req *http.Request) *http.Request { diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/timeout.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/timeout.go index 6c737d24c..4da27d984 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/timeout.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/timeout.go @@ -22,6 +22,10 @@ type timeoutContext struct { timeoutCache, timeoutBackend time.Duration } +func (*timeoutContext) SetContextWithBaseRequest(req *http.Request, _ *http.Request) *http.Request { + return req +} + func (t *timeoutContext) SetupContext(c configurationtypes.AbstractConfigurationInterface) { t.timeoutBackend = defaultTimeoutBackend t.timeoutCache = defaultTimeoutCache @@ -38,4 +42,4 @@ func (t *timeoutContext) SetContext(req *http.Request) *http.Request { return req.WithContext(context.WithValue(context.WithValue(ctx, TimeoutCancel, cancel), TimeoutCache, t.timeoutCache)) } -var _ ctx = (*cacheContext)(nil) +var _ ctx = (*timeoutContext)(nil) diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/context/types.go b/plugins/traefik/vendor/github.com/darkweak/souin/context/types.go index 38bf5ed19..34e56363f 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/context/types.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/context/types.go @@ -12,6 +12,7 @@ type ( ctx interface { SetupContext(c configurationtypes.AbstractConfigurationInterface) SetContext(req *http.Request) *http.Request + SetContextWithBaseRequest(req *http.Request, baseRq *http.Request) *http.Request } Context struct { @@ -53,6 +54,6 @@ func (c *Context) SetBaseContext(req *http.Request) *http.Request { return c.Mode.SetContext(c.Timeout.SetContext(c.Method.SetContext(c.CacheName.SetContext(c.Now.SetContext(req))))) } -func (c *Context) SetContext(req *http.Request) *http.Request { - return c.Key.SetContext(c.GraphQL.SetContext(req)) +func (c *Context) SetContext(req *http.Request, baseRq *http.Request) *http.Request { + return c.Key.SetContext(c.GraphQL.SetContextWithBaseRequest(req, baseRq)) } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go index 34735b88d..53f899690 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/middleware.go @@ -22,7 +22,9 @@ import ( "github.com/darkweak/souin/pkg/storage/types" "github.com/darkweak/souin/pkg/surrogate" "github.com/darkweak/souin/pkg/surrogate/providers" + "github.com/google/uuid" "github.com/pquerna/cachecontrol/cacheobject" + "golang.org/x/sync/singleflight" ) func NewHTTPCacheHandler(c configurationtypes.AbstractConfigurationInterface) *SouinBaseHandler { @@ -65,6 +67,7 @@ func NewHTTPCacheHandler(c configurationtypes.AbstractConfigurationInterface) *S context: ctx, bufPool: bufPool, storersLen: len(storers), + singleflightPool: singleflight.Group{}, } } @@ -78,14 +81,17 @@ type SouinBaseHandler struct { SurrogateKeyStorer providers.SurrogateInterface DefaultMatchedUrl configurationtypes.URL context *context.Context + singleflightPool singleflight.Group bufPool *sync.Pool storersLen int } -type upsreamError struct{} +var Upstream50xError = upstream50xError{} -func (upsreamError) Error() string { - return "Upstream error" +type upstream50xError struct{} + +func (upstream50xError) Error() string { + return "Upstream 50x error" } func isCacheableCode(code int) bool { @@ -97,6 +103,15 @@ func isCacheableCode(code int) bool { return false } +func canStatusCodeEmptyContent(code int) bool { + switch code { + case 204, 301, 405: + return true + } + + return false +} + func canBypassAuthorizationRestriction(headers http.Header, bypassed []string) bool { for _, header := range bypassed { if strings.ToLower(header) == "authorization" { @@ -113,35 +128,37 @@ func (s *SouinBaseHandler) Store( requestCc *cacheobject.RequestCacheDirectives, cachedKey string, ) error { - if !isCacheableCode(customWriter.statusCode) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + statusCode := customWriter.GetStatusCode() + if !isCacheableCode(statusCode) { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - switch customWriter.statusCode { + switch statusCode { case 500, 502, 503, 504: - return new(upsreamError) + return Upstream50xError } return nil } - if customWriter.Header().Get("Cache-Control") == "" { + headerName, cacheControl := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if cacheControl == "" { // TODO see with @mnot if mandatory to not store the response when no Cache-Control given. // if s.DefaultMatchedUrl.DefaultCacheControl == "" { - // customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + // customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) // return nil // } - customWriter.Header().Set("Cache-Control", s.DefaultMatchedUrl.DefaultCacheControl) + customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl) } - responseCc, _ := cacheobject.ParseResponseCacheControl(customWriter.Header().Get("Cache-Control")) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(customWriter.Header(), headerName)) if responseCc == nil { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=INVALID-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=INVALID-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) return nil } modeContext := rq.Context().Value(context.Mode).(*context.ModeContext) if !modeContext.Bypass_request && (responseCc.PrivatePresent || rq.Header.Get("Authorization") != "") && !canBypassAuthorizationRestriction(customWriter.Header(), rq.Context().Value(context.IgnoredHeaders).([]string)) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=PRIVATE-OR-AUTHENTICATED-RESPONSE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=PRIVATE-OR-AUTHENTICATED-RESPONSE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) return nil } @@ -156,40 +173,55 @@ func (s *SouinBaseHandler) Store( } } + hasFreshness := false ma := currentMatchedURL.TTL.Duration if responseCc.SMaxAge >= 0 { ma = time.Duration(responseCc.SMaxAge) * time.Second } else if responseCc.MaxAge >= 0 { ma = time.Duration(responseCc.MaxAge) * time.Second - } - if ma > currentMatchedURL.TTL.Duration { - ma = currentMatchedURL.TTL.Duration + } else if customWriter.Header().Get("Expires") != "" { + exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires")) + if err != nil { + return nil + } + + duration := time.Until(exp) + if duration <= 0 || duration > 10*types.OneYearDuration { + return nil + } + + date, _ := time.Parse(time.RFC1123, customWriter.Header().Get("Date")) + if date.Sub(exp) > 0 { + return nil + } + + ma = duration + hasFreshness = true } now := rq.Context().Value(context.Now).(time.Time) date, _ := http.ParseTime(now.Format(http.TimeFormat)) - customWriter.Headers.Set(rfc.StoredTTLHeader, ma.String()) + customWriter.Header().Set(rfc.StoredTTLHeader, ma.String()) ma = ma - time.Since(date) - if exp := customWriter.Header().Get("Expires"); exp != "" { - delta, _ := time.Parse(exp, time.RFC1123) - if sub := delta.Sub(now); sub > 0 { - ma = sub - } - } - status := fmt.Sprintf("%s; fwd=uri-miss", rq.Context().Value(context.CacheName)) if (modeContext.Bypass_request || !requestCc.NoStore) && - (modeContext.Bypass_response || !responseCc.NoStore) { - headers := customWriter.Headers.Clone() + (modeContext.Bypass_response || !responseCc.NoStore || hasFreshness) { + headers := customWriter.Header().Clone() for hname, shouldDelete := range responseCc.NoCache { if shouldDelete { headers.Del(hname) } } + + customWriter.mutex.Lock() + b := customWriter.Buf.Bytes() + bLen := customWriter.Buf.Len() + customWriter.mutex.Unlock() + res := http.Response{ - StatusCode: customWriter.statusCode, - Body: io.NopCloser(bytes.NewBuffer(customWriter.Buf.Bytes())), + StatusCode: statusCode, + Body: io.NopCloser(bytes.NewBuffer(b)), Header: headers, } @@ -197,17 +229,23 @@ func (s *SouinBaseHandler) Store( res.Header.Set("Date", now.Format(http.TimeFormat)) } if res.Header.Get("Content-Length") == "" { - res.Header.Set("Content-Length", fmt.Sprint(customWriter.Buf.Len())) + res.Header.Set("Content-Length", fmt.Sprint(bLen)) + } + respBodyMaxSize := int(s.Configuration.GetDefaultCache().GetMaxBodyBytes()) + if respBodyMaxSize > 0 && bLen > respBodyMaxSize { + customWriter.Header().Set("Cache-Status", status+"; detail=UPSTREAM-RESPONSE-TOO-LARGE; key="+rfc.GetCacheKeyFromCtx(rq.Context())) + + return nil } res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length")) response, err := httputil.DumpResponse(&res, true) - if err == nil { + if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) { variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header) if isVaryStar { // "Implies that the response is uncacheable" status += "; detail=UPSTREAM-VARY-STAR" } else { - cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders) + variedKey := cachedKey + rfc.GetVariedCacheKey(rq, variedHeaders) var wg sync.WaitGroup mu := sync.Mutex{} @@ -216,11 +254,25 @@ func (s *SouinBaseHandler) Store( case <-rq.Context().Done(): status += "; detail=REQUEST-CANCELED-OR-UPSTREAM-BROKEN-PIPE" default: + vhs := http.Header{} + for _, hname := range variedHeaders { + hn := strings.Split(hname, ":") + vhs.Set(hn[0], rq.Header.Get(hn[0])) + } for _, storer := range s.Storers { wg.Add(1) go func(currentStorer types.Storer) { defer wg.Done() - if currentStorer.Set(cachedKey, response, ma) != nil { + if currentStorer.SetMultiLevel( + cachedKey, + variedKey, + response, + vhs, + res.Header.Get("Etag"), ma, + variedKey, + ) == nil { + res.Request = rq + } else { mu.Lock() fails = append(fails, fmt.Sprintf("; detail=%s-INSERTION-ERROR", currentStorer.Name())) mu.Unlock() @@ -232,7 +284,7 @@ func (s *SouinBaseHandler) Store( if len(fails) < s.storersLen { go func(rs http.Response, key string) { _ = s.SurrogateKeyStorer.Store(&rs, key, "") - }(res, cachedKey) + }(res, variedKey) status += "; stored" } @@ -241,15 +293,25 @@ func (s *SouinBaseHandler) Store( } } } + + } else { + status += "; detail=UPSTREAM-ERROR-OR-EMPTY-RESPONSE" } } else { status += "; detail=NO-STORE-DIRECTIVE" } - customWriter.Headers.Set("Cache-Status", status+"; key="+rfc.GetCacheKeyFromCtx(rq.Context())) + customWriter.Header().Set("Cache-Status", status+"; key="+rfc.GetCacheKeyFromCtx(rq.Context())) return nil } +type singleflightValue struct { + body []byte + headers http.Header + requestHeaders http.Header + code int +} + func (s *SouinBaseHandler) Upstream( customWriter *CustomWriter, rq *http.Request, @@ -257,71 +319,138 @@ func (s *SouinBaseHandler) Upstream( requestCc *cacheobject.RequestCacheDirectives, cachedKey string, ) error { - if err := next(customWriter, rq); err != nil { - customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - return err + var recoveredFromErr error = nil + defer func() { + // In case of "http.ErrAbortHandler" panic, + // prevent singleflight from wrapping it into "singleflight.panicError". + if r := recover(); r != nil { + err, ok := r.(error) + // Sometimes, the error is a string. + if !ok || errors.Is(err, http.ErrAbortHandler) { + recoveredFromErr = http.ErrAbortHandler + } else { + panic(err) + } + } + }() + + singleflightCacheKey := cachedKey + if s.Configuration.GetDefaultCache().IsCoalescingDisable() { + singleflightCacheKey += uuid.NewString() } + sfValue, err, _ := s.singleflightPool.Do(singleflightCacheKey, func() (interface{}, error) { + if e := next(customWriter, rq); e != nil { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=SERVE-HTTP-ERROR", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + return nil, e + } - s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - if !isCacheableCode(customWriter.statusCode) { - customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - switch customWriter.statusCode { - case 500, 502, 503, 504: - return new(upsreamError) + statusCode := customWriter.GetStatusCode() + if !isCacheableCode(statusCode) { + customWriter.Header().Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=UNCACHEABLE-STATUS-CODE", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) + + switch statusCode { + case 500, 502, 503, 504: + return nil, Upstream50xError + } } - return nil - } + headerName, cacheControl := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if cacheControl == "" { + customWriter.Header().Set(headerName, s.DefaultMatchedUrl.DefaultCacheControl) + } - if customWriter.Header().Get("Cache-Control") == "" { - // TODO see with @mnot if mandatory to not store the response when no Cache-Control given. - // if s.DefaultMatchedUrl.DefaultCacheControl == "" { - // customWriter.Headers.Set("Cache-Status", fmt.Sprintf("%s; fwd=uri-miss; key=%s; detail=EMPTY-RESPONSE-CACHE-CONTROL", rq.Context().Value(context.CacheName), rfc.GetCacheKeyFromCtx(rq.Context()))) - // return nil - // } - customWriter.Header().Set("Cache-Control", s.DefaultMatchedUrl.DefaultCacheControl) + err := s.Store(customWriter, rq, requestCc, cachedKey) + defer customWriter.Buf.Reset() + + return singleflightValue{ + body: customWriter.Buf.Bytes(), + headers: customWriter.Header().Clone(), + requestHeaders: rq.Header, + code: statusCode, + }, err + }) + if recoveredFromErr != nil { + panic(recoveredFromErr) + } + if err != nil { + return err } - select { - case <-rq.Context().Done(): - return baseCtx.Canceled - default: - return s.Store(customWriter, rq, requestCc, cachedKey) + if sfWriter, ok := sfValue.(singleflightValue); ok { + if vary := sfWriter.headers.Get("Vary"); vary != "" { + variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(sfWriter.headers) + if !isVaryStar { + for _, vh := range variedHeaders { + if rq.Header.Get(vh) != sfWriter.requestHeaders.Get(vh) { + // cachedKey += rfc.GetVariedCacheKey(rq, variedHeaders) + return s.Upstream(customWriter, rq, next, requestCc, cachedKey) + } + } + } + } + _, _ = customWriter.Write(sfWriter.body) + // Yaegi sucks, we can't use maps. + for k := range sfWriter.headers { + customWriter.Header().Set(k, sfWriter.headers.Get(k)) + } + customWriter.WriteHeader(sfWriter.code) } + + return nil } -func (s *SouinBaseHandler) Revalidate(validator *rfc.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string) error { - err := next(customWriter, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) +func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handlerFunc, customWriter *CustomWriter, rq *http.Request, requestCc *cacheobject.RequestCacheDirectives, cachedKey string, uri string) error { + singleflightCacheKey := cachedKey + if s.Configuration.GetDefaultCache().IsCoalescingDisable() { + singleflightCacheKey += uuid.NewString() + } + sfValue, err, _ := s.singleflightPool.Do(singleflightCacheKey, func() (interface{}, error) { + err := next(customWriter, rq) + s.SurrogateKeyStorer.Invalidate(rq.Method, customWriter.Header()) - if err == nil { - if validator.IfUnmodifiedSincePresent && customWriter.statusCode != http.StatusNotModified { - customWriter.Buf.Reset() - for h, v := range customWriter.Headers { - if len(v) > 0 { - customWriter.Rw.Header().Set(h, strings.Join(v, ", ")) - } + statusCode := customWriter.GetStatusCode() + if err == nil { + if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified { + customWriter.Buf.Reset() + customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) + + return nil, errors.New("") } - customWriter.Rw.WriteHeader(http.StatusPreconditionFailed) - return errors.New("") + if statusCode != http.StatusNotModified { + err = s.Store(customWriter, rq, requestCc, cachedKey) + } } - if customWriter.statusCode != http.StatusNotModified { - err = s.Store(customWriter, rq, requestCc, cachedKey) + customWriter.Header().Set( + "Cache-Status", + fmt.Sprintf( + "%s; fwd=request; fwd-status=%d; key=%s; detail=REQUEST-REVALIDATION", + rq.Context().Value(context.CacheName), + statusCode, + rfc.GetCacheKeyFromCtx(rq.Context()), + ), + ) + + defer customWriter.Buf.Reset() + return singleflightValue{ + body: customWriter.Buf.Bytes(), + headers: customWriter.Header().Clone(), + code: statusCode, + }, err + }) + + if sfWriter, ok := sfValue.(singleflightValue); ok { + _, _ = customWriter.Write(sfWriter.body) + // Yaegi sucks, we can't use maps. + for k := range sfWriter.headers { + customWriter.Header().Set(k, sfWriter.headers.Get(k)) } + customWriter.WriteHeader(sfWriter.code) } - customWriter.Header().Set( - "Cache-Status", - fmt.Sprintf( - "%s; fwd=request; fwd-status=%d; key=%s; detail=REQUEST-REVALIDATION", - rq.Context().Value(context.CacheName), - customWriter.statusCode, - rfc.GetCacheKeyFromCtx(rq.Context()), - ), - ) return err } @@ -339,89 +468,128 @@ func (s *SouinBaseHandler) HandleInternally(r *http.Request) (bool, http.Handler } type handlerFunc = func(http.ResponseWriter, *http.Request) error +type statusCodeLogger struct { + http.ResponseWriter + statusCode int +} + +func (s *statusCodeLogger) WriteHeader(code int) { + s.statusCode = code + s.ResponseWriter.WriteHeader(code) +} func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, next handlerFunc) error { - b, handler := s.HandleInternally(rq) - if b { + if b, handler := s.HandleInternally(rq); b { handler(rw, rq) return nil } + req := s.context.SetBaseContext(rq) + cacheName := req.Context().Value(context.CacheName).(string) - rq = s.context.SetBaseContext(rq) - cacheName := rq.Context().Value(context.CacheName).(string) - if rq.Header.Get("Upgrade") == "websocket" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { + if rq.Header.Get("Upgrade") == "websocket" || rq.Header.Get("Accept") == "text/event-stream" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI") - return next(rw, rq) + return next(rw, req) } - if !rq.Context().Value(context.SupportedMethod).(bool) { + if !req.Context().Value(context.SupportedMethod).(bool) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=UNSUPPORTED-METHOD") + nrw := &statusCodeLogger{ + ResponseWriter: rw, + statusCode: 0, + } - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(nrw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) + + if err == nil && req.Method != http.MethodGet && nrw.statusCode < http.StatusBadRequest { + // Invalidate related GET keys when the method is not allowed and the response is valid + req.Method = http.MethodGet + keyname := s.context.SetContext(req, rq).Context().Value(context.Key).(string) + for _, storer := range s.Storers { + storer.Delete("IDX_" + keyname) + } + } return err } - requestCc, coErr := cacheobject.ParseRequestCacheControl(rq.Header.Get("Cache-Control")) + requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control")) + + modeContext := req.Context().Value(context.Mode).(*context.ModeContext) - modeContext := rq.Context().Value(context.Mode).(*context.ModeContext) if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - rq = s.context.SetContext(rq) + req = s.context.SetContext(req, rq) + + isMutationRequest := false + // Yaegi sucks AGAIN, it considers the value as nil if we directly try to cast as bool + mutationRequestValue := req.Context().Value(context.IsMutationRequest) + if mutationRequestValue != nil { + isMutationRequest = mutationRequestValue.(bool) + } - // Yaegi sucks again, it considers false as true - isMutationRequest := rq.Context().Value(context.IsMutationRequest).(bool) if isMutationRequest { rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=IS-MUTATION-REQUEST") - err := next(rw, rq) - s.SurrogateKeyStorer.Invalidate(rq.Method, rw.Header()) + err := next(rw, req) + s.SurrogateKeyStorer.Invalidate(req.Method, rw.Header()) return err } - cachedKey := rq.Context().Value(context.Key).(string) + cachedKey := req.Context().Value(context.Key).(string) + + // Need to copy URL path before calling next because it can alter the URI + uri := req.URL.Path bufPool := s.bufPool.Get().(*bytes.Buffer) bufPool.Reset() defer s.bufPool.Put(bufPool) - customWriter := NewCustomWriter(rq, rw, bufPool) + customWriter := NewCustomWriter(req, rw, bufPool) + go func(req *http.Request, crw *CustomWriter) { <-req.Context().Done() crw.mutex.Lock() crw.headersSent = true crw.mutex.Unlock() - }(rq, customWriter) + }(req, customWriter) + if modeContext.Bypass_request || !requestCc.NoCache { - validator := rfc.ParseRequest(rq) - var response *http.Response + validator := rfc.ParseRequest(req) + var fresh, stale *http.Response + var storerName string for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(cachedKey, rq, validator) - if response != nil { + fresh, stale = currentStorer.GetMultiLevel(cachedKey, req, validator) + + if fresh != nil || stale != nil { + storerName = currentStorer.Name() break } } - if response != nil && (!modeContext.Strict || rfc.ValidateCacheControl(response, requestCc)) { + headerName, _ := s.SurrogateKeyStorer.GetSurrogateControl(customWriter.Header()) + if fresh != nil && (!modeContext.Strict || rfc.ValidateCacheControl(fresh, requestCc)) { + response := fresh if validator.ResponseETag != "" && validator.Matched { - rfc.SetCacheStatusHeader(response, "DEFAULT") - customWriter.Headers = response.Header + rfc.SetCacheStatusHeader(response, storerName) + for h, v := range response.Header { + customWriter.Header()[h] = v + } if validator.NotModified { - customWriter.statusCode = http.StatusNotModified + customWriter.WriteHeader(http.StatusNotModified) customWriter.Buf.Reset() _, _ = customWriter.Send() return nil } - customWriter.statusCode = response.StatusCode + customWriter.WriteHeader(response.StatusCode) _, _ = io.Copy(customWriter.Buf, response.Body) _, _ = customWriter.Send() @@ -429,48 +597,48 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if validator.NeedRevalidation { - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) _, _ = customWriter.Send() return err } - if resCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")); resCc.NoCachePresent { - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + if resCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, headerName)); resCc.NoCachePresent { + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) _, _ = customWriter.Send() return err } - rfc.SetCacheStatusHeader(response, "DEFAULT") + rfc.SetCacheStatusHeader(response, storerName) if !modeContext.Strict || rfc.ValidateMaxAgeCachedResponse(requestCc, response) != nil { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + for h, v := range response.Header { + customWriter.Header()[h] = v + } + customWriter.WriteHeader(response.StatusCode) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() return err } - } else if response == nil && !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { - for _, currentStorer := range s.Storers { - response = currentStorer.Prefix(storage.StalePrefix+cachedKey, rq, validator) - if response != nil { - break - } - } + } else if !requestCc.OnlyIfCached && (requestCc.MaxStaleSet || requestCc.MaxStale > -1) { + response := stale + if nil != response && (!modeContext.Strict || rfc.ValidateCacheControl(response, requestCc)) { addTime, _ := time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) - rfc.SetCacheStatusHeader(response, "DEFAULT") + rfc.SetCacheStatusHeader(response, storerName) - responseCc, _ := cacheobject.ParseResponseCacheControl(response.Header.Get("Cache-Control")) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) if responseCc.StaleWhileRevalidate > 0 { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + for h, v := range response.Header { + customWriter.Header()[h] = v + } + customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() - customWriter = NewCustomWriter(rq, rw, bufPool) - go func(v *rfc.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string) { - _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk) - }(validator, customWriter, rq, next, requestCc, cachedKey) + customWriter = NewCustomWriter(req, rw, bufPool) + go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) { + _ = s.Revalidate(v, goNext, goCw, goRq, goCc, goCk, goUri) + }(validator, customWriter, req, next, requestCc, cachedKey, uri) buf := s.bufPool.Get().(*bytes.Buffer) buf.Reset() defer s.bufPool.Put(buf) @@ -479,15 +647,20 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation { - rq.Header["If-None-Match"] = append(rq.Header["If-None-Match"], validator.ResponseETag) - err := s.Revalidate(validator, next, customWriter, rq, requestCc, cachedKey) + req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag) + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) + statusCode := customWriter.GetStatusCode() if err != nil { if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 { - code := fmt.Sprintf("; fwd-status=%d", customWriter.statusCode) - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + code := fmt.Sprintf("; fwd-status=%d", statusCode) rfc.HitStaleCache(&response.Header) response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } + customWriter.WriteHeader(response.StatusCode) + customWriter.Buf.Reset() _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() @@ -500,11 +673,14 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if customWriter.statusCode == http.StatusNotModified { + if statusCode == http.StatusNotModified { if !validator.Matched { - rfc.SetCacheStatusHeader(response, "DEFAULT") - customWriter.statusCode = response.StatusCode - customWriter.Headers = response.Header + rfc.SetCacheStatusHeader(response, storerName) + customWriter.WriteHeader(response.StatusCode) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } _, _ = io.Copy(customWriter.Buf, response.Body) _, _ = customWriter.Send() @@ -512,8 +688,8 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n } } - if customWriter.statusCode != http.StatusNotModified && validator.Matched { - customWriter.statusCode = http.StatusNotModified + if statusCode != http.StatusNotModified && validator.Matched { + customWriter.WriteHeader(http.StatusNotModified) customWriter.Buf.Reset() _, _ = customWriter.Send() @@ -525,27 +701,64 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n return err } - if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, response, int(addTime.Seconds())) != nil { - customWriter.Headers = response.Header - customWriter.statusCode = response.StatusCode + if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, responseCc, response, int(addTime.Seconds())) != nil { + customWriter.WriteHeader(response.StatusCode) rfc.HitStaleCache(&response.Header) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } _, _ = io.Copy(customWriter.Buf, response.Body) _, err := customWriter.Send() return err } } + } else if stale != nil { + response := stale + addTime, _ := time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) + + if !modeContext.Strict || rfc.ValidateMaxAgeCachedStaleResponse(requestCc, responseCc, response, int(addTime.Seconds())) != nil { + _, _ = time.ParseDuration(response.Header.Get(rfc.StoredTTLHeader)) + rfc.SetCacheStatusHeader(response, storerName) + + responseCc, _ := cacheobject.ParseResponseCacheControl(rfc.HeaderAllCommaSepValuesString(response.Header, "Cache-Control")) + + if responseCc.StaleIfError > -1 || requestCc.StaleIfError > 0 { + err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri) + statusCode := customWriter.GetStatusCode() + if err != nil { + code := fmt.Sprintf("; fwd-status=%d", statusCode) + rfc.HitStaleCache(&response.Header) + response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code) + // Yaegi sucks, we can't use maps. + for k := range response.Header { + customWriter.Header().Set(k, response.Header.Get(k)) + } + customWriter.WriteHeader(response.StatusCode) + customWriter.Buf.Reset() + _, _ = io.Copy(customWriter.Buf, response.Body) + _, err := customWriter.Send() + + return err + } + } + + } } } errorCacheCh := make(chan error) - go func() { - errorCacheCh <- s.Upstream(customWriter, rq, next, requestCc, cachedKey) - }() + + go func(vr *http.Request, cw *CustomWriter) { + errorCacheCh <- s.Upstream(cw, vr, next, requestCc, cachedKey) + }(req, customWriter) select { - case <-rq.Context().Done(): - switch rq.Context().Err() { + case <-req.Context().Done(): + + switch req.Context().Err() { case baseCtx.DeadlineExceeded: customWriter.WriteHeader(http.StatusGatewayTimeout) rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=DEADLINE-EXCEEDED") @@ -556,9 +769,15 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n default: return nil } + case v := <-errorCacheCh: - if v == nil { + + switch v { + case nil: + _, _ = customWriter.Send() + case Upstream50xError: _, _ = customWriter.Send() + return nil } return v } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go index ddbccfdbd..97b479bd7 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/middleware/writer.go @@ -3,10 +3,9 @@ package middleware import ( "bytes" "net/http" - "strings" + "strconv" "sync" - "github.com/darkweak/go-esi/esi" "github.com/darkweak/souin/pkg/rfc" ) @@ -37,19 +36,28 @@ type CustomWriter struct { headersSent bool mutex *sync.Mutex statusCode int - // size int } // Header will write the response headers func (r *CustomWriter) Header() http.Header { r.mutex.Lock() defer r.mutex.Unlock() - if r.headersSent { + + if r.headersSent || r.Req.Context().Err() != nil { return http.Header{} } + return r.Rw.Header() } +// GetStatusCode returns the response status code +func (r *CustomWriter) GetStatusCode() int { + r.mutex.Lock() + defer r.mutex.Unlock() + + return r.statusCode +} + // WriteHeader will write the response headers func (r *CustomWriter) WriteHeader(code int) { r.mutex.Lock() @@ -57,14 +65,13 @@ func (r *CustomWriter) WriteHeader(code int) { if r.headersSent { return } - r.Headers = r.Rw.Header() r.statusCode = code - // r.headersSent = true - // r.Rw.WriteHeader(code) } // Write will write the response body func (r *CustomWriter) Write(b []byte) (int, error) { + r.mutex.Lock() + defer r.mutex.Unlock() r.Buf.Grow(len(b)) _, _ = r.Buf.Write(b) @@ -73,24 +80,20 @@ func (r *CustomWriter) Write(b []byte) (int, error) { // Send delays the response to handle Cache-Status func (r *CustomWriter) Send() (int, error) { - contentLength := r.Headers.Get(rfc.StoredLengthHeader) - if contentLength != "" { - r.Header().Set("Content-Length", contentLength) - } defer r.Buf.Reset() - b := esi.Parse(r.Buf.Bytes(), r.Req) - for h, v := range r.Headers { - if len(v) > 0 { - r.Rw.Header().Set(h, strings.Join(v, ", ")) - } + storedLength := r.Header().Get(rfc.StoredLengthHeader) + if storedLength != "" { + r.Header().Set("Content-Length", storedLength) + } + b := r.Buf.Bytes() + if len(b) != 0 { + r.Header().Set("Content-Length", strconv.Itoa(len(b))) } r.Header().Del(rfc.StoredLengthHeader) r.Header().Del(rfc.StoredTTLHeader) if !r.headersSent { - - // r.Rw.Header().Set("Content-Length", fmt.Sprintf("%d", len(b))) - r.Rw.WriteHeader(r.statusCode) + r.Rw.WriteHeader(r.GetStatusCode()) r.headersSent = true } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/rfc/revalidation.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/rfc/revalidation.go index c0eb363b3..ccb6e73be 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/rfc/revalidation.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/rfc/revalidation.go @@ -1,29 +1,17 @@ package rfc import ( + "bufio" + "bytes" + "encoding/json" "net/http" "strings" "time" -) -type Revalidator struct { - Matched bool - IfNoneMatchPresent bool - IfMatchPresent bool - IfModifiedSincePresent bool - IfUnmodifiedSincePresent bool - IfUnmotModifiedSincePresent bool - NeedRevalidation bool - NotModified bool - IfModifiedSince time.Time - IfUnmodifiedSince time.Time - IfNoneMatch []string - IfMatch []string - RequestETags []string - ResponseETag string -} + "github.com/darkweak/souin/pkg/storage/types" +) -func ValidateETagFromHeader(etag string, validator *Revalidator) { +func ValidateETagFromHeader(etag string, validator *types.Revalidator) { validator.ResponseETag = etag validator.NeedRevalidation = validator.NeedRevalidation || validator.ResponseETag != "" validator.Matched = validator.ResponseETag == "" || (validator.ResponseETag != "" && len(validator.RequestETags) == 0) @@ -72,7 +60,7 @@ func ValidateETagFromHeader(etag string, validator *Revalidator) { } } -func ParseRequest(req *http.Request) *Revalidator { +func ParseRequest(req *http.Request) *types.Revalidator { var rqEtags []string if len(req.Header.Get("If-None-Match")) > 0 { rqEtags = strings.Split(req.Header.Get("If-None-Match"), ",") @@ -80,7 +68,7 @@ func ParseRequest(req *http.Request) *Revalidator { for i, tag := range rqEtags { rqEtags[i] = strings.Trim(tag, " ") } - validator := Revalidator{ + validator := types.Revalidator{ NotModified: len(rqEtags) > 0, RequestETags: rqEtags, } @@ -106,3 +94,116 @@ func ParseRequest(req *http.Request) *Revalidator { return &validator } + +func DecodeMapping(item []byte) (*StorageMapper, error) { + mapping := &StorageMapper{} + e := json.Unmarshal(item, mapping) + + return mapping, e +} + +func MappingElection(provider types.Storer, item []byte, req *http.Request, validator *types.Revalidator) (resultFresh *http.Response, resultStale *http.Response, e error) { + mapping := &StorageMapper{} + + if len(item) != 0 { + mapping, e = DecodeMapping(item) + if e != nil { + return resultFresh, resultStale, e + } + } + + for keyName, keyItem := range mapping.Mapping { + valid := true + + for hname, hval := range keyItem.VariedHeaders { + if req.Header.Get(hname) != strings.Join(hval, ", ") { + valid = false + + break + } + } + + if !valid { + continue + } + + ValidateETagFromHeader(keyItem.Etag, validator) + + if validator.Matched { + // If the key is fresh enough. + if time.Since(keyItem.FreshTime) < 0 { + response := provider.Get(keyName) + if response != nil { + if resultFresh, e = http.ReadResponse(bufio.NewReader(bytes.NewBuffer(response)), req); e != nil { + return resultFresh, resultStale, e + } + + return resultFresh, resultStale, e + } + } + + // If the key is still stale. + if time.Since(keyItem.StaleTime) < 0 { + response := provider.Get(keyName) + if response != nil { + if resultStale, e = http.ReadResponse(bufio.NewReader(bytes.NewBuffer(response)), req); e != nil { + return resultFresh, resultStale, e + } + } + } + } + } + + return resultFresh, resultStale, e +} + +type KeyIndex struct { + StoredAt time.Time `json:"stored_at,omitempty"` + FreshTime time.Time `json:"fresh_time,omitempty"` + StaleTime time.Time `json:"stale_time,omitempty"` + VariedHeaders map[string][]string `json:"varied_headers,omitempty"` + Etag string `json:"etag,omitempty"` + RealKey string `json:"real_key,omitempty"` +} +type StorageMapper struct { + Mapping map[string]*KeyIndex `json:"mapping,omitempty"` +} + +func MappingUpdater(key string, item []byte, now, freshTime, staleTime time.Time, variedHeaders http.Header, etag, realKey string) (val []byte, e error) { + mapping := &StorageMapper{} + if len(item) != 0 { + e = json.Unmarshal(item, mapping) + if e != nil { + return nil, e + } + } + + if mapping.Mapping == nil { + mapping.Mapping = make(map[string]*KeyIndex) + } + + var pbvariedeheader map[string][]string + if variedHeaders != nil { + pbvariedeheader = make(map[string][]string) + } + + for k, v := range variedHeaders { + pbvariedeheader[k] = append(pbvariedeheader[k], v...) + } + + mapping.Mapping[key] = &KeyIndex{ + StoredAt: now, + FreshTime: freshTime, + StaleTime: staleTime, + VariedHeaders: pbvariedeheader, + Etag: etag, + RealKey: realKey, + } + + val, e = json.Marshal(mapping) + if e != nil { + return nil, e + } + + return val, e +} diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/abstractProvider.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/abstractProvider.go index 65b1c81f3..577d5fade 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/abstractProvider.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/abstractProvider.go @@ -20,7 +20,7 @@ const ( type StorerInstanciator func(configurationtypes.AbstractConfigurationInterface) (types.Storer, error) func NewStorages(configuration configurationtypes.AbstractConfigurationInterface) ([]types.Storer, error) { - s, err := CacheConnectionFactory(configuration) + s, err := Factory(configuration) return []types.Storer{s}, err } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/cacheProvider.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/cacheProvider.go index 8331486c6..dff56aa09 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/cacheProvider.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/cacheProvider.go @@ -22,8 +22,8 @@ type Cache struct { var sharedCache *Cache -// CacheConnectionFactory function create new Cache instance -func CacheConnectionFactory(c t.AbstractConfigurationInterface) (types.Storer, error) { +// Factory function create new Cache instance +func Factory(c t.AbstractConfigurationInterface) (types.Storer, error) { provider := cache.New(1 * time.Second) if sharedCache == nil { @@ -38,6 +38,11 @@ func (provider *Cache) Name() string { return "CACHE" } +// Uuid returns an unique identifier +func (provider *Cache) Uuid() string { + return "" +} + // ListKeys method returns the list of existing keys func (provider *Cache) ListKeys() []string { var keys []string @@ -74,8 +79,44 @@ func (provider *Cache) Get(key string) []byte { return result.([]byte) } +// GetMultiLevel tries to load the key and check if one of linked keys is a fresh/stale candidate. +func (provider *Cache) GetMultiLevel(key string, req *http.Request, validator *types.Revalidator) (fresh *http.Response, stale *http.Response) { + result, found := provider.Cache.Get("IDX_" + key) + if !found { + return + } + + fresh, stale, _ = rfc.MappingElection(provider, result.([]byte), req, validator) + + return +} + +// SetMultiLevel tries to store the key with the given value and update the mapping key to store metadata. +func (provider *Cache) SetMultiLevel(baseKey, variedKey string, value []byte, variedHeaders http.Header, etag string, duration time.Duration, realKey string) error { + now := time.Now() + + var e error + + provider.Cache.Set(variedKey, value, duration) + + mappingKey := "IDX_" + baseKey + item, ok := provider.Cache.Get(mappingKey) + var val []byte + if ok { + val = item.([]byte) + } + + val, e = rfc.MappingUpdater(variedKey, val, now, now.Add(duration), now.Add(duration+provider.stale), variedHeaders, etag, realKey) + if e != nil { + return e + } + + provider.Cache.Set(mappingKey, val, 0) + return nil +} + // Prefix method returns the populated response if exists, empty response then -func (provider *Cache) Prefix(key string, req *http.Request, validator *rfc.Revalidator) *http.Response { +func (provider *Cache) Prefix(key string, req *http.Request, validator *types.Revalidator) *http.Response { var result *http.Response provider.Cache.Range(func(k, v interface{}) bool { @@ -103,7 +144,6 @@ func (provider *Cache) Prefix(key string, req *http.Request, validator *rfc.Reva // Set method will store the response in Cache provider func (provider *Cache) Set(key string, value []byte, duration time.Duration) error { provider.Cache.Set(key, value, duration) - provider.Cache.Set(StalePrefix+key, value, provider.stale+duration) return nil } diff --git a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/types/types.go b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/types/types.go index 67031c15e..e41007745 100644 --- a/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/types/types.go +++ b/plugins/traefik/vendor/github.com/darkweak/souin/pkg/storage/types/types.go @@ -3,19 +3,41 @@ package types import ( "net/http" "time" - - "github.com/darkweak/souin/pkg/rfc" ) +type Revalidator struct { + Matched bool + IfNoneMatchPresent bool + IfMatchPresent bool + IfModifiedSincePresent bool + IfUnmodifiedSincePresent bool + IfUnmotModifiedSincePresent bool + NeedRevalidation bool + NotModified bool + IfModifiedSince time.Time + IfUnmodifiedSince time.Time + IfNoneMatch []string + IfMatch []string + RequestETags []string + ResponseETag string +} + +const DefaultStorageName = "CACHE" +const OneYearDuration = 365 * 24 * time.Hour + type Storer interface { MapKeys(prefix string) map[string]string ListKeys() []string - Prefix(key string, req *http.Request, validator *rfc.Revalidator) *http.Response Get(key string) []byte Set(key string, value []byte, duration time.Duration) error Delete(key string) DeleteMany(key string) Init() error Name() string + Uuid() string Reset() error + + // Multi level storer to handle fresh/stale at once + GetMultiLevel(key string, req *http.Request, validator *Revalidator) (fresh *http.Response, stale *http.Response) + SetMultiLevel(baseKey, variedKey string, value []byte, variedHeaders http.Header, etag string, duration time.Duration, realKey string) error }