Skip to content

Commit

Permalink
fix(writer): buffer race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
darkweak committed Nov 15, 2024
1 parent dec0714 commit 556cf83
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 54 deletions.
61 changes: 45 additions & 16 deletions pkg/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ func (s *SouinBaseHandler) Upstream(
}

err := s.Store(customWriter, rq, requestCc, cachedKey, uri)
defer customWriter.Buf.Reset()
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})

return singleflightValue{
body: customWriter.Buf.Bytes(),
Expand Down Expand Up @@ -521,7 +523,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
statusCode := customWriter.GetStatusCode()
if err == nil {
if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified {
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.Rw.WriteHeader(http.StatusPreconditionFailed)

return nil, errors.New("")
Expand All @@ -542,7 +546,9 @@ func (s *SouinBaseHandler) Revalidate(validator *core.Revalidator, next handlerF
),
)

defer customWriter.Buf.Reset()
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
return singleflightValue{
body: customWriter.Buf.Bytes(),
headers: customWriter.Header().Clone(),
Expand Down Expand Up @@ -598,6 +604,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n

req := s.context.SetBaseContext(rq)
cacheName := req.Context().Value(context.CacheName).(string)

if rq.Header.Get("Upgrade") == "websocket" || rq.Header.Get("Accept") == "text/event-stream" || (s.ExcludeRegex != nil && s.ExcludeRegex.MatchString(rq.RequestURI)) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=EXCLUDED-REQUEST-URI")
return next(rw, req)
Expand Down Expand Up @@ -689,14 +696,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
if validator.NotModified {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, _ = customWriter.Send()

return nil
}

customWriter.WriteHeader(response.StatusCode)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.Send()

return nil
Expand All @@ -722,7 +733,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
customWriter.WriteHeader(response.StatusCode)
s.Configuration.GetLogger().Debugf("Serve from cache %+v", req)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()
prometheus.Increment(prometheus.CachedResponseCounter)

Expand All @@ -742,7 +755,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()
customWriter = NewCustomWriter(req, rw, bufPool)
go func(v *core.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
Expand All @@ -766,14 +781,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
maps.Copy(customWriter.Header(), response.Header)
customWriter.WriteHeader(response.StatusCode)
customWriter.Buf.Reset()
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
}
rw.WriteHeader(http.StatusGatewayTimeout)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, err := customWriter.Send()

return err
Expand All @@ -784,7 +803,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
rfc.SetCacheStatusHeader(response, storerName)
customWriter.WriteHeader(response.StatusCode)
maps.Copy(customWriter.Header(), response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.Send()

return err
Expand All @@ -793,7 +814,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n

if statusCode != http.StatusNotModified && validator.Matched {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, _ = customWriter.Send()

return err
Expand All @@ -808,7 +831,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
maps.Copy(customWriter.Header(), response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand All @@ -822,7 +847,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
maps.Copy(customWriter.Header(), response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand All @@ -846,8 +873,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
response.Header.Set("Cache-Status", response.Header.Get("Cache-Status")+code)
maps.Copy(customWriter.Header(), response.Header)
customWriter.WriteHeader(response.StatusCode)
customWriter.Buf.Reset()
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand Down
6 changes: 6 additions & 0 deletions pkg/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ type CustomWriter struct {
statusCode int
}

func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) {
r.mutex.Lock()
callback(r.Buf)
r.mutex.Unlock()
}

// Header will write the response headers
func (r *CustomWriter) Header() http.Header {
r.mutex.Lock()
Expand Down
64 changes: 45 additions & 19 deletions plugins/traefik/override/middleware/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (s *SouinBaseHandler) Store(
ma = time.Duration(responseCc.SMaxAge) * time.Second
} else if responseCc.MaxAge >= 0 {
ma = time.Duration(responseCc.MaxAge) * time.Second
} else if customWriter.Header().Get("Expires") != "" {
} else if !modeContext.Bypass_response && customWriter.Header().Get("Expires") != "" {
exp, err := time.Parse(time.RFC1123, customWriter.Header().Get("Expires"))
if err != nil {
return nil
Expand Down Expand Up @@ -249,7 +249,7 @@ func (s *SouinBaseHandler) Store(
}
res.Header.Set(rfc.StoredLengthHeader, res.Header.Get("Content-Length"))
response, err := httputil.DumpResponse(&res, true)
if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode)) {
if err == nil && (bLen > 0 || canStatusCodeEmptyContent(statusCode) || s.hasAllowedAdditionalStatusCodesToCache(statusCode)) {
variedHeaders, isVaryStar := rfc.VariedHeaderAllCommaSepValues(res.Header)
if isVaryStar {
// "Implies that the response is uncacheable"
Expand Down Expand Up @@ -372,7 +372,9 @@ func (s *SouinBaseHandler) Upstream(
}

err := s.Store(customWriter, rq, requestCc, cachedKey)
defer customWriter.Buf.Reset()
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})

return singleflightValue{
body: customWriter.Buf.Bytes(),
Expand Down Expand Up @@ -423,7 +425,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler
statusCode := customWriter.GetStatusCode()
if err == nil {
if validator.IfUnmodifiedSincePresent && statusCode != http.StatusNotModified {
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
customWriter.Rw.WriteHeader(http.StatusPreconditionFailed)

return nil, errors.New("")
Expand All @@ -444,7 +448,9 @@ func (s *SouinBaseHandler) Revalidate(validator *types.Revalidator, next handler
),
)

defer customWriter.Buf.Reset()
defer customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
return singleflightValue{
body: customWriter.Buf.Bytes(),
headers: customWriter.Header().Clone(),
Expand Down Expand Up @@ -493,6 +499,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
handler(rw, rq)
return nil
}

req := s.context.SetBaseContext(rq)
cacheName := req.Context().Value(context.CacheName).(string)

Expand Down Expand Up @@ -526,7 +533,6 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
requestCc, coErr := cacheobject.ParseRequestCacheControl(rfc.HeaderAllCommaSepValuesString(req.Header, "Cache-Control"))

modeContext := req.Context().Value(context.Mode).(*context.ModeContext)

if !modeContext.Bypass_request && (coErr != nil || requestCc == nil) {
rw.Header().Set("Cache-Status", cacheName+"; fwd=bypass; detail=CACHE-CONTROL-EXTRACTION-ERROR")

Expand Down Expand Up @@ -593,14 +599,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
if validator.NotModified {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, _ = customWriter.Send()

return nil
}

customWriter.WriteHeader(response.StatusCode)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.Send()

return nil
Expand All @@ -624,7 +634,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.Header()[h] = v
}
customWriter.WriteHeader(response.StatusCode)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand All @@ -643,7 +655,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
}
customWriter.WriteHeader(response.StatusCode)
rfc.HitStaleCache(&response.Header)
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()
customWriter = NewCustomWriter(req, rw, bufPool)
go func(v *types.Revalidator, goCw *CustomWriter, goRq *http.Request, goNext func(http.ResponseWriter, *http.Request) error, goCc *cacheobject.RequestCacheDirectives, goCk string, goUri string) {
Expand All @@ -656,7 +670,7 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
return err
}

if responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
if modeContext.Bypass_response || responseCc.MustRevalidate || responseCc.NoCachePresent || validator.NeedRevalidation {
req.Header["If-None-Match"] = append(req.Header["If-None-Match"], validator.ResponseETag)
err := s.Revalidate(validator, next, customWriter, req, requestCc, cachedKey, uri)
statusCode := customWriter.GetStatusCode()
Expand All @@ -670,14 +684,18 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.Header().Set(k, response.Header.Get(k))
}
customWriter.WriteHeader(response.StatusCode)
customWriter.Buf.Reset()
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
}
rw.WriteHeader(http.StatusGatewayTimeout)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, err := customWriter.Send()

return err
Expand All @@ -691,7 +709,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
for k := range response.Header {
customWriter.Header().Set(k, response.Header.Get(k))
}
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, _ = customWriter.Send()

return err
Expand All @@ -700,7 +720,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n

if statusCode != http.StatusNotModified && validator.Matched {
customWriter.WriteHeader(http.StatusNotModified)
customWriter.Buf.Reset()
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
})
_, _ = customWriter.Send()

return err
Expand All @@ -718,7 +740,9 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
for k := range response.Header {
customWriter.Header().Set(k, response.Header.Get(k))
}
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand Down Expand Up @@ -747,8 +771,10 @@ func (s *SouinBaseHandler) ServeHTTP(rw http.ResponseWriter, rq *http.Request, n
customWriter.Header().Set(k, response.Header.Get(k))
}
customWriter.WriteHeader(response.StatusCode)
customWriter.Buf.Reset()
_, _ = io.Copy(customWriter.Buf, response.Body)
customWriter.handleBuffer(func(b *bytes.Buffer) {
b.Reset()
_, _ = io.Copy(b, response.Body)
})
_, err := customWriter.Send()

return err
Expand Down
6 changes: 6 additions & 0 deletions plugins/traefik/override/middleware/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ type CustomWriter struct {
statusCode int
}

func (r *CustomWriter) handleBuffer(callback func(*bytes.Buffer)) {
r.mutex.Lock()
callback(r.Buf)
r.mutex.Unlock()
}

// Header will write the response headers
func (r *CustomWriter) Header() http.Header {
r.mutex.Lock()
Expand Down
Loading

0 comments on commit 556cf83

Please sign in to comment.