Skip to content

Commit f4d53fb

Browse files
balamanovaabalamanova
andauthored
Rotate cert conf refresh fix (#90)
Signed-off-by: abalamanova <assem.balamanova@yahooinc.com> Co-authored-by: abalamanova <assem.balamanova@yahooinc.com>
1 parent 286fce4 commit f4d53fb

File tree

5 files changed

+175
-59
lines changed

5 files changed

+175
-59
lines changed

deploy/example/example-app.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,3 @@ spec:
7676
volumeAttributes:
7777
csi.cert-manager.athenz.io/pod-subdomain: "my-subdomain"
7878
csi.cert-manager.athenz.io/pod-hostname: "my-hostname"
79-
csi.cert-manager.athenz.io/refresh-interval: "1h"
80-

internal/csi/driver/driver.go

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -464,27 +464,14 @@ func (d *Driver) writeKeypair(meta metadata.Metadata, key crypto.PrivateKey, cha
464464

465465
// Calculate the next issuance time before we write any data to file,
466466
// so if the write fails, we are not left in a bad state.
467-
var nextIssuanceTime time.Time
468-
469-
// Check if a custom refresh interval is specified in the volume context
467+
// Parse refresh interval from volume context, defaults to 24h if not specified or invalid
470468
refreshIntervalStr := meta.VolumeContext[attrRefreshInterval]
471-
if refreshIntervalStr != "" {
472-
// Use the custom refresh interval from volume context
473-
refreshInterval, err := parseRefreshInterval(refreshIntervalStr, defaultRefreshInterval)
474-
if err != nil {
475-
return fmt.Errorf("failed to parse refresh interval: %w", err)
476-
}
477-
nextIssuanceTime = calculateNextIssuanceTimeWithRefreshInterval(refreshInterval)
478-
d.log.Info("using custom refresh interval", "refreshInterval", refreshInterval.String(), "nextIssuanceTime", nextIssuanceTime.Format(time.RFC3339))
479-
} else {
480-
// Fall back to certificate-based calculation (2/3 of validity period)
481-
var err error
482-
nextIssuanceTime, err = calculateNextIssuanceTime(chain)
483-
if err != nil {
484-
return fmt.Errorf("failed to calculate next issuance time: %w", err)
485-
}
486-
d.log.Info("using certificate-based refresh interval", "nextIssuanceTime", nextIssuanceTime.Format(time.RFC3339))
469+
refreshInterval, err := parseRefreshInterval(refreshIntervalStr, defaultRefreshInterval)
470+
if err != nil {
471+
d.log.Error(err, "invalid refresh interval, using default", "default", defaultRefreshInterval.String())
487472
}
473+
nextIssuanceTime := calculateNextIssuanceTimeWithRefreshInterval(refreshInterval)
474+
d.log.Info("using refresh interval", "refreshInterval", refreshInterval.String(), "nextIssuanceTime", nextIssuanceTime.Format(time.RFC3339))
488475

489476
data := map[string][]byte{
490477
d.certFileName: chain,

internal/csi/driver/util.go

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -57,24 +57,6 @@ func signRequest(_ metadata.Metadata, key crypto.PrivateKey, request *x509.Certi
5757
}), nil
5858
}
5959

60-
// calculateNextIssuanceTime returns the time when the certificate should be
61-
// renewed. This will be 2/3rds the duration of the leaf certificate's validity period.
62-
func calculateNextIssuanceTime(chain []byte) (time.Time, error) {
63-
block, _ := pem.Decode(chain)
64-
65-
crt, err := x509.ParseCertificate(block.Bytes)
66-
if err != nil {
67-
return time.Time{}, fmt.Errorf("parsing issued certificate: %w", err)
68-
}
69-
70-
// Renew once a certificate is 2/3rds of the way through its actual lifetime.
71-
actualDuration := crt.NotAfter.Sub(crt.NotBefore)
72-
73-
renewBeforeNotAfter := actualDuration / 3
74-
75-
return crt.NotAfter.Add(-renewBeforeNotAfter), nil
76-
}
77-
7860
// extract domain and service from the service account name
7961
// e.g. athenz.prod.api -> domain: athenz.prod, service: api
8062
func extractDomainService(saName string) (string, string) {
@@ -127,23 +109,24 @@ func getDomainFromNamespaceAnnotations(annotations map[string]string) string {
127109
return ""
128110
}
129111

130-
// parseRefreshInterval parses a refresh interval string in hours (e.g., "24h", "12h", "1h")
131-
// and returns the duration. If the string is empty, it returns the default refresh interval.
132-
// If the string is invalid or less than 1 hour, it returns an error.
112+
// parseRefreshInterval parses a refresh interval string (e.g., "60m", "120m", "1h", "2h")
113+
// and returns the duration. If the string is empty, it returns the default refresh interval with nil error.
114+
// If invalid, it returns the default interval with an error describing the issue.
115+
// Minimum allowed interval is 60 minutes.
133116
func parseRefreshInterval(intervalStr string, defaultInterval time.Duration) (time.Duration, error) {
134117
if intervalStr == "" {
135-
return defaultInterval, nil
118+
return defaultInterval, fmt.Errorf("no refresh interval specified, using default %s", defaultInterval.String())
136119
}
137120

138-
// Parse the hours value (e.g., "24h" -> 24 hours)
121+
// Parse the duration value (e.g., "60m" -> 60 minutes, "1h" -> 1 hour)
139122
duration, err := time.ParseDuration(intervalStr)
140123
if err != nil {
141-
return 0, fmt.Errorf("invalid refresh interval %q: %w", intervalStr, err)
124+
return defaultInterval, fmt.Errorf("failed to parse refresh interval %q: %w", intervalStr, err)
142125
}
143126

144-
// Ensure the refresh interval is at least 1 hour
145-
if duration < time.Hour {
146-
return 0, fmt.Errorf("refresh interval %q must be at least 1 hour", intervalStr)
127+
// Ensure the refresh interval is at least 60 minutes
128+
if duration < 60*time.Minute {
129+
return defaultInterval, fmt.Errorf("refresh interval %q is less than minimum 60 minutes", intervalStr)
147130
}
148131

149132
return duration, nil

internal/csi/driver/util_test.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -198,32 +198,54 @@ func Test_parseRefreshInterval(t *testing.T) {
198198
expectError bool
199199
}{
200200
{
201-
name: "valid 1h interval (minimum)",
202-
intervalStr: "1h",
201+
name: "valid 60m interval (minimum)",
202+
intervalStr: "60m",
203203
defaultInterval: defaultInterval,
204-
expected: 1 * time.Hour,
204+
expected: 60 * time.Minute,
205+
expectError: false,
206+
},
207+
{
208+
name: "valid 120m interval",
209+
intervalStr: "120m",
210+
defaultInterval: defaultInterval,
211+
expected: 120 * time.Minute,
205212
expectError: false,
206213
},
207214
{
208-
name: "valid 72h interval",
209-
intervalStr: "72h",
215+
name: "valid 1440m interval (24h)",
216+
intervalStr: "1440m",
210217
defaultInterval: defaultInterval,
211-
expected: 72 * time.Hour,
218+
expected: 1440 * time.Minute,
212219
expectError: false,
213220
},
214221
{
215-
name: "empty interval returns default",
222+
name: "empty interval returns default with info message",
216223
intervalStr: "",
217224
defaultInterval: defaultInterval,
218225
expected: defaultInterval,
219-
expectError: false,
226+
expectError: true,
220227
},
221228
{
222-
name: "invalid string returns error",
223-
intervalStr: "interval",
229+
name: "invalid string returns default with error",
230+
intervalStr: "invalid",
224231
defaultInterval: defaultInterval,
232+
expected: defaultInterval,
225233
expectError: true,
226234
},
235+
{
236+
name: "interval less than 60m returns default with error",
237+
intervalStr: "30m",
238+
defaultInterval: defaultInterval,
239+
expected: defaultInterval,
240+
expectError: true,
241+
},
242+
{
243+
name: "valid 1h interval (hours format also accepted)",
244+
intervalStr: "1h",
245+
defaultInterval: defaultInterval,
246+
expected: 1 * time.Hour,
247+
expectError: false,
248+
},
227249
}
228250

229251
for _, tt := range tests {
@@ -233,8 +255,8 @@ func Test_parseRefreshInterval(t *testing.T) {
233255
assert.Error(t, err)
234256
} else {
235257
assert.NoError(t, err)
236-
assert.Equal(t, tt.expected, result)
237258
}
259+
assert.Equal(t, tt.expected, result)
238260
})
239261
}
240262
}

test/e2e/suite/refreshinterval/refreshinterval.go

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package refreshinterval
1919
import (
2020
"bytes"
2121
"os/exec"
22+
"strings"
2223

2324
corev1 "k8s.io/api/core/v1"
2425
rbacv1 "k8s.io/api/rbac/v1"
@@ -32,6 +33,8 @@ import (
3233
. "github.com/onsi/gomega"
3334
)
3435

36+
const csiDriverNamespace = "cert-manager"
37+
3538
var _ = framework.CasesDescribe("RefreshInterval", func() {
3639
f := framework.NewDefaultFramework("RefreshInterval")
3740

@@ -136,6 +139,129 @@ var _ = framework.CasesDescribe("RefreshInterval", func() {
136139
Expect(cmd.Run()).To(Succeed())
137140
Expect(buf.Len()).To(BeNumerically(">", 0), "expected certificate file to not be empty")
138141

142+
By("Verifying CSI driver logs show 1h refresh interval")
143+
logBuf := new(bytes.Buffer)
144+
logCmd := exec.Command(f.Config().KubectlBinPath, "logs", "-n"+csiDriverNamespace, "-l", "app=csi-driver-athenz", "-c", "csi-driver-athenz", "--tail=100")
145+
logCmd.Stdout = logBuf
146+
logCmd.Stderr = GinkgoWriter
147+
Expect(logCmd.Run()).To(Succeed())
148+
Expect(strings.Contains(logBuf.String(), "refreshInterval\"=\"1h0m0s\"")).To(BeTrue(), "expected logs to show 1h refresh interval")
149+
150+
By("Cleaning up resources")
151+
Expect(f.Client().Delete(f.Context(), &pod)).NotTo(HaveOccurred())
152+
Expect(f.Client().Delete(f.Context(), &rolebinding)).NotTo(HaveOccurred())
153+
Expect(f.Client().Delete(f.Context(), &role)).NotTo(HaveOccurred())
154+
Expect(f.Client().Delete(f.Context(), &serviceAccount)).NotTo(HaveOccurred())
155+
})
156+
157+
It("should issue certificate with default refresh interval (24h)", func() {
158+
By("Creating service account, role, and rolebinding")
159+
160+
serviceAccount := corev1.ServiceAccount{
161+
ObjectMeta: metav1.ObjectMeta{
162+
Name: "athenz.default-refresh-test",
163+
Namespace: f.Namespace.Name,
164+
},
165+
}
166+
Expect(f.Client().Create(f.Context(), &serviceAccount)).NotTo(HaveOccurred())
167+
168+
role := rbacv1.Role{
169+
ObjectMeta: metav1.ObjectMeta{
170+
Name: "default-refresh-test",
171+
Namespace: f.Namespace.Name,
172+
},
173+
Rules: []rbacv1.PolicyRule{{
174+
Verbs: []string{"create"},
175+
APIGroups: []string{"cert-manager.io"},
176+
Resources: []string{"certificaterequests"},
177+
}},
178+
}
179+
Expect(f.Client().Create(f.Context(), &role)).NotTo(HaveOccurred())
180+
181+
rolebinding := rbacv1.RoleBinding{
182+
ObjectMeta: metav1.ObjectMeta{
183+
Name: "default-refresh-test",
184+
Namespace: f.Namespace.Name,
185+
},
186+
RoleRef: rbacv1.RoleRef{
187+
APIGroup: "rbac.authorization.k8s.io",
188+
Kind: "Role",
189+
Name: role.Name,
190+
},
191+
Subjects: []rbacv1.Subject{{
192+
Kind: "ServiceAccount",
193+
Name: serviceAccount.Name,
194+
Namespace: f.Namespace.Name,
195+
}},
196+
}
197+
Expect(f.Client().Create(f.Context(), &rolebinding)).NotTo(HaveOccurred())
198+
199+
By("Creating pod without refresh-interval (should use default 24h)")
200+
pod := corev1.Pod{
201+
ObjectMeta: metav1.ObjectMeta{
202+
Name: "default-refresh-interval-test",
203+
Namespace: f.Namespace.Name,
204+
},
205+
Spec: corev1.PodSpec{
206+
Volumes: []corev1.Volume{{
207+
Name: "csi-driver-athenz",
208+
VolumeSource: corev1.VolumeSource{
209+
CSI: &corev1.CSIVolumeSource{
210+
Driver: "csi.cert-manager.athenz.io",
211+
ReadOnly: pointer.Bool(true),
212+
// No refresh-interval specified - should use default 24h
213+
VolumeAttributes: map[string]string{},
214+
},
215+
},
216+
}},
217+
ServiceAccountName: "athenz.default-refresh-test",
218+
Containers: []corev1.Container{
219+
{
220+
Name: "my-container",
221+
Image: "busybox",
222+
Command: []string{"sleep", "10000"},
223+
VolumeMounts: []corev1.VolumeMount{
224+
{
225+
Name: "csi-driver-athenz",
226+
MountPath: "/var/run/secrets/athenz.io",
227+
},
228+
},
229+
},
230+
},
231+
},
232+
}
233+
Expect(f.Client().Create(f.Context(), &pod)).NotTo(HaveOccurred())
234+
235+
By("Waiting for pod to become ready")
236+
Eventually(func() bool {
237+
var p corev1.Pod
238+
Expect(f.Client().Get(f.Context(), client.ObjectKey{Namespace: f.Namespace.Name, Name: pod.Name}, &p)).NotTo(HaveOccurred())
239+
240+
for _, c := range p.Status.Conditions {
241+
if c.Type == corev1.PodReady {
242+
return c.Status == corev1.ConditionTrue
243+
}
244+
}
245+
246+
return false
247+
}, "60s", "1s").Should(BeTrue(), "expected pod to become ready in time")
248+
249+
By("Verifying certificate was issued")
250+
buf := new(bytes.Buffer)
251+
cmd := exec.Command(f.Config().KubectlBinPath, "exec", "-n"+f.Namespace.Name, pod.Name, "-cmy-container", "--", "cat", "/var/run/secrets/athenz.io/tls.crt")
252+
cmd.Stdout = buf
253+
cmd.Stderr = GinkgoWriter
254+
Expect(cmd.Run()).To(Succeed())
255+
Expect(buf.Len()).To(BeNumerically(">", 0), "expected certificate file to not be empty")
256+
257+
By("Verifying CSI driver logs show 24h refresh interval (default)")
258+
logBuf := new(bytes.Buffer)
259+
logCmd := exec.Command(f.Config().KubectlBinPath, "logs", "-n"+csiDriverNamespace, "-l", "app=csi-driver-athenz", "-c", "csi-driver-athenz", "--tail=100")
260+
logCmd.Stdout = logBuf
261+
logCmd.Stderr = GinkgoWriter
262+
Expect(logCmd.Run()).To(Succeed())
263+
Expect(strings.Contains(logBuf.String(), "refreshInterval\"=\"24h0m0s\"")).To(BeTrue(), "expected logs to show 24h refresh interval (default)")
264+
139265
By("Cleaning up resources")
140266
Expect(f.Client().Delete(f.Context(), &pod)).NotTo(HaveOccurred())
141267
Expect(f.Client().Delete(f.Context(), &rolebinding)).NotTo(HaveOccurred())

0 commit comments

Comments
 (0)