Skip to content

Commit 54766aa

Browse files
committed
updates
1 parent a4788b5 commit 54766aa

File tree

2 files changed

+131
-239
lines changed

2 files changed

+131
-239
lines changed

internal/integration/unified/client_entity.go

Lines changed: 34 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
package unified
88

99
import (
10-
"bytes"
1110
"context"
12-
"encoding/base64"
11+
"crypto/tls"
1312
"fmt"
14-
"os"
1513
"strings"
1614
"sync"
1715
"sync/atomic"
@@ -35,26 +33,11 @@ import (
3533
// exceed the default truncation length.
3634
const defaultMaxDocumentLen = 10_000
3735

38-
var (
39-
// Security-sensitive commands that should be ignored in command monitoring by default.
40-
securitySensitiveCommands = []string{
41-
"authenticate", "saslStart", "saslContinue", "getnonce",
42-
"createUser", "updateUser", "copydbgetnonce", "copydbsaslstart", "copydb",
43-
}
44-
45-
awsAccessKeyID = os.Getenv("FLE_AWS_KEY")
46-
awsSecretAccessKey = os.Getenv("FLE_AWS_SECRET")
47-
awsTempAccessKeyID = os.Getenv("CSFLE_AWS_TEMP_ACCESS_KEY_ID")
48-
awsTempSecretAccessKey = os.Getenv("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY")
49-
awsTempSessionToken = os.Getenv("CSFLE_AWS_TEMP_SESSION_TOKEN")
50-
azureTenantID = os.Getenv("FLE_AZURE_TENANTID")
51-
azureClientID = os.Getenv("FLE_AZURE_CLIENTID")
52-
azureClientSecret = os.Getenv("FLE_AZURE_CLIENTSECRET")
53-
gcpEmail = os.Getenv("FLE_GCP_EMAIL")
54-
gcpPrivateKey = os.Getenv("FLE_GCP_PRIVATEKEY")
55-
56-
placeholderDoc = bsoncore.NewDocumentBuilder().AppendInt32("$$placeholder", 1).Build()
57-
)
36+
// Security-sensitive commands that should be ignored in command monitoring by default.
37+
var securitySensitiveCommands = []string{
38+
"authenticate", "saslStart", "saslContinue", "getnonce",
39+
"createUser", "updateUser", "copydbgetnonce", "copydbsaslstart", "copydb",
40+
}
5841

5942
// clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test
6043
// execution.
@@ -295,128 +278,52 @@ func createAutoEncryptionOptions(opts bson.Raw) (*options.AutoEncryptionOptions,
295278
if err != nil {
296279
return nil, err
297280
}
298-
retrieveProviderData := func(doc bson.Raw, key, defaultVal string) (any, error) {
299-
e := doc.Lookup(key)
300-
if e.IsZero() {
301-
return nil, nil
281+
for _, elem := range elems {
282+
key := elem.Key()
283+
opt := elem.Value().Document()
284+
provider, err := getKmsProvider(key, opt)
285+
if err != nil {
286+
return nil, err
302287
}
303-
switch e.Type {
304-
case bson.TypeString:
305-
return e.StringValue(), nil
306-
case bson.TypeEmbeddedDocument:
307-
if bytes.Equal(e.Document(), placeholderDoc) {
308-
return defaultVal, nil
309-
}
288+
if provider == nil {
289+
continue
310290
}
311-
return nil, fmt.Errorf("unexpected %s in kms provider: %v", key, e)
312-
}
313-
for _, elem := range elems {
314-
provider := make(map[string]any)
315-
providerT := elem.Key()
316-
providerOpt := elem.Value().Document()
317-
switch providerT {
318-
case "aws":
319-
accessKeyID := awsAccessKeyID
320-
secretAccessKey := awsSecretAccessKey
321-
322-
// replace with temporary access, if sessionToken placeholder exists
323-
v, err := retrieveProviderData(providerOpt, "sessionToken", "$$placeholder")
324-
if err != nil {
325-
return nil, err
326-
}
327-
if v == "$$placeholder" {
328-
provider["sessionToken"] = awsTempSessionToken
329-
accessKeyID = awsTempAccessKeyID
330-
secretAccessKey = awsTempSecretAccessKey
331-
} else if v != nil {
332-
provider["sessionToken"] = v
333-
}
334-
335-
for _, e := range []struct {
336-
key string
337-
defaultVal string
338-
}{
339-
{"accessKeyId", accessKeyID},
340-
{"secretAccessKey", secretAccessKey},
341-
} {
342-
v, err = retrieveProviderData(providerOpt, e.key, e.defaultVal)
343-
if err != nil {
344-
return nil, err
345-
}
346-
if v != nil {
347-
provider[e.key] = v
348-
}
349-
}
350-
case "azure":
351-
for _, e := range []struct {
352-
key string
353-
defaultVal string
354-
}{
355-
{"tenantId", azureTenantID},
356-
{"clientId", azureClientID},
357-
{"clientSecret", azureClientSecret},
358-
} {
359-
v, err := retrieveProviderData(providerOpt, e.key, e.defaultVal)
360-
if err != nil {
361-
return nil, err
362-
}
363-
if v != nil {
364-
provider[e.key] = v
365-
}
366-
}
367-
case "gcp":
368-
for _, e := range []struct {
369-
key string
370-
defaultVal string
371-
}{
372-
{"email", gcpEmail},
373-
{"privateKey", gcpPrivateKey},
374-
} {
375-
v, err := retrieveProviderData(providerOpt, e.key, e.defaultVal)
376-
if err != nil {
377-
return nil, err
378-
}
379-
if v != nil {
380-
provider[e.key] = v
381-
}
382-
}
383-
case "kmip":
384-
v, err := retrieveProviderData(providerOpt, "endpoint", "localhost:5698")
385-
if err != nil {
386-
return nil, err
387-
}
388-
if v != nil {
389-
provider["endpoint"] = v
390-
}
391-
case "local", "local:name2":
392-
str := providerOpt.Lookup("key").StringValue()
393-
key, err := base64.StdEncoding.DecodeString(str)
291+
providers[key] = provider
292+
if key == "kmip" && tlsClientCertificateKeyFile != "" && tlsCAFile != "" {
293+
cfg, err := options.BuildTLSConfig(map[string]any{
294+
"tlsCertificateKeyFile": tlsClientCertificateKeyFile,
295+
"tlsCAFile": tlsCAFile,
296+
})
394297
if err != nil {
395-
return nil, err
298+
return nil, fmt.Errorf("error constructing tls config: %w", err)
396299
}
397-
provider["key"] = key
398-
default:
399-
return nil, fmt.Errorf("unrecognized KMS provider: %v", provider)
400-
}
401-
if len(provider) > 0 {
402-
providers[providerT] = provider
300+
aeo.SetTLSConfig(map[string]*tls.Config{
301+
"kmip": cfg,
302+
})
403303
}
404304
}
405305
aeo.SetKmsProviders(providers)
406306
case "schemaMap":
407307
var schemaMap map[string]any
408308
err := bson.Unmarshal(opt.Document(), &schemaMap)
409309
if err != nil {
410-
return nil, err
310+
return nil, fmt.Errorf("error creating schema map: %v", err)
411311
}
412312
aeo.SetSchemaMap(schemaMap)
413313
case "keyVaultNamespace":
414314
kvnsFound = true
415315
aeo.SetKeyVaultNamespace(opt.StringValue())
416-
case "bypassQueryAnalysis":
417-
aeo.SetBypassQueryAnalysis(opt.Boolean())
418316
case "bypassAutoEncryption":
419317
aeo.SetBypassAutoEncryption(opt.Boolean())
318+
case "encryptedFieldsMap":
319+
var encryptedFieldsMap map[string]any
320+
err := bson.Unmarshal(opt.Document(), &encryptedFieldsMap)
321+
if err != nil {
322+
return nil, fmt.Errorf("error creating encryptedFieldsMap: %v", err)
323+
}
324+
aeo.SetEncryptedFieldsMap(encryptedFieldsMap)
325+
case "bypassQueryAnalysis":
326+
aeo.SetBypassQueryAnalysis(opt.Boolean())
420327
default:
421328
return nil, fmt.Errorf("unrecognized option: %v", name)
422329
}

0 commit comments

Comments
 (0)