Skip to content

Commit

Permalink
- optimized prepareparams for volume request to handle region
Browse files Browse the repository at this point in the history
  • Loading branch information
prajwalvathreya committed Nov 13, 2024
1 parent b1caf89 commit 10e2e3d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 59 deletions.
6 changes: 3 additions & 3 deletions internal/driver/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol

// Prepare the volume parameters such as name and SizeGB from the request.
// This step may involve calculations or adjustments based on the request's content.
volName, sizeGB, size, encryptionStatus, err := cs.prepareVolumeParams(ctx, req)
params, err := cs.prepareVolumeParams(ctx, req)
if err != nil {
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Failed, functionStartTime)
return &csi.CreateVolumeResponse{}, err
Expand All @@ -93,7 +93,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
}

// Create the volume
vol, err := cs.createAndWaitForVolume(ctx, volName, req.GetParameters(), encryptionStatus, sizeGB, sourceVolInfo, accessibilityRequirements)
vol, err := cs.createAndWaitForVolume(ctx, params.VolumeName, req.GetParameters(), params.EncryptionStatus, params.TargetSizeGB, sourceVolInfo, params.Region)
if err != nil {
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Failed, functionStartTime)
return &csi.CreateVolumeResponse{}, err
Expand All @@ -103,7 +103,7 @@ func (cs *ControllerServer) CreateVolume(ctx context.Context, req *csi.CreateVol
volContext := cs.createVolumeContext(ctx, req, vol)

// Prepare and return response
resp := cs.prepareCreateVolumeResponse(ctx, vol, size, volContext, sourceVolInfo, contentSource)
resp := cs.prepareCreateVolumeResponse(ctx, vol, params.Size, volContext, sourceVolInfo, contentSource)

// Record function completion
metrics.RecordMetrics(metrics.ControllerCreateVolumeTotal, metrics.ControllerCreateVolumeDuration, metrics.Completed, functionStartTime)
Expand Down
67 changes: 42 additions & 25 deletions internal/driver/controllerserver_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ const (
VolumeEncryption = Name + "/encrypted"
)

// Struct to return volume parameters when prepareVolumeParams is called

type VolumeParams struct {
VolumeName string
TargetSizeGB int
Size int64
EncryptionStatus string
Region string
}

// canAttach indicates whether or not another volume can be attached to the
// Linode with the given ID.
//
Expand Down Expand Up @@ -189,7 +199,7 @@ func (cs *ControllerServer) getContentSourceVolume(ctx context.Context, contentS
// attemptCreateLinodeVolume creates a Linode volume while ensuring idempotency.
// It checks for existing volumes with the same label and either returns the existing
// volume or creates a new one, optionally cloning from a source volume.
func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label, tags, volumeEncryption string, sizeGB int, sourceVolume *linodevolumes.LinodeVolumeKey, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label, tags, volumeEncryption string, sizeGB int, sourceVolume *linodevolumes.LinodeVolumeKey, region string) (*linodego.Volume, error) {
log := logger.GetLogger(ctx)
log.V(4).Info("Attempting to create Linode volume", "label", label, "sizeGB", sizeGB, "tags", tags)

Expand Down Expand Up @@ -219,7 +229,7 @@ func (cs *ControllerServer) attemptCreateLinodeVolume(ctx context.Context, label
return cs.cloneLinodeVolume(ctx, label, sourceVolume.VolumeID)
}

return cs.createLinodeVolume(ctx, label, tags, volumeEncryption, sizeGB, accessibilityRequirements)
return cs.createLinodeVolume(ctx, label, tags, volumeEncryption, sizeGB, region)
}

// Helper function to extract region from topology
Expand All @@ -240,19 +250,10 @@ func getRegionFromTopology(requirements *csi.TopologyRequirement) string {

// createLinodeVolume creates a new Linode volume with the specified label, size, and tags.
// It returns the created volume or an error if the creation fails.
func (cs *ControllerServer) createLinodeVolume(ctx context.Context, label, tags, encryptionStatus string, sizeGB int, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
func (cs *ControllerServer) createLinodeVolume(ctx context.Context, label, tags, encryptionStatus string, sizeGB int, region string) (*linodego.Volume, error) {
log := logger.GetLogger(ctx)
log.V(4).Info("Creating Linode volume", "label", label, "sizeGB", sizeGB, "tags", tags)

// Get the region from req.AccessibilityRequirements if it exists. Fall back to the controller's metadata region if not specified.
region := cs.metadata.Region
if accessibilityRequirements != nil {
if topologyRegion := getRegionFromTopology(accessibilityRequirements); topologyRegion != "" {
log.V(4).Info("Using region from topology", "region", topologyRegion)
region = topologyRegion
}
}

// Prepare the volume creation request with region, label, and size.
volumeReq := linodego.VolumeCreateOptions{
Region: region,
Expand Down Expand Up @@ -431,38 +432,54 @@ func (cs *ControllerServer) validateCreateVolumeRequest(ctx context.Context, req
// prepareVolumeParams prepares the volume parameters for creation.
// It extracts the capacity range from the request, calculates the size,
// and generates a normalized volume name. Returns the volume name and size in GB.
func (cs *ControllerServer) prepareVolumeParams(ctx context.Context, req *csi.CreateVolumeRequest) (volumeName string, targetSizeGB int, size int64, encryptionStatus string, err error) {
func (cs *ControllerServer) prepareVolumeParams(ctx context.Context, req *csi.CreateVolumeRequest) (*VolumeParams, error) {
log := logger.GetLogger(ctx)
log.V(4).Info("Entering prepareVolumeParams()", "req", req)
defer log.V(4).Info("Exiting prepareVolumeParams()")

// by default encryption is disabled
encryptionStatus = "disabled"
// By default, encryption is disabled
encryptionStatus := "disabled"
// Retrieve the capacity range from the request to determine the size limits for the volume.
capRange := req.GetCapacityRange()
// Get the requested size in bytes, handling any potential errors.
size, err = getRequestCapacitySize(capRange)
size, err := getRequestCapacitySize(capRange)
if err != nil {
return "", 0, 0, "", err
return nil, err
}

// Get the region from req.AccessibilityRequirements if it exists. Fall back to the controller's metadata region if not specified.
accessibilityRequirements := req.GetAccessibilityRequirements()
region := cs.metadata.Region
if accessibilityRequirements != nil {
if topologyRegion := getRegionFromTopology(accessibilityRequirements); topologyRegion != "" {
log.V(4).Info("Using region from topology", "region", topologyRegion)
region = topologyRegion
}
}

preKey := linodevolumes.CreateLinodeVolumeKey(0, req.GetName())
volumeName = preKey.GetNormalizedLabelWithPrefix(cs.driver.volumeLabelPrefix)
targetSizeGB = bytesToGB(size)
volumeName := preKey.GetNormalizedLabelWithPrefix(cs.driver.volumeLabelPrefix)
targetSizeGB := bytesToGB(size)

// Check if encryption should be enabled
if req.GetParameters()[VolumeEncryption] == True {
supported, err := cs.isEncryptionSupported(ctx, cs.metadata.Region)
if err != nil {
return volumeName, targetSizeGB, size, encryptionStatus, err
return nil, err
}
if !supported {
return volumeName, targetSizeGB, size, encryptionStatus, errInternal("Volume encryption is not supported in the %s region", cs.metadata.Region)
return nil, errInternal("Volume encryption is not supported in the %s region", cs.metadata.Region)
}
encryptionStatus = "enabled"
}

log.V(4).Info("Volume parameters prepared", "volumeName", volumeName, "targetSizeGB", targetSizeGB)
return volumeName, targetSizeGB, size, encryptionStatus, nil
return &VolumeParams{
VolumeName: volumeName,
TargetSizeGB: targetSizeGB,
Size: size,
EncryptionStatus: encryptionStatus,
Region: region,
}, nil
}

// createVolumeContext creates a context map for the volume based on the request parameters.
Expand All @@ -489,12 +506,12 @@ func (cs *ControllerServer) createVolumeContext(ctx context.Context, req *csi.Cr

// createAndWaitForVolume attempts to create a new volume and waits for it to become active.
// It logs the process and handles any errors that occur during creation or waiting.
func (cs *ControllerServer) createAndWaitForVolume(ctx context.Context, name string, parameters map[string]string, encryptionStatus string, sizeGB int, sourceInfo *linodevolumes.LinodeVolumeKey, accessibilityRequirements *csi.TopologyRequirement) (*linodego.Volume, error) {
func (cs *ControllerServer) createAndWaitForVolume(ctx context.Context, name string, parameters map[string]string, encryptionStatus string, sizeGB int, sourceInfo *linodevolumes.LinodeVolumeKey, region string) (*linodego.Volume, error) {
log := logger.GetLogger(ctx)
log.V(4).Info("Entering createAndWaitForVolume()", "name", name, "sizeGB", sizeGB, "tags", parameters[VolumeTags], "encryptionStatus", encryptionStatus)
defer log.V(4).Info("Exiting createAndWaitForVolume()")

vol, err := cs.attemptCreateLinodeVolume(ctx, name, parameters[VolumeTags], encryptionStatus, sizeGB, sourceInfo, accessibilityRequirements)
vol, err := cs.attemptCreateLinodeVolume(ctx, name, parameters[VolumeTags], encryptionStatus, sizeGB, sourceInfo, region)
if err != nil {
return nil, err
}
Expand Down
67 changes: 36 additions & 31 deletions internal/driver/controllerserver_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,6 @@ func TestCreateAndWaitForVolume(t *testing.T) {
client: mockClient,
}

topology := &csi.TopologyRequirement{
Preferred: []*csi.Topology{
{
Segments: map[string]string{
VolumeTopologyRegion: "us-east",
},
},
},
}

testCases := []struct {
name string
volumeName string
Expand Down Expand Up @@ -378,7 +368,7 @@ func TestCreateAndWaitForVolume(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
tc.setupMocks()
encryptionStatus := "disabled"
volume, err := cs.createAndWaitForVolume(context.Background(), tc.volumeName, tc.parameters, encryptionStatus, tc.sizeGB, tc.sourceInfo, topology)
volume, err := cs.createAndWaitForVolume(context.Background(), tc.volumeName, tc.parameters, encryptionStatus, tc.sizeGB, tc.sourceInfo, "us-east")

if err != nil && !reflect.DeepEqual(tc.expectedError, err) {
if tc.expectedError != nil {
Expand Down Expand Up @@ -477,26 +467,31 @@ func TestPrepareVolumeParams(t *testing.T) {
}
ctx := context.Background()

volumeName, sizeGB, size, _, err := cs.prepareVolumeParams(ctx, tt.req)
params, err := cs.prepareVolumeParams(ctx, tt.req)

if err != nil && !reflect.DeepEqual(tt.expectedError, err) {
if tt.expectedError != nil {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
} else {
t.Errorf("expected no error but got %v", err)
}
// First, verify that the error matches the expectation
if (err != nil && tt.expectedError == nil) || (err == nil && tt.expectedError != nil) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
} else if err != nil && tt.expectedError != nil && err.Error() != tt.expectedError.Error() {
t.Errorf("expected error message %v, got %v", tt.expectedError.Error(), err.Error())
}

if !reflect.DeepEqual(volumeName, tt.expectedName) {
t.Errorf("Expected volume name: %s, but got: %s", tt.expectedName, volumeName)
}
// Only check params fields if params is not nil
if params != nil {
if params.VolumeName != tt.expectedName {
t.Errorf("Expected volume name: %s, but got: %s", tt.expectedName, params.VolumeName)
}

if !reflect.DeepEqual(sizeGB, tt.expectedSizeGB) {
t.Errorf("Expected size in GB: %d, but got: %d", tt.expectedSizeGB, sizeGB)
}
if params.TargetSizeGB != tt.expectedSizeGB {
t.Errorf("Expected size in GB: %d, but got: %d", tt.expectedSizeGB, params.TargetSizeGB)
}

if !reflect.DeepEqual(size, tt.expectedSize) {
t.Errorf("Expected size in bytes: %d, but got: %d", tt.expectedSize, size)
if params.Size != tt.expectedSize {
t.Errorf("Expected size in bytes: %d, but got: %d", tt.expectedSize, params.Size)
}
} else if err == nil {
// If params is nil and no error was expected, the test should fail
t.Errorf("expected non-nil params, got nil")
}
})
}
Expand Down Expand Up @@ -582,14 +577,24 @@ func TestPrepareVolumeParams_Encryption(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tt.setupMocks()

_, _, _, encryptionStatus, err := cs.prepareVolumeParams(ctx, tt.req)
// Call prepareVolumeParams and capture the result and error
params, err := cs.prepareVolumeParams(ctx, tt.req)

if err != nil && !reflect.DeepEqual(err, tt.expectedError) {
t.Errorf("Expected error %v, got %v", tt.expectedError, err)
// Verify that the error matches the expected error
if (err != nil && tt.expectedError == nil) || (err == nil && tt.expectedError != nil) {
t.Errorf("expected error %v, got %v", tt.expectedError, err)
} else if err != nil && tt.expectedError != nil && err.Error() != tt.expectedError.Error() {
t.Errorf("expected error message %v, got %v", tt.expectedError.Error(), err.Error())
}

if encryptionStatus != tt.expectedEncrypt {
t.Errorf("Expected encryption status %v, got %v", tt.expectedEncrypt, encryptionStatus)
// Only proceed to check params fields if params is non-nil and no error was expected
if params != nil && err == nil {
if params.EncryptionStatus != tt.expectedEncrypt {
t.Errorf("Expected encryption status %v, got %v", tt.expectedEncrypt, params.EncryptionStatus)
}
} else if params == nil && err == nil {
// Fail the test if params is nil but no error was expected
t.Errorf("expected non-nil params, got nil")
}
})
}
Expand Down

0 comments on commit 10e2e3d

Please sign in to comment.