Skip to content

Commit 10e2e3d

Browse files
- optimized prepareparams for volume request to handle region
1 parent b1caf89 commit 10e2e3d

File tree

3 files changed

+81
-59
lines changed

3 files changed

+81
-59
lines changed

internal/driver/controllerserver.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
7575

7676
// Prepare the volume parameters such as name and SizeGB from the request.
7777
// This step may involve calculations or adjustments based on the request's content.
78-
volName, sizeGB, size, encryptionStatus, err := cs.prepareVolumeParams(ctx, req)
78+
params, err := cs.prepareVolumeParams(ctx, req)
7979
if err != nil {
8080
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Failed, functionStartTime)
8181
return &csi.CreateVolumeResponse{}, err
@@ -93,7 +93,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
9393
}
9494

9595
// Create the volume
96-
vol, err := cs.createAndWaitForVolume(ctx, volName, req.GetParameters(), encryptionStatus, sizeGB, sourceVolInfo, accessibilityRequirements)
96+
vol, err := cs.createAndWaitForVolume(ctx, params.VolumeName, req.GetParameters(), params.EncryptionStatus, params.TargetSizeGB, sourceVolInfo, params.Region)
9797
if err != nil {
9898
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Failed, functionStartTime)
9999
return &csi.CreateVolumeResponse{}, err
@@ -103,7 +103,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
103103
volContext := cs.createVolumeContext(ctx, req, vol)
104104

105105
// Prepare and return response
106-
resp := cs.prepareCreateVolumeResponse(ctx, vol, size, volContext, sourceVolInfo, contentSource)
106+
resp := cs.prepareCreateVolumeResponse(ctx, vol, params.Size, volContext, sourceVolInfo, contentSource)
107107

108108
// Record function completion
109109
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Completed, functionStartTime)

internal/driver/controllerserver_helper.go

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ const (
8585
VolumeEncryption = Name + "/encrypted"
8686
)
8787

88+
// Struct to return volume parameters when prepareVolumeParams is called
89+
90+
type VolumeParams struct {
91+
VolumeName string
92+
TargetSizeGB int
93+
Size int64
94+
EncryptionStatus string
95+
Region string
96+
}
97+
8898
// canAttach indicates whether or not another volume can be attached to the
8999
// Linode with the given ID.
90100
//
@@ -189,7 +199,7 @@ func (cs *ControllerServer) getContentSourceVolume(ctx context.Context, contentS
189199
// attemptCreateLinodeVolume creates a Linode volume while ensuring idempotency.
190200
// It checks for existing volumes with the same label and either returns the existing
191201
// volume or creates a new one, optionally cloning from a source volume.
192-
func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label, tags, volumeEncryption string, sizeGB int, sourceVolume *linodevolumes.LinodeVolumeKey, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
202+
func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label, tags, volumeEncryption string, sizeGB int, sourceVolume *linodevolumes.LinodeVolumeKey, region string) (*linodego.Volume, error) {
193203
log := logger.GetLogger(ctx)
194204
log.V(4).Info("Attempting to create Linode volume", "label", label, "sizeGB", sizeGB, "tags", tags)
195205

@@ -219,7 +229,7 @@ func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label
219229
return cs.cloneLinodeVolume(ctx, label, sourceVolume.VolumeID)
220230
}
221231

222-
return cs.createLinodeVolume(ctx, label, tags, volumeEncryption, sizeGB, accessibilityRequirements)
232+
return cs.createLinodeVolume(ctx, label, tags, volumeEncryption, sizeGB, region)
223233
}
224234

225235
// Helper function to extract region from topology
@@ -240,19 +250,10 @@ func getRegionFromTopology(requirements *csi.TopologyRequirement) string {
240250

241251
// createLinodeVolume creates a new Linode volume with the specified label, size, and tags.
242252
// It returns the created volume or an error if the creation fails.
243-
func (cs *ControllerServer) createLinodeVolume(ctx context.Context, label, tags, encryptionStatus string, sizeGB int, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
253+
func (cs *ControllerServer) createLinodeVolume(ctx context.Context, label, tags, encryptionStatus string, sizeGB int, region string) (*linodego.Volume, error) {
244254
log := logger.GetLogger(ctx)
245255
log.V(4).Info("Creating Linode volume", "label", label, "sizeGB", sizeGB, "tags", tags)
246256

247-
// Get the region from req.AccessibilityRequirements if it exists. Fall back to the controller's metadata region if not specified.
248-
region := cs.metadata.Region
249-
if accessibilityRequirements != nil {
250-
if topologyRegion := getRegionFromTopology(accessibilityRequirements); topologyRegion != "" {
251-
log.V(4).Info("Using region from topology", "region", topologyRegion)
252-
region = topologyRegion
253-
}
254-
}
255-
256257
// Prepare the volume creation request with region, label, and size.
257258
volumeReq := linodego.VolumeCreateOptions{
258259
Region: region,
@@ -431,38 +432,54 @@ func (cs *ControllerServer) validateCreateVolumeRequest(ctx context.Context, req
431432
// prepareVolumeParams prepares the volume parameters for creation.
432433
// It extracts the capacity range from the request, calculates the size,
433434
// and generates a normalized volume name. Returns the volume name and size in GB.
434-
func (cs *ControllerServer) prepareVolumeParams(ctx context.Context, req *csi.CreateVolumeRequest) (volumeName string, targetSizeGB int, size int64, encryptionStatus string, err error) {
435+
func (cs *ControllerServer) prepareVolumeParams(ctx context.Context, req *csi.CreateVolumeRequest) (*VolumeParams, error) {
435436
log := logger.GetLogger(ctx)
436437
log.V(4).Info("Entering prepareVolumeParams()", "req", req)
437438
defer log.V(4).Info("Exiting prepareVolumeParams()")
438-
439-
// by default encryption is disabled
440-
encryptionStatus = "disabled"
439+
// By default, encryption is disabled
440+
encryptionStatus := "disabled"
441441
// Retrieve the capacity range from the request to determine the size limits for the volume.
442442
capRange := req.GetCapacityRange()
443443
// Get the requested size in bytes, handling any potential errors.
444-
size, err = getRequestCapacitySize(capRange)
444+
size, err := getRequestCapacitySize(capRange)
445445
if err != nil {
446-
return "", 0, 0, "", err
446+
return nil, err
447+
}
448+
449+
// Get the region from req.AccessibilityRequirements if it exists. Fall back to the controller's metadata region if not specified.
450+
accessibilityRequirements := req.GetAccessibilityRequirements()
451+
region := cs.metadata.Region
452+
if accessibilityRequirements != nil {
453+
if topologyRegion := getRegionFromTopology(accessibilityRequirements); topologyRegion != "" {
454+
log.V(4).Info("Using region from topology", "region", topologyRegion)
455+
region = topologyRegion
456+
}
447457
}
448458

449459
preKey := linodevolumes.CreateLinodeVolumeKey(0, req.GetName())
450-
volumeName = preKey.GetNormalizedLabelWithPrefix(cs.driver.volumeLabelPrefix)
451-
targetSizeGB = bytesToGB(size)
460+
volumeName := preKey.GetNormalizedLabelWithPrefix(cs.driver.volumeLabelPrefix)
461+
targetSizeGB := bytesToGB(size)
462+
452463
// Check if encryption should be enabled
453464
if req.GetParameters()[VolumeEncryption] == True {
454465
supported, err := cs.isEncryptionSupported(ctx, cs.metadata.Region)
455466
if err != nil {
456-
return volumeName, targetSizeGB, size, encryptionStatus, err
467+
return nil, err
457468
}
458469
if !supported {
459-
return volumeName, targetSizeGB, size, encryptionStatus, errInternal("Volume encryption is not supported in the %s region", cs.metadata.Region)
470+
return nil, errInternal("Volume encryption is not supported in the %s region", cs.metadata.Region)
460471
}
461472
encryptionStatus = "enabled"
462473
}
463474

464475
log.V(4).Info("Volume parameters prepared", "volumeName", volumeName, "targetSizeGB", targetSizeGB)
465-
return volumeName, targetSizeGB, size, encryptionStatus, nil
476+
return &VolumeParams{
477+
VolumeName: volumeName,
478+
TargetSizeGB: targetSizeGB,
479+
Size: size,
480+
EncryptionStatus: encryptionStatus,
481+
Region: region,
482+
}, nil
466483
}
467484

468485
// createVolumeContext creates a context map for the volume based on the request parameters.
@@ -489,12 +506,12 @@ func (cs *ControllerServer) createVolumeContext(ctx context.Context, req *csi.Cr
489506

490507
// createAndWaitForVolume attempts to create a new volume and waits for it to become active.
491508
// It logs the process and handles any errors that occur during creation or waiting.
492-
func (cs *ControllerServer) createAndWaitForVolume(ctx context.Context, name string, parameters map[string]string, encryptionStatus string, sizeGB int, sourceInfo *linodevolumes.LinodeVolumeKey, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
509+
func (cs *ControllerServer) createAndWaitForVolume(ctx context.Context, name string, parameters map[string]string, encryptionStatus string, sizeGB int, sourceInfo *linodevolumes.LinodeVolumeKey, region string) (*linodego.Volume, error) {
493510
log := logger.GetLogger(ctx)
494511
log.V(4).Info("Entering createAndWaitForVolume()", "name", name, "sizeGB", sizeGB, "tags", parameters[VolumeTags], "encryptionStatus", encryptionStatus)
495512
defer log.V(4).Info("Exiting createAndWaitForVolume()")
496513

497-
vol, err := cs.attemptCreateLinodeVolume(ctx, name, parameters[VolumeTags], encryptionStatus, sizeGB, sourceInfo, accessibilityRequirements)
514+
vol, err := cs.attemptCreateLinodeVolume(ctx, name, parameters[VolumeTags], encryptionStatus, sizeGB, sourceInfo, region)
498515
if err != nil {
499516
return nil, err
500517
}

internal/driver/controllerserver_helper_test.go

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,6 @@ func TestCreateAndWaitForVolume(t *testing.T) {
273273
client: mockClient,
274274
}
275275

276-
topology := &csi.TopologyRequirement{
277-
Preferred: []*csi.Topology{
278-
{
279-
Segments: map[string]string{
280-
VolumeTopologyRegion: "us-east",
281-
},
282-
},
283-
},
284-
}
285-
286276
testCases := []struct {
287277
name string
288278
volumeName string
@@ -378,7 +368,7 @@ func TestCreateAndWaitForVolume(t *testing.T) {
378368
t.Run(tc.name, func(t *testing.T) {
379369
tc.setupMocks()
380370
encryptionStatus := "disabled"
381-
volume, err := cs.createAndWaitForVolume(context.Background(), tc.volumeName, tc.parameters, encryptionStatus, tc.sizeGB, tc.sourceInfo, topology)
371+
volume, err := cs.createAndWaitForVolume(context.Background(), tc.volumeName, tc.parameters, encryptionStatus, tc.sizeGB, tc.sourceInfo, "us-east")
382372

383373
if err != nil && !reflect.DeepEqual(tc.expectedError, err) {
384374
if tc.expectedError != nil {
@@ -477,26 +467,31 @@ func TestPrepareVolumeParams(t *testing.T) {
477467
}
478468
ctx := context.Background()
479469

480-
volumeName, sizeGB, size, _, err := cs.prepareVolumeParams(ctx, tt.req)
470+
params, err := cs.prepareVolumeParams(ctx, tt.req)
481471

482-
if err != nil && !reflect.DeepEqual(tt.expectedError, err) {
483-
if tt.expectedError != nil {
484-
t.Errorf("expected error %v, got %v", tt.expectedError, err)
485-
} else {
486-
t.Errorf("expected no error but got %v", err)
487-
}
472+
// First, verify that the error matches the expectation
473+
if (err != nil && tt.expectedError == nil) || (err == nil && tt.expectedError != nil) {
474+
t.Errorf("expected error %v, got %v", tt.expectedError, err)
475+
} else if err != nil && tt.expectedError != nil && err.Error() != tt.expectedError.Error() {
476+
t.Errorf("expected error message %v, got %v", tt.expectedError.Error(), err.Error())
488477
}
489478

490-
if !reflect.DeepEqual(volumeName, tt.expectedName) {
491-
t.Errorf("Expected volume name: %s, but got: %s", tt.expectedName, volumeName)
492-
}
479+
// Only check params fields if params is not nil
480+
if params != nil {
481+
if params.VolumeName != tt.expectedName {
482+
t.Errorf("Expected volume name: %s, but got: %s", tt.expectedName, params.VolumeName)
483+
}
493484

494-
if !reflect.DeepEqual(sizeGB, tt.expectedSizeGB) {
495-
t.Errorf("Expected size in GB: %d, but got: %d", tt.expectedSizeGB, sizeGB)
496-
}
485+
if params.TargetSizeGB != tt.expectedSizeGB {
486+
t.Errorf("Expected size in GB: %d, but got: %d", tt.expectedSizeGB, params.TargetSizeGB)
487+
}
497488

498-
if !reflect.DeepEqual(size, tt.expectedSize) {
499-
t.Errorf("Expected size in bytes: %d, but got: %d", tt.expectedSize, size)
489+
if params.Size != tt.expectedSize {
490+
t.Errorf("Expected size in bytes: %d, but got: %d", tt.expectedSize, params.Size)
491+
}
492+
} else if err == nil {
493+
// If params is nil and no error was expected, the test should fail
494+
t.Errorf("expected non-nil params, got nil")
500495
}
501496
})
502497
}
@@ -582,14 +577,24 @@ func TestPrepareVolumeParams_Encryption(t *testing.T) {
582577
t.Run(tt.name, func(t *testing.T) {
583578
tt.setupMocks()
584579

585-
_, _, _, encryptionStatus, err := cs.prepareVolumeParams(ctx, tt.req)
580+
// Call prepareVolumeParams and capture the result and error
581+
params, err := cs.prepareVolumeParams(ctx, tt.req)
586582

587-
if err != nil && !reflect.DeepEqual(err, tt.expectedError) {
588-
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
583+
// Verify that the error matches the expected error
584+
if (err != nil && tt.expectedError == nil) || (err == nil && tt.expectedError != nil) {
585+
t.Errorf("expected error %v, got %v", tt.expectedError, err)
586+
} else if err != nil && tt.expectedError != nil && err.Error() != tt.expectedError.Error() {
587+
t.Errorf("expected error message %v, got %v", tt.expectedError.Error(), err.Error())
589588
}
590589

591-
if encryptionStatus != tt.expectedEncrypt {
592-
t.Errorf("Expected encryption status %v, got %v", tt.expectedEncrypt, encryptionStatus)
590+
// Only proceed to check params fields if params is non-nil and no error was expected
591+
if params != nil && err == nil {
592+
if params.EncryptionStatus != tt.expectedEncrypt {
593+
t.Errorf("Expected encryption status %v, got %v", tt.expectedEncrypt, params.EncryptionStatus)
594+
}
595+
} else if params == nil && err == nil {
596+
// Fail the test if params is nil but no error was expected
597+
t.Errorf("expected non-nil params, got nil")
593598
}
594599
})
595600
}

0 commit comments

Comments
 (0)