diff --git a/Makefile b/Makefile index f0e285ab..15d05f94 100644 --- a/Makefile +++ b/Makefile @@ -127,5 +127,5 @@ snapshot: docker run --rm --privileged -v $(PWD):/go/tmp \ -v /var/run/docker.sock:/var/run/docker.sock \ -w /go/tmp \ - ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip-publish + ghcr.io/goreleaser/goreleaser-cross:latest --clean --snapshot --skip publish diff --git a/protocol/protocol.go b/protocol/protocol.go index c489e931..9fbe6dbe 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -414,19 +414,21 @@ type Operation struct { ServerIP net.IP SNI string CertID string + ForwardingSvc int64 CustomFuncName string JaegerSpan []byte ReqContext []byte } func (o *Operation) String() string { - return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s]", + return fmt.Sprintf("[Opcode: %v, SKI: %v, Digest: %02x, Client IP: %s, Server IP: %s, SNI: %s, Forwarding Service: %v]", o.Opcode, o.SKI, o.Digest, o.ClientIP, o.ServerIP, o.SNI, + o.ForwardingSvc, ) } diff --git a/server/server.go b/server/server.go index 6286bb59..3506cbd3 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ import ( "net" "net/rpc" "os" + "strings" "sync" "time" @@ -152,6 +153,12 @@ type Sealer interface { Unseal(*protocol.Operation) ([]byte, error) } +// ClientInfo has information on the client of the connection +type ClientInfo struct { + Name string + CertSerial string +} + // handler is associated with a connection and contains bookkeeping // information used across goroutines. The channel tokens limits the // concurrency: before reading a request a token is extracted, when @@ -166,6 +173,7 @@ type handler struct { conn net.Conn timeout time.Duration closed bool + c *ClientInfo } func (h *handler) close(err error) { @@ -197,6 +205,12 @@ func (h *handler) handle(pkt *protocol.Packet, reqTime time.Time) { } else { resp = h.s.unlimitedDo(pkt, h.name) } + + if resp.op.ErrorVal() != protocol.ErrNone { + // log the client certificate information on the connection if the request failed so the caller is apparent + reqID, _ := getOperationRequestID(&pkt.Operation) + log.Errorf("operation from client %s client cert serial: %s errored. sni %s ski %s cert %s request-id %s", h.c.Name, h.c.CertSerial, resp.op.SNI, resp.op.SKI.String(), resp.op.CertID, reqID) + } logRequestExecDuration(pkt.Operation.Opcode, start, resp.op.ErrorVal()) respPkt := protocol.Packet{ Header: protocol.Header{ @@ -289,32 +303,61 @@ func makeErrResponse(pkt *protocol.Packet, err protocol.Error) response { func addOperationRequestID(op *protocol.Operation) string { reqContext := make(map[string]interface{}) var reqID string - var gen bool if len(op.ReqContext) > 0 { - if err := json.Unmarshal(op.ReqContext, &reqContext); err == nil { - if v, ok := reqContext["request_id"]; ok { - return v.(string) - } else { - gen = true - } - } else { - log.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext) + if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr != nil { + log.Error(fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext)) + return reqID + } + } + + if v, ok := reqContext["request_id"]; ok { + return v.(string) + } + + reqID = uuid.New().String() + reqContext["request_id"] = reqID + b, err := json.Marshal(reqContext) + if err == nil { + op.ReqContext = b + } else { + log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext) + reqID = "" + } + return reqID +} + +func getOperationRequestID(op *protocol.Operation) (reqID string, err error) { + reqContext := make(map[string]interface{}) + if len(op.ReqContext) == 0 { + return + } + if decodeErr := json.Unmarshal(op.ReqContext, &reqContext); decodeErr == nil { + if v, ok := reqContext["request_id"]; ok { + return v.(string), nil } + } else { + err = fmt.Errorf("malformed operation.ReqContext %v, ignoring error", op.ReqContext) + log.Error(err) + return } + return +} - if len(op.ReqContext) == 0 || gen { - reqID = uuid.New().String() - reqContext["request_id"] = reqID - b, err := json.Marshal(reqContext) - if err == nil { - op.ReqContext = b +func getClientInfoFromCerts(certs []*x509.Certificate) *ClientInfo { + cln := []string(nil) + srls := []string(nil) + for _, cert := range certs { + if cert.Subject.CommonName != "" { + cln = append(cln, cert.Subject.CommonName) } else { - log.Errorf("error marshaling operation.ReqContext %v, ignoring error", reqContext) - reqID = "" + cln = append(cln, cert.DNSNames...) } + srls = append(srls, cert.SerialNumber.String()) } - return reqID + name := strings.Join(cln, " , ") + serial := strings.Join(srls, " , ") + return &ClientInfo{Name: name, CertSerial: serial} } func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { @@ -328,7 +371,7 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { reqID := addOperationRequestID(&pkt.Operation) span.SetTag("request_id", reqID) - log.Debugf("connection %s: limited=false opcode=%s id=%d sni=%s ip=%s ski=%v request-id=%s", + log.Debugf("connection %s: limited=false opcode= %s id=%d sni= %s ip= %s ski= %v request-id= %s", connName, pkt.Operation.Opcode, pkt.Header.ID, @@ -412,14 +455,14 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { sig, err := key.Sign(rand.Reader, pkt.Operation.Payload, crypto.Hash(0)) if err != nil { - log.Errorf("Connection: %s: sni=%s ski=%v request-id=%s: Signing error: %v: request-id:%s:", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID) + log.Errorf("Connection: %s: sni= %s ski= %v request-id= %s: Signing error: %v", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err) // This indicates that a remote keyserver is being used var remoteConfigurationErr RemoteConfigurationErr if errors.As(err, &remoteConfigurationErr) { - log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err, reqID) + log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err) return makeErrResponse(pkt, protocol.ErrRemoteConfiguration) } else { - log.Errorf("Connection %v: sni=%s ski=%v request-id=%s: %s: Signing error: %v request-id:%s\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err, reqID) + log.Errorf("Connection %v: sni= %s ski= %v request-id= %s: %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err) return makeErrResponse(pkt, protocol.ErrCrypto) } } @@ -430,15 +473,15 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { key, err := s.keys.Get(ctx, &pkt.Operation) logKeyLoadDuration(loadStart) if err != nil { - log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) + log.Errorf("failed to load key with sni= %s ip= %s ski=%v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) return makeErrResponse(pkt, protocol.ErrInternal) } else if key == nil { - log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound) + log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound) return makeErrResponse(pkt, protocol.ErrKeyNotFound) } if _, ok := key.Public().(*rsa.PublicKey); !ok { - log.Errorf("Connection %v: sni=%s request-id=%s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto) + log.Errorf("Connection %v: sni= %s request-id= %s: %s: Key is not RSA", connName, pkt.Operation.SNI, reqID, protocol.ErrCrypto) return makeErrResponse(pkt, protocol.ErrCrypto) } @@ -446,7 +489,7 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { // Decrypt without removing padding; that's the client's responsibility. ptxt, err := textbook_rsa.Decrypt(rsaKey, pkt.Operation.Payload) if err != nil { - log.Errorf("connection %v: sni=%s ip=%s ski=%v request-id=%s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) + log.Errorf("connection %v: sni= %s ip= %s ski= %v request-id= %s: %v", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) return makeErrResponse(pkt, protocol.ErrCrypto) } return makeRespondResponse(pkt, ptxt) @@ -493,10 +536,10 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { key, err := s.keys.Get(ctx, &pkt.Operation) logKeyLoadDuration(loadStart) if err != nil { - log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) + log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err) return makeErrResponse(pkt, protocol.ErrInternal) } else if key == nil { - log.Errorf("failed to load key with sni=%s ip=%s ski=%v request-id=%s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound) + log.Errorf("failed to load key with sni= %s ip= %s ski= %v request-id= %s: %v", pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrKeyNotFound) return makeErrResponse(pkt, protocol.ErrKeyNotFound) } @@ -526,17 +569,17 @@ func (s *Server) unlimitedDo(pkt *protocol.Packet, connName string) response { } if err != nil { if attempts > 1 { - log.Debugf("Connection %v sni=%s ip=%s ski=%v request-id=%s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1) + log.Debugf("Connection %v sni= %s ip= %s ski= %v request-id= %s : failed sign attempt: %s, %d attempt(s) left", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, err, attempts-1) continue } else { tracing.LogError(span, err) // This indicates that a remote keyserver is being used var remoteConfigurationErr RemoteConfigurationErr if errors.As(err, &remoteConfigurationErr) { - log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err) + log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrRemoteConfiguration, err) return makeErrResponse(pkt, protocol.ErrRemoteConfiguration) } else { - log.Errorf("Connection %v sni=%s ip=%s ski=%v request-id=%s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err) + log.Errorf("Connection %v sni= %s ip= %s ski= %v request-id= %s : %s: Signing error: %v\n", connName, pkt.Operation.SNI, pkt.Operation.ServerIP, pkt.Operation.SKI, reqID, protocol.ErrCrypto, err) return makeErrResponse(pkt, protocol.ErrCrypto) } } @@ -656,6 +699,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) { } connState := tconn.ConnectionState() certmetrics.Observe(certmetrics.CertSourceFromCerts(fmt.Sprintf("listener: %s", l.Addr().String()), connState.PeerCertificates)...) + cl := getClientInfoFromCerts(connState.PeerCertificates) limited, err := s.config.isLimited(connState) if err != nil { log.Errorf("connection %v: could not determine if limited: %v", c.RemoteAddr(), err) @@ -692,6 +736,7 @@ func (s *Server) spawn(l net.Listener, c net.Conn) { conn: tconn, listener: l, timeout: timeout, + c: cl, } err = handler.loop() diff --git a/tracing/tracing.go b/tracing/tracing.go index c5c4007c..15d76c94 100644 --- a/tracing/tracing.go +++ b/tracing/tracing.go @@ -47,6 +47,7 @@ func SetOperationSpanTags(span opentracing.Span, op *protocol.Operation) { "operation.sni": op.SNI, "operation.certid": op.CertID, "operation.customfuncname": op.CustomFuncName, + "operation.forwardingsvc": fmt.Sprintf("%d", op.ForwardingSvc), } for k, v := range tags { span.SetTag(k, v)