Skip to content

Commit 6c01400

Browse files
committed
feat: Support key versionless
Retrieve latest key version from akv and put key version into annotation for decryption. Signed-off-by: Zhecheng Li <[email protected]>
1 parent 2b68d2f commit 6c01400

File tree

4 files changed

+158
-73
lines changed

4 files changed

+158
-73
lines changed

cmd/server/main.go

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,14 @@ import (
3131
)
3232

3333
var (
34-
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
35-
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
36-
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
37-
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
38-
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
39-
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
40-
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
34+
listenAddr = flag.String("listen-addr", "unix:///opt/azurekms.socket", "gRPC listen address")
35+
keyvaultName = flag.String("keyvault-name", "", "Azure Key Vault name")
36+
keyName = flag.String("key-name", "", "Azure Key Vault KMS key name")
37+
keyVersion = flag.String("key-version", "", "Azure Key Vault KMS key version")
38+
keyVersionlessEnabled = flag.Bool("key-versionless-enabled", false, "Azure Key Vault KMS key versionless enabled")
39+
managedHSM = flag.Bool("managed-hsm", false, "Azure Key Vault Managed HSM. Refer to https://docs.microsoft.com/en-us/azure/key-vault/managed-hsm/overview for more details.")
40+
logFormatJSON = flag.Bool("log-format-json", false, "set log formatter to json")
41+
logLevel = flag.Int("v", 0, "In order of increasing verbosity: 0=warning/error, 2=info, 4=debug, 6=trace, 10=all")
4142
// TODO remove this flag in future release.
4243
_ = flag.String("configFilePath", "/etc/kubernetes/azure.json", "[DEPRECATED] Path for Azure Cloud Provider config file")
4344
configFilePath = flag.String("config-file-path", "/etc/kubernetes/azure.json", "Path for Azure Cloud Provider config file")
@@ -90,14 +91,15 @@ func setupKMSPlugin() error {
9091
mlog.Always("Starting KeyManagementServiceServer service", "version", version.BuildVersion, "buildDate", version.BuildDate)
9192

9293
pluginConfig := &plugin.Config{
93-
KeyVaultName: *keyvaultName,
94-
KeyName: *keyName,
95-
KeyVersion: *keyVersion,
96-
ManagedHSM: *managedHSM,
97-
ProxyMode: *proxyMode,
98-
ProxyAddress: *proxyAddress,
99-
ProxyPort: *proxyPort,
100-
ConfigFilePath: *configFilePath,
94+
KeyVaultName: *keyvaultName,
95+
KeyName: *keyName,
96+
KeyVersion: *keyVersion,
97+
KeyVersionlessEnabled: *keyVersionlessEnabled,
98+
ManagedHSM: *managedHSM,
99+
ProxyMode: *proxyMode,
100+
ProxyAddress: *proxyAddress,
101+
ProxyPort: *proxyPort,
102+
ConfigFilePath: *configFilePath,
101103
}
102104

103105
azureConfig, err := config.GetAzureConfig(pluginConfig.ConfigFilePath)
@@ -110,6 +112,7 @@ func setupKMSPlugin() error {
110112
pluginConfig.KeyVaultName,
111113
pluginConfig.KeyName,
112114
pluginConfig.KeyVersion,
115+
pluginConfig.KeyVersionlessEnabled,
113116
pluginConfig.ProxyMode,
114117
pluginConfig.ProxyAddress,
115118
pluginConfig.ProxyPort,

pkg/plugin/keyvault.go

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"path"
1515
"regexp"
1616
"strings"
17+
"time"
1718

1819
"github.com/Azure/kubernetes-kms/pkg/auth"
1920
"github.com/Azure/kubernetes-kms/pkg/config"
@@ -38,6 +39,8 @@ const (
3839
keyvaultRegionAnnotationKey = "x-ms-keyvault-region.azure.akv.io"
3940
versionAnnotationKey = "version.azure.akv.io"
4041
algorithmAnnotationKey = "algorithm.azure.akv.io"
42+
keyVersionAnnotationKey = "keyversion.azure.akv.io"
43+
keyIDHashAnnotationKey = "keyidhash.azure.akv.io"
4144
dateAnnotationValue = "Date"
4245
requestIDAnnotationValue = "X-Ms-Request-Id"
4346
keyvaultRegionAnnotationValue = "X-Ms-Keyvault-Region"
@@ -64,20 +67,22 @@ type Client interface {
6467

6568
// KeyVaultClient is a client for interacting with Keyvault.
6669
type KeyVaultClient struct {
67-
baseClient kv.BaseClient
68-
config *config.AzureConfig
69-
vaultName string
70-
keyName string
71-
keyVersion string
72-
vaultURL string
73-
keyIDHash string
74-
azureEnvironment *azure.Environment
70+
baseClient kv.BaseClient
71+
config *config.AzureConfig
72+
vaultName string
73+
keyName string
74+
keyVersion string
75+
keyVersionlessEnabled bool
76+
vaultURL string
77+
keyIDHash string
78+
azureEnvironment *azure.Environment
7579
}
7680

7781
// NewKeyVaultClient returns a new key vault client to use for kms operations.
7882
func NewKeyVaultClient(
7983
config *config.AzureConfig,
8084
vaultName, keyName, keyVersion string,
85+
keyVersionlessEnabled bool,
8186
proxyMode bool,
8287
proxyAddress string,
8388
proxyPort int,
@@ -90,9 +95,10 @@ func NewKeyVaultClient(
9095

9196
// this should be the case for bring your own key, clusters bootstrapped with
9297
// aks-engine or aks and standalone kms plugin deployments
93-
if len(vaultName) == 0 || len(keyName) == 0 || len(keyVersion) == 0 {
94-
return nil, fmt.Errorf("key vault name, key name and key version are required")
98+
if len(vaultName) == 0 || len(keyName) == 0 || (!keyVersionlessEnabled && len(keyVersion) == 0) {
99+
return nil, fmt.Errorf("key vault name, key name and key version (not key versionless enabled) are required")
95100
}
101+
96102
kvClient := kv.New()
97103
err := kvClient.AddToUserAgent(version.GetUserAgent())
98104
if err != nil {
@@ -121,9 +127,12 @@ func NewKeyVaultClient(
121127
return nil, fmt.Errorf("failed to get vault url, error: %+v", err)
122128
}
123129

124-
keyIDHash, err := getKeyIDHash(*vaultURL, keyName, keyVersion)
125-
if err != nil {
126-
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
130+
var keyIDHash string
131+
if len(keyVersion) != 0 {
132+
keyIDHash, err = getKeyIDHash(*vaultURL, keyName, keyVersion)
133+
if err != nil {
134+
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
135+
}
127136
}
128137

129138
if proxyMode {
@@ -134,18 +143,51 @@ func NewKeyVaultClient(
134143
mlog.Always("using kms key for encrypt/decrypt", "vaultURL", *vaultURL, "keyName", keyName, "keyVersion", keyVersion)
135144

136145
client := &KeyVaultClient{
137-
baseClient: kvClient,
138-
config: config,
139-
vaultName: vaultName,
140-
keyName: keyName,
141-
keyVersion: keyVersion,
142-
vaultURL: *vaultURL,
143-
azureEnvironment: env,
144-
keyIDHash: keyIDHash,
146+
baseClient: kvClient,
147+
config: config,
148+
vaultName: vaultName,
149+
keyName: keyName,
150+
keyVersion: keyVersion,
151+
keyVersionlessEnabled: keyVersionlessEnabled,
152+
vaultURL: *vaultURL,
153+
azureEnvironment: env,
154+
keyIDHash: keyIDHash,
145155
}
146156
return client, nil
147157
}
148158

159+
func (kvc *KeyVaultClient) GetLatestKeyVersion(ctx context.Context) (string, error) {
160+
keyVersionResultPage, err := kvc.baseClient.GetKeyVersions(ctx, kvc.vaultURL, kvc.keyName, nil)
161+
if err != nil {
162+
return "", fmt.Errorf("failed to get key versions, error: %+v", err)
163+
}
164+
var latestKeyVersionItem kv.KeyItem
165+
for keyVersionResultPage.NotDone() {
166+
for _, value := range keyVersionResultPage.Values() {
167+
if latestKeyVersionItem.Kid == nil {
168+
latestKeyVersionItem = value
169+
} else {
170+
updatedTimeCurrent := time.Time(*value.Attributes.Updated)
171+
updatedTimeLatest := time.Time(*latestKeyVersionItem.Attributes.Updated)
172+
if updatedTimeCurrent.After(updatedTimeLatest) {
173+
latestKeyVersionItem = value
174+
}
175+
}
176+
}
177+
keyVersionResultPage.Next()
178+
}
179+
180+
if latestKeyVersionItem.Kid == nil {
181+
return "", fmt.Errorf("failed to get latest key version, key id is nil")
182+
}
183+
kidSplitted := strings.Split(*latestKeyVersionItem.Kid, "/")
184+
if len(kidSplitted) == 0 {
185+
return "", fmt.Errorf("failed to get latest key version, key id is invalid %q", *latestKeyVersionItem.Kid)
186+
}
187+
latestKeyVersion := kidSplitted[len(kidSplitted)-1]
188+
return latestKeyVersion, nil
189+
}
190+
149191
// Encrypt encrypts the given plain text using the keyvault key.
150192
func (kvc *KeyVaultClient) Encrypt(
151193
ctx context.Context,
@@ -158,15 +200,29 @@ func (kvc *KeyVaultClient) Encrypt(
158200
Algorithm: encryptionAlgorithm,
159201
Value: &value,
160202
}
161-
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
203+
204+
keyVersion := kvc.keyVersion
205+
keyIDHash := kvc.keyIDHash
206+
if kvc.keyVersionlessEnabled {
207+
var err error
208+
if keyVersion, err = kvc.GetLatestKeyVersion(ctx); err != nil {
209+
return nil, fmt.Errorf("failed to get latest key version, error: %+v", err)
210+
}
211+
212+
if keyIDHash, err = getKeyIDHash(kvc.vaultURL, kvc.keyName, keyVersion); err != nil {
213+
return nil, fmt.Errorf("failed to get key id hash, error: %w", err)
214+
}
215+
}
216+
217+
result, err := kvc.baseClient.Encrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
162218
if err != nil {
163219
return nil, fmt.Errorf("failed to encrypt, error: %+v", err)
164220
}
165221

166-
if kvc.keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
222+
if keyIDHash != fmt.Sprintf("%x", sha256.Sum256([]byte(*result.Kid))) {
167223
return nil, fmt.Errorf(
168224
"key id initialized does not match with the key id from encryption result, expected: %s, got: %s",
169-
kvc.keyIDHash,
225+
keyIDHash,
170226
*result.Kid,
171227
)
172228
}
@@ -177,11 +233,14 @@ func (kvc *KeyVaultClient) Encrypt(
177233
keyvaultRegionAnnotationKey: []byte(result.Header.Get(keyvaultRegionAnnotationValue)),
178234
versionAnnotationKey: []byte(encryptionResponseVersion),
179235
algorithmAnnotationKey: []byte(encryptionAlgorithm),
236+
keyVersionAnnotationKey: []byte(keyVersion),
237+
keyIDHashAnnotationKey: []byte(keyIDHash),
180238
}
181239

240+
mlog.Info("Encryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
182241
return &service.EncryptResponse{
183242
Ciphertext: []byte(*result.Result),
184-
KeyID: kvc.keyIDHash,
243+
KeyID: keyIDHash,
185244
Annotations: annotations,
186245
}, nil
187246
}
@@ -208,7 +267,12 @@ func (kvc *KeyVaultClient) Decrypt(
208267
Value: &value,
209268
}
210269

211-
result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, kvc.keyVersion, params)
270+
keyVersion := kvc.keyVersion
271+
if len(annotations[keyVersionAnnotationKey]) != 0 {
272+
keyVersion = string(annotations[keyVersionAnnotationKey])
273+
}
274+
275+
result, err := kvc.baseClient.Decrypt(ctx, kvc.vaultURL, kvc.keyName, keyVersion, params)
212276
if err != nil {
213277
return nil, fmt.Errorf("failed to decrypt, error: %+v", err)
214278
}
@@ -217,6 +281,7 @@ func (kvc *KeyVaultClient) Decrypt(
217281
return nil, fmt.Errorf("failed to base64 decode result, error: %+v", err)
218282
}
219283

284+
mlog.Info("Decryption succeeded", "vaultName", kvc.vaultName, "keyName", kvc.keyName, "keyVersion", keyVersion)
220285
return bytes, nil
221286
}
222287

@@ -241,11 +306,15 @@ func (kvc *KeyVaultClient) validateAnnotations(
241306
return fmt.Errorf("invalid annotations, annotations cannot be empty")
242307
}
243308

244-
if keyID != kvc.keyIDHash {
309+
expectedKeyIDHash := kvc.keyIDHash
310+
if len(annotations[keyIDHashAnnotationKey]) != 0 {
311+
expectedKeyIDHash = string(annotations[keyIDHashAnnotationKey])
312+
}
313+
if keyID != expectedKeyIDHash {
245314
return fmt.Errorf(
246315
"key id %s does not match expected key id %s used for encryption",
247316
keyID,
248-
kvc.keyIDHash,
317+
expectedKeyIDHash,
249318
)
250319
}
251320

pkg/plugin/keyvault_test.go

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,16 @@ var (
2121

2222
func TestNewKeyVaultClientError(t *testing.T) {
2323
tests := []struct {
24-
desc string
25-
config *config.AzureConfig
26-
vaultName string
27-
keyName string
28-
keyVersion string
29-
proxyMode bool
30-
proxyAddress string
31-
proxyPort int
32-
managedHSM bool
24+
desc string
25+
config *config.AzureConfig
26+
vaultName string
27+
keyName string
28+
keyVersion string
29+
keyVersionlessEnabled bool
30+
proxyMode bool
31+
proxyAddress string
32+
proxyPort int
33+
managedHSM bool
3334
}{
3435
{
3536
desc: "vault name not provided",
@@ -43,7 +44,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
4344
proxyMode: false,
4445
},
4546
{
46-
desc: "key version not provided",
47+
desc: "key version not provided when not keyVersionlessEnabled",
4748
config: &config.AzureConfig{},
4849
vaultName: "testkv",
4950
keyName: "k8s",
@@ -68,7 +69,7 @@ func TestNewKeyVaultClientError(t *testing.T) {
6869

6970
for _, test := range tests {
7071
t.Run(test.desc, func(t *testing.T) {
71-
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
72+
if _, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM); err == nil {
7273
t.Fatalf("newKeyVaultClient() expected error, got nil")
7374
}
7475
})
@@ -77,16 +78,17 @@ func TestNewKeyVaultClientError(t *testing.T) {
7778

7879
func TestNewKeyVaultClient(t *testing.T) {
7980
tests := []struct {
80-
desc string
81-
config *config.AzureConfig
82-
vaultName string
83-
keyName string
84-
keyVersion string
85-
proxyMode bool
86-
proxyAddress string
87-
proxyPort int
88-
managedHSM bool
89-
expectedVaultURL string
81+
desc string
82+
config *config.AzureConfig
83+
vaultName string
84+
keyName string
85+
keyVersion string
86+
keyVersionlessEnabled bool
87+
proxyMode bool
88+
proxyAddress string
89+
proxyPort int
90+
managedHSM bool
91+
expectedVaultURL string
9092
}{
9193
{
9294
desc: "no error",
@@ -127,11 +129,21 @@ func TestNewKeyVaultClient(t *testing.T) {
127129
proxyMode: false,
128130
expectedVaultURL: "https://testkv.managedhsm.azure.net/",
129131
},
132+
{
133+
desc: "no error when no key version with keyVersionlessEnabled",
134+
config: &config.AzureConfig{ClientID: "clientid", ClientSecret: "clientsecret"},
135+
vaultName: "testkv",
136+
keyName: "key1",
137+
keyVersion: "",
138+
keyVersionlessEnabled: true,
139+
proxyMode: false,
140+
expectedVaultURL: "https://testkv.vault.azure.net/",
141+
},
130142
}
131143

132144
for _, test := range tests {
133145
t.Run(test.desc, func(t *testing.T) {
134-
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
146+
kvClient, err := NewKeyVaultClient(test.config, test.vaultName, test.keyName, test.keyVersion, test.keyVersionlessEnabled, test.proxyMode, test.proxyAddress, test.proxyPort, test.managedHSM)
135147
if err != nil {
136148
t.Fatalf("newKeyVaultClient() failed with error: %v", err)
137149
}

pkg/plugin/server.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ type KeyManagementServiceServer struct {
2727

2828
// Config is the configuration for the KMS plugin.
2929
type Config struct {
30-
ConfigFilePath string
31-
KeyVaultName string
32-
KeyName string
33-
KeyVersion string
34-
ManagedHSM bool
35-
ProxyMode bool
36-
ProxyAddress string
37-
ProxyPort int
30+
ConfigFilePath string
31+
KeyVaultName string
32+
KeyName string
33+
KeyVersion string
34+
KeyVersionlessEnabled bool
35+
ManagedHSM bool
36+
ProxyMode bool
37+
ProxyAddress string
38+
ProxyPort int
3839
}
3940

4041
// NewKMSv1Server creates an instance of the KMS Service Server.

0 commit comments

Comments
 (0)