Skip to content

Commit 12a07bd

Browse files
authored
feat: Support multiple certificates for the ca certificate (#270)
* Separate certificate handler code * Append certs to cert pool * Fix linting * Add unit test * Code Review comments * close the temp file
1 parent af3e481 commit 12a07bd

File tree

4 files changed

+181
-20
lines changed

4 files changed

+181
-20
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package cacertificatehandler
2+
3+
import (
4+
"crypto/x509"
5+
"fmt"
6+
"os"
7+
)
8+
9+
func GetCertificatePool(certPath string) (*x509.CertPool, error) {
10+
certBytes, err := getCertBytes(certPath)
11+
if err != nil {
12+
return nil, err
13+
}
14+
rootCertPool := x509.NewCertPool()
15+
ok := rootCertPool.AppendCertsFromPEM(certBytes)
16+
if !ok {
17+
msg := "failed to append certificate to pool"
18+
return nil, fmt.Errorf("%s :%w", msg, err)
19+
}
20+
return rootCertPool, nil
21+
}
22+
23+
func getCertBytes(certPath string) ([]byte, error) {
24+
certBytes, err := os.ReadFile(certPath)
25+
if err != nil {
26+
msg := "could not load CA certificate"
27+
return nil, fmt.Errorf("%s :%w", msg, err)
28+
}
29+
30+
return certBytes, nil
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
package cacertificatehandler_test
2+
3+
import (
4+
"crypto/rand"
5+
"crypto/x509"
6+
"crypto/x509/pkix"
7+
"encoding/pem"
8+
"fmt"
9+
"math/big"
10+
"net"
11+
"os"
12+
"testing"
13+
"time"
14+
15+
"github.com/stretchr/testify/require"
16+
17+
"github.com/kyma-project/runtime-watcher/skr/internal/cacertificatehandler"
18+
"github.com/kyma-project/runtime-watcher/skr/internal/tlstest"
19+
)
20+
21+
func TestGetCertificatePool1(t *testing.T) {
22+
t.Parallel()
23+
tests := []struct {
24+
name string
25+
certificateCount int
26+
certPath string
27+
}{
28+
{
29+
name: "certificate pool with one certificate",
30+
certificateCount: 1,
31+
certPath: "ca-1.cert",
32+
},
33+
{
34+
name: "certificate pool with two certificates",
35+
certificateCount: 2,
36+
certPath: "ca-2.cert",
37+
},
38+
}
39+
for _, tt := range tests {
40+
testCase := tt
41+
t.Run(testCase.name, func(t *testing.T) {
42+
t.Parallel()
43+
file, err := os.CreateTemp("", testCase.certPath)
44+
require.NoError(t, err)
45+
46+
err = writeCertificatesToFile(file, testCase.certificateCount)
47+
require.NoError(t, err)
48+
49+
got, err := cacertificatehandler.GetCertificatePool(file.Name())
50+
require.NoError(t, err)
51+
require.False(t, got.Equal(x509.NewCertPool()))
52+
53+
certificates, err := getCertificates(file.Name())
54+
require.NoError(t, err)
55+
err = os.Remove(file.Name())
56+
require.NoError(t, err)
57+
expectedCertPool := x509.NewCertPool()
58+
for _, certificate := range certificates {
59+
expectedCertPool.AddCert(certificate)
60+
}
61+
require.True(t, got.Equal(expectedCertPool))
62+
})
63+
}
64+
}
65+
66+
func getCertificates(certPath string) ([]*x509.Certificate, error) {
67+
caCertBytes, err := os.ReadFile(certPath)
68+
if err != nil {
69+
return nil, fmt.Errorf("could not load CA certificate :%w", err)
70+
}
71+
var certs []*x509.Certificate
72+
remainingCert := caCertBytes
73+
for len(remainingCert) > 0 {
74+
var publicPemBlock *pem.Block
75+
publicPemBlock, remainingCert = pem.Decode(remainingCert)
76+
rootPubCrt, errParse := x509.ParseCertificate(publicPemBlock.Bytes)
77+
if errParse != nil {
78+
msg := "failed to parse public key"
79+
return nil, fmt.Errorf("%s :%w", msg, errParse)
80+
}
81+
certs = append(certs, rootPubCrt)
82+
}
83+
84+
return certs, nil
85+
}
86+
87+
func createCertificate() *x509.Certificate {
88+
sn, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
89+
cert := &x509.Certificate{
90+
SerialNumber: sn,
91+
Subject: pkix.Name{
92+
CommonName: "127.0.0.1",
93+
},
94+
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
95+
NotBefore: time.Now(),
96+
NotAfter: time.Now().Add(time.Hour),
97+
IsCA: true,
98+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
99+
BasicConstraintsValid: true,
100+
}
101+
102+
return cert
103+
}
104+
105+
func writeCertificatesToFile(certFile *os.File, certificateCount int) error {
106+
var certs []byte
107+
108+
for i := 0; i < certificateCount; i++ {
109+
rootKey, err := tlstest.GenerateRootKey()
110+
if err != nil {
111+
return fmt.Errorf("failed to generate root key: %w", err)
112+
}
113+
114+
certificate := createCertificate()
115+
cert, err := tlstest.CreateCert(certificate, certificate, rootKey, rootKey)
116+
if err != nil {
117+
return fmt.Errorf("failed to create certificate: %w", err)
118+
}
119+
certBytes := pem.EncodeToMemory(&pem.Block{
120+
Type: "CERTIFICATE",
121+
Bytes: cert.Certificate[0],
122+
})
123+
certs = append(certs, certBytes...)
124+
}
125+
126+
if _, err := certFile.Write(certs); err != nil {
127+
certFile.Close()
128+
return fmt.Errorf("failed to write certificates to file: %w", err)
129+
}
130+
131+
return certFile.Close()
132+
}

runtime-watcher/internal/handler.go

+4-14
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,11 @@ package internal
33
import (
44
"bytes"
55
"crypto/tls"
6-
"crypto/x509"
76
"encoding/json"
8-
"encoding/pem"
97
"errors"
108
"fmt"
119
"io"
1210
"net/http"
13-
"os"
1411
"reflect"
1512
"strings"
1613
"time"
@@ -26,6 +23,7 @@ import (
2623
"github.com/go-logr/logr"
2724
listenerTypes "github.com/kyma-project/runtime-watcher/listener/pkg/types"
2825

26+
"github.com/kyma-project/runtime-watcher/skr/internal/cacertificatehandler"
2927
"github.com/kyma-project/runtime-watcher/skr/internal/requestparser"
3028
"github.com/kyma-project/runtime-watcher/skr/internal/serverconfig"
3129
"github.com/kyma-project/runtime-watcher/skr/internal/watchermetrics"
@@ -324,19 +322,11 @@ func (h *Handler) getHTTPSClient() (*http.Client, error) {
324322
msg := "could not load tls certificate"
325323
return nil, fmt.Errorf("%s :%w", msg, err)
326324
}
327-
caCertBytes, err := os.ReadFile(h.config.CACertPath)
325+
326+
rootCertPool, err := cacertificatehandler.GetCertificatePool(h.config.CACertPath)
328327
if err != nil {
329-
msg := "could not load CA certificate"
330-
return nil, fmt.Errorf("%s :%w", msg, err)
331-
}
332-
publicPemBlock, _ := pem.Decode(caCertBytes)
333-
rootPubCrt, errParse := x509.ParseCertificate(publicPemBlock.Bytes)
334-
if errParse != nil {
335-
msg := "failed to parse public key"
336-
return nil, fmt.Errorf("%s :%w", msg, errParse)
328+
return nil, fmt.Errorf("failed to get certificate pool:%w", err)
337329
}
338-
rootCertPool := x509.NewCertPool()
339-
rootCertPool.AddCert(rootPubCrt)
340330

341331
httpsClient.Timeout = HTTPTimeout
342332
//nolint:gosec

runtime-watcher/internal/tlstest/certificate_provider.go

+14-6
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ func createCertTemplate(isCA bool) (*x509.Certificate, error) {
120120
return template, nil
121121
}
122122

123-
func createCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
123+
func CreateCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey, rootKey *rsa.PrivateKey) (
124124
*tls.Certificate, error,
125125
) {
126126
certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, &privateKey.PublicKey, rootKey)
@@ -141,16 +141,24 @@ func createCert(template, parent *x509.Certificate, privateKey *rsa.PrivateKey,
141141
return &cert, nil
142142
}
143143

144-
func (p *CertProvider) GenerateCerts() error {
144+
func GenerateRootKey() (*rsa.PrivateKey, error) {
145145
rootKey, err := rsa.GenerateKey(rand.Reader, privateKeyBits)
146146
if err != nil {
147-
return fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
147+
return nil, fmt.Errorf("%s: %w", errMsgCreatingPrivateKey, err)
148+
}
149+
return rootKey, nil
150+
}
151+
152+
func (p *CertProvider) GenerateCerts() error {
153+
rootKey, err := GenerateRootKey()
154+
if err != nil {
155+
return err
148156
}
149157
rootTemplate, err := createCertTemplate(true)
150158
if err != nil {
151159
return err
152160
}
153-
p.RootCert, err = createCert(rootTemplate, rootTemplate, rootKey, rootKey)
161+
p.RootCert, err = CreateCert(rootTemplate, rootTemplate, rootKey, rootKey)
154162
if err != nil {
155163
return err
156164
}
@@ -168,7 +176,7 @@ func (p *CertProvider) GenerateCerts() error {
168176
return err
169177
}
170178
serverTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}
171-
p.ServerCert, err = createCert(serverTemplate, rootTemplate, serverKey, rootKey)
179+
p.ServerCert, err = CreateCert(serverTemplate, rootTemplate, serverKey, rootKey)
172180
if err != nil {
173181
return err
174182
}
@@ -182,7 +190,7 @@ func (p *CertProvider) GenerateCerts() error {
182190
return err
183191
}
184192
clientTemplate.ExtKeyUsage = []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}
185-
clientCert, err := createCert(clientTemplate, rootTemplate, clientKey, rootKey)
193+
clientCert, err := CreateCert(clientTemplate, rootTemplate, clientKey, rootKey)
186194
if err != nil {
187195
return err
188196
}

0 commit comments

Comments
 (0)