Skip to content

Commit e2c769d

Browse files
committed
refactor code to allow testability
1 parent aef1fc9 commit e2c769d

File tree

1 file changed

+19
-20
lines changed

1 file changed

+19
-20
lines changed

cmd/flagsValidation.go

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"runtime"
2727

2828
"github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file"
29+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share"
2930
"github.com/Azure/azure-storage-azcopy/v10/common"
3031
"github.com/spf13/cobra"
3132
)
@@ -320,7 +321,22 @@ func validateShareProtocolCompatibility(
320321
) error {
321322

322323
// We can ignore the error if we fail to get the share properties.
323-
shareProtocol, _ := getShareProtocolType(ctx, serviceClient, resource, protocol)
324+
fileURLParts, err := file.ParseURL(resource.Value)
325+
if err != nil {
326+
return fmt.Errorf("failed to parse resource URL: %w", err)
327+
}
328+
shareName := fileURLParts.ShareName
329+
330+
if serviceClient == nil {
331+
return fmt.Errorf("service client is nil")
332+
}
333+
334+
fileServiceClient, err := serviceClient.FileServiceClient()
335+
if err != nil {
336+
return fmt.Errorf("failed to create file service client: %w", err)
337+
}
338+
shareClient := fileServiceClient.NewShareClient(shareName)
339+
shareProtocol, _ := getShareProtocolType(ctx, shareName, shareClient, protocol)
324340

325341
if shareProtocol == common.ELocation.File() {
326342
if isSource && fromTo.From() != common.ELocation.File() {
@@ -350,27 +366,10 @@ func validateShareProtocolCompatibility(
350366
// If retrieval fails, it logs a warning and returns the fallback givenValue ("SMB" or "NFS").
351367
func getShareProtocolType(
352368
ctx context.Context,
353-
serviceClient *common.ServiceClient,
354-
resource common.ResourceString,
369+
shareName string,
370+
shareClient *share.Client,
355371
givenValue common.Location,
356372
) (common.Location, error) {
357-
358-
fileURLParts, err := file.ParseURL(resource.Value)
359-
if err != nil {
360-
return common.ELocation.Unknown(), fmt.Errorf("failed to parse resource URL: %w", err)
361-
}
362-
shareName := fileURLParts.ShareName
363-
364-
if serviceClient == nil {
365-
return common.ELocation.Unknown(), fmt.Errorf("service client is nil")
366-
}
367-
368-
fileServiceClient, err := serviceClient.FileServiceClient()
369-
if err != nil {
370-
return common.ELocation.Unknown(), fmt.Errorf("failed to create file service client: %w", err)
371-
}
372-
373-
shareClient := fileServiceClient.NewShareClient(shareName)
374373
properties, err := shareClient.GetProperties(ctx, nil)
375374
if err != nil {
376375
glcm.Info(fmt.Sprintf("Warning: Failed to fetch share properties for '%s'. Assuming the share uses '%s' protocol based on --from-to flag.", shareName, givenValue))

0 commit comments

Comments
 (0)