diff --git a/experimental/middleware/transaction.go b/experimental/middleware/transaction.go new file mode 100644 index 000000000..f143b2ab6 --- /dev/null +++ b/experimental/middleware/transaction.go @@ -0,0 +1,51 @@ +// Copyright 2024 Juan Pablo Tosso and the OWASP Coraza contributors +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "context" + + "github.com/corazawaf/coraza/v3/collection" + "github.com/corazawaf/coraza/v3/debuglog" + "github.com/corazawaf/coraza/v3/experimental/plugins/plugintypes" + "github.com/corazawaf/coraza/v3/internal/corazawaf" + "github.com/corazawaf/coraza/v3/types" + "github.com/corazawaf/coraza/v3/types/variables" +) + +type TransactionState interface { + // ID returns the ID of the transaction. + ID() string + + // Variables returns the TransactionVariables of the transaction. + Variables() plugintypes.TransactionVariables + + // Collection returns a collection from the transaction. + Collection(idx variables.RuleVariable) collection.Collection + + // DebugLogger returns the logger for this transaction. + DebugLogger() debuglog.Logger + + // IsInterrupted will return true if the transaction was interrupted + IsInterrupted() bool + + // Interruption returns the transaction interruption + Interruption() *types.Interruption + + // MatchedRules returns the matched rules of the transaction + MatchedRules() []types.MatchedRule + + // LastPhase that was evaluated + LastPhase() types.RulePhase +} + +// GetContext returns the context of the transaction and a boolean indicating if the +// transaction has a context or not. +func GetContext(tx TransactionState) (context.Context, bool) { + itx, ok := tx.(*corazawaf.Transaction) + if !ok { + return context.Background(), false + } + return itx.Context(), true +} diff --git a/experimental/plugins/plugintypes/transaction.go b/experimental/plugins/plugintypes/transaction.go index c520088d0..aae497daa 100644 --- a/experimental/plugins/plugintypes/transaction.go +++ b/experimental/plugins/plugintypes/transaction.go @@ -4,6 +4,8 @@ package plugintypes import ( + "context" + "github.com/corazawaf/coraza/v3/collection" "github.com/corazawaf/coraza/v3/debuglog" "github.com/corazawaf/coraza/v3/types" @@ -34,7 +36,11 @@ type TransactionState interface { // CaptureField captures a field. CaptureField(idx int, value string) + // LastPhase that was evaluated LastPhase() types.RulePhase + + // Context returns the context of the transaction. + Context() context.Context } // TransactionVariables has pointers to all the variables of the transaction diff --git a/http/middleware.go b/http/middleware.go index f8c4477ee..9c77c868c 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -16,6 +16,7 @@ import ( "github.com/corazawaf/coraza/v3" "github.com/corazawaf/coraza/v3/experimental" + "github.com/corazawaf/coraza/v3/experimental/middleware" "github.com/corazawaf/coraza/v3/types" ) @@ -113,7 +114,34 @@ func processRequest(tx types.Transaction, req *http.Request) (*types.Interruptio return tx.ProcessRequestBody() } +// Options is a set of options for the middleware +type Options struct { + // BeforeCloseTransaction is called before the transaction is closed, after the response has + // been written. It is useful to complement observability signals like metrics, traces and + // logs by providing additional context about the transaction and the rules that were matched. + BeforeCloseTransaction func(tx middleware.TransactionState) +} + +var defaultOptions = Options{ + BeforeCloseTransaction: func(middleware.TransactionState) {}, +} + +func (o *Options) loadDefaults() { + if o.BeforeCloseTransaction == nil { + o.BeforeCloseTransaction = defaultOptions.BeforeCloseTransaction + } +} + func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler { + return wrapHandler(waf, h, defaultOptions) +} + +func WrapHandlerWithOptions(waf coraza.WAF, h http.Handler, opts Options) http.Handler { + opts.loadDefaults() + return wrapHandler(waf, h, opts) +} + +func wrapHandler(waf coraza.WAF, h http.Handler, opts Options) http.Handler { if waf == nil { return h } @@ -132,9 +160,13 @@ func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { tx := newTX(r) + txs := tx.(middleware.TransactionState) defer func() { // We run phase 5 rules and create audit logs (if enabled) tx.ProcessLogging() + + opts.BeforeCloseTransaction(txs) + // we remove temporary files and free some memory if err := tx.Close(); err != nil { tx.DebugLogger().Error().Err(err).Msg("Failed to close the transaction") diff --git a/http/middleware_test.go b/http/middleware_test.go index d1bc2cf08..d20b4026d 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -10,6 +10,7 @@ package http import ( "bufio" "bytes" + "context" "fmt" "io" "mime/multipart" @@ -22,6 +23,7 @@ import ( "github.com/corazawaf/coraza/v3" "github.com/corazawaf/coraza/v3/debuglog" + "github.com/corazawaf/coraza/v3/experimental/middleware" "github.com/corazawaf/coraza/v3/experimental/plugins/macro" "github.com/corazawaf/coraza/v3/internal/corazawaf" "github.com/corazawaf/coraza/v3/internal/seclang" @@ -213,6 +215,7 @@ func TestChainEvaluation(t *testing.T) { } func errLogger(t *testing.T) func(rule types.MatchedRule) { + t.Helper() return func(rule types.MatchedRule) { t.Log(rule.ErrorLog()) } @@ -606,3 +609,28 @@ func TestHandlerAPI(t *testing.T) { }) } } + +type ctxKey struct{} + +func TestWrapHandlerWithOptions(t *testing.T) { + waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) + delegateHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + + ctx := context.WithValue(context.Background(), ctxKey{}, "value") + req, _ := http.NewRequestWithContext(ctx, "GET", "https://www.coraza.io/test", nil) + + wrappedHandler := WrapHandlerWithOptions(waf, delegateHandler, Options{ + BeforeCloseTransaction: func(tx middleware.TransactionState) { + ctx, ok := middleware.GetContext(tx) + if !ok { + t.Error("unexpected context") + } + + if want, have := "value", ctx.Value(ctxKey{}).(string); want != have { + t.Errorf("unexpected context value, want: %s, have: %s", want, have) + } + }, + }).(http.HandlerFunc) + + wrappedHandler(httptest.NewRecorder(), req) +} diff --git a/internal/corazawaf/transaction.go b/internal/corazawaf/transaction.go index b41efc8f4..2d0b5f886 100644 --- a/internal/corazawaf/transaction.go +++ b/internal/corazawaf/transaction.go @@ -126,6 +126,10 @@ func (tx *Transaction) ID() string { return tx.id } +func (tx *Transaction) Context() context.Context { + return tx.context +} + func (tx *Transaction) Variables() plugintypes.TransactionVariables { return &tx.variables }