Skip to content

Commit 5a7ad77

Browse files
gapra-msftCopilot
andauthored
Fixed panic for s3 and gcp transfers (#3276)
* Fixed seg fault for s3 and gcp transfers * remove condition to call method * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * Apply suggestion from @Copilot Co-authored-by: Copilot <[email protected]> * addressed comments * refactor code to allow testability --------- Co-authored-by: Copilot <[email protected]>
1 parent 6e46576 commit 5a7ad77

File tree

3 files changed

+182
-55
lines changed

3 files changed

+182
-55
lines changed

cmd/flagsValidation.go

Lines changed: 24 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"runtime"
2727

2828
"github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/file"
29+
"github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share"
2930
"github.com/Azure/azure-storage-azcopy/v10/common"
3031
"github.com/spf13/cobra"
3132
)
@@ -320,7 +321,22 @@ func validateShareProtocolCompatibility(
320321
) error {
321322

322323
// We can ignore the error if we fail to get the share properties.
323-
shareProtocol, _ := getShareProtocolType(ctx, serviceClient, resource, protocol)
324+
fileURLParts, err := file.ParseURL(resource.Value)
325+
if err != nil {
326+
return fmt.Errorf("failed to parse resource URL: %w", err)
327+
}
328+
shareName := fileURLParts.ShareName
329+
330+
if serviceClient == nil {
331+
return fmt.Errorf("service client is nil")
332+
}
333+
334+
fileServiceClient, err := serviceClient.FileServiceClient()
335+
if err != nil {
336+
return fmt.Errorf("failed to create file service client: %w", err)
337+
}
338+
shareClient := fileServiceClient.NewShareClient(shareName)
339+
shareProtocol, _ := getShareProtocolType(ctx, shareName, shareClient, protocol)
324340

325341
if shareProtocol == common.ELocation.File() {
326342
if isSource && fromTo.From() != common.ELocation.File() {
@@ -350,23 +366,10 @@ func validateShareProtocolCompatibility(
350366
// If retrieval fails, it logs a warning and returns the fallback givenValue ("SMB" or "NFS").
351367
func getShareProtocolType(
352368
ctx context.Context,
353-
serviceClient *common.ServiceClient,
354-
resource common.ResourceString,
369+
shareName string,
370+
shareClient *share.Client,
355371
givenValue common.Location,
356372
) (common.Location, error) {
357-
358-
fileURLParts, err := file.ParseURL(resource.Value)
359-
if err != nil {
360-
return common.ELocation.Unknown(), fmt.Errorf("failed to parse resource URL: %w", err)
361-
}
362-
shareName := fileURLParts.ShareName
363-
364-
fileServiceClient, err := serviceClient.FileServiceClient()
365-
if err != nil {
366-
return common.ELocation.Unknown(), fmt.Errorf("failed to create file service client: %w", err)
367-
}
368-
369-
shareClient := fileServiceClient.NewShareClient(shareName)
370373
properties, err := shareClient.GetProperties(ctx, nil)
371374
if err != nil {
372375
glcm.Info(fmt.Sprintf("Warning: Failed to fetch share properties for '%s'. Assuming the share uses '%s' protocol based on --from-to flag.", shareName, givenValue))
@@ -383,49 +386,16 @@ func getShareProtocolType(
383386
// Protocol compatibility validation for SMB and NFS transfers
384387
func validateProtocolCompatibility(ctx context.Context, fromTo common.FromTo, src, dst common.ResourceString, srcClient, dstClient *common.ServiceClient) error {
385388

386-
getUploadDownloadProtocol := func(fromTo common.FromTo) common.Location {
387-
switch fromTo {
388-
case common.EFromTo.LocalFile(), common.EFromTo.FileLocal():
389-
return common.ELocation.File()
390-
case common.EFromTo.LocalFileNFS(), common.EFromTo.FileNFSLocal():
391-
return common.ELocation.FileNFS()
392-
default:
393-
return common.ELocation.Unknown()
389+
if fromTo.From().IsFile() {
390+
if err := validateShareProtocolCompatibility(ctx, src, srcClient, true, fromTo.From(), fromTo); err != nil {
391+
return err
394392
}
395393
}
396394

397-
var srcProtocol, dstProtocol common.Location
398-
399-
// S2S Transfers
400-
if fromTo.IsS2S() {
401-
switch fromTo {
402-
case common.EFromTo.FileFile():
403-
srcProtocol, dstProtocol = common.ELocation.File(), common.ELocation.File()
404-
case common.EFromTo.FileNFSFileNFS():
405-
srcProtocol, dstProtocol = common.ELocation.FileNFS(), common.ELocation.FileNFS()
406-
case common.EFromTo.FileNFSFileSMB():
407-
srcProtocol, dstProtocol = common.ELocation.FileNFS(), common.ELocation.File()
408-
case common.EFromTo.FileSMBFileNFS():
409-
srcProtocol, dstProtocol = common.ELocation.File(), common.ELocation.FileNFS()
410-
}
411-
412-
// Validate both source and destination
413-
if err := validateShareProtocolCompatibility(ctx, src, srcClient, true, srcProtocol, fromTo); err != nil {
395+
if fromTo.To().IsFile() {
396+
if err := validateShareProtocolCompatibility(ctx, dst, dstClient, false, fromTo.To(), fromTo); err != nil {
414397
return err
415398
}
416-
return validateShareProtocolCompatibility(ctx, dst, dstClient, false, dstProtocol, fromTo)
417-
}
418-
419-
// Uploads to File Shares
420-
if fromTo.IsUpload() {
421-
dstProtocol = getUploadDownloadProtocol(fromTo)
422-
return validateShareProtocolCompatibility(ctx, dst, dstClient, false, dstProtocol, fromTo)
423-
}
424-
425-
// Downloads from File Shares
426-
if fromTo.IsDownload() {
427-
srcProtocol = getUploadDownloadProtocol(fromTo)
428-
return validateShareProtocolCompatibility(ctx, src, srcClient, true, srcProtocol, fromTo)
429399
}
430400

431401
return nil

cmd/flagsValidation_test.go

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package cmd
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/Azure/azure-storage-azcopy/v10/common"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestValidateProtocolCompatibility(t *testing.T) {
12+
a := assert.New(t)
13+
ctx := context.Background()
14+
15+
// Test cases where validation should NOT be called (no File locations involved)
16+
testCases := []struct {
17+
name string
18+
fromTo common.FromTo
19+
shouldValidate bool
20+
description string
21+
}{
22+
{
23+
name: "S3ToBlob",
24+
fromTo: common.EFromTo.S3Blob(),
25+
shouldValidate: false,
26+
description: "S3 to Blob should not validate (neither side is File)",
27+
},
28+
{
29+
name: "GCPToBlob",
30+
fromTo: common.EFromTo.GCPBlob(),
31+
shouldValidate: false,
32+
description: "GCP to Blob should not validate (neither side is File)",
33+
},
34+
{
35+
name: "LocalToBlob",
36+
fromTo: common.EFromTo.LocalBlob(),
37+
shouldValidate: false,
38+
description: "Local to Blob should not validate (neither side is File)",
39+
},
40+
{
41+
name: "BlobToLocal",
42+
fromTo: common.EFromTo.BlobLocal(),
43+
shouldValidate: false,
44+
description: "Blob to Local should not validate (neither side is File)",
45+
},
46+
{
47+
name: "BlobToBlob",
48+
fromTo: common.EFromTo.BlobBlob(),
49+
shouldValidate: false,
50+
description: "Blob to Blob should not validate (neither side is File)",
51+
},
52+
{
53+
name: "LocalToBlobFS",
54+
fromTo: common.EFromTo.LocalBlobFS(),
55+
shouldValidate: false,
56+
description: "Local to BlobFS should not validate (neither side is File)",
57+
},
58+
{
59+
name: "LocalToFile",
60+
fromTo: common.EFromTo.LocalFile(),
61+
shouldValidate: true,
62+
description: "Local to File should validate (destination is File)",
63+
},
64+
{
65+
name: "FileToLocal",
66+
fromTo: common.EFromTo.FileLocal(),
67+
shouldValidate: true,
68+
description: "File to Local should validate (source is File)",
69+
},
70+
{
71+
name: "LocalToFileNFS",
72+
fromTo: common.EFromTo.LocalFileNFS(),
73+
shouldValidate: true,
74+
description: "Local to FileNFS should validate (destination is FileNFS)",
75+
},
76+
{
77+
name: "FileNFSToLocal",
78+
fromTo: common.EFromTo.FileNFSLocal(),
79+
shouldValidate: true,
80+
description: "FileNFS to Local should validate (source is FileNFS)",
81+
},
82+
{
83+
name: "FileToFile",
84+
fromTo: common.EFromTo.FileFile(),
85+
shouldValidate: true,
86+
description: "File to File should validate (both sides are File)",
87+
},
88+
{
89+
name: "FileNFSToFileNFS",
90+
fromTo: common.EFromTo.FileNFSFileNFS(),
91+
shouldValidate: true,
92+
description: "FileNFS to FileNFS should validate (both sides are FileNFS)",
93+
},
94+
{
95+
name: "FileToBlob",
96+
fromTo: common.EFromTo.FileBlob(),
97+
shouldValidate: true,
98+
description: "File to Blob should validate (source is File)",
99+
},
100+
{
101+
name: "BlobToFile",
102+
fromTo: common.EFromTo.BlobFile(),
103+
shouldValidate: true,
104+
description: "Blob to File should validate (destination is File)",
105+
},
106+
}
107+
108+
for _, tc := range testCases {
109+
t.Run(tc.name, func(t *testing.T) {
110+
// Create dummy resource strings
111+
src := common.ResourceString{Value: "https://source.example.com/path"}
112+
dst := common.ResourceString{Value: "https://dest.example.com/path"}
113+
114+
// For non-File transfers, we can pass nil service clients since validation should be skipped
115+
// For File transfers, we would need proper service clients, but we're testing the conditional logic
116+
var srcClient, dstClient *common.ServiceClient
117+
118+
if !tc.shouldValidate {
119+
// Test that validation is skipped when no File locations are involved
120+
// This should not panic even with nil service clients
121+
err := validateProtocolCompatibility(ctx, tc.fromTo, src, dst, srcClient, dstClient)
122+
a.NoError(err, "validateProtocolCompatibility should not fail for %s: %s", tc.name, tc.description)
123+
} else {
124+
// For File transfers, we expect the function to attempt validation
125+
// Since we're passing nil service clients, we expect it to fail gracefully
126+
// This tests that the conditional logic correctly identifies File transfers
127+
err := validateProtocolCompatibility(ctx, tc.fromTo, src, dst, srcClient, dstClient)
128+
// We expect an error here because we're passing nil service clients for File transfers
129+
// The important thing is that it doesn't panic and attempts validation
130+
if tc.fromTo.From().IsFile() || tc.fromTo.To().IsFile() {
131+
a.Error(err, "validateProtocolCompatibility should attempt validation for %s and fail with nil clients: %s", tc.name, tc.description)
132+
}
133+
}
134+
})
135+
}
136+
}
137+
138+
func TestValidateProtocolCompatibility_ConditionalLogic(t *testing.T) {
139+
a := assert.New(t)
140+
ctx := context.Background()
141+
142+
// Test the specific conditional logic
143+
src := common.ResourceString{Value: "https://source.example.com/path"}
144+
dst := common.ResourceString{Value: "https://dest.example.com/path"}
145+
146+
// Test that S3->Blob doesn't call validation (should not panic with nil clients)
147+
err := validateProtocolCompatibility(ctx, common.EFromTo.S3Blob(), src, dst, nil, nil)
148+
a.NoError(err, "S3->Blob should skip validation and not panic with nil service clients")
149+
150+
// Test that GCP->Blob doesn't call validation (should not panic with nil clients)
151+
err = validateProtocolCompatibility(ctx, common.EFromTo.GCPBlob(), src, dst, nil, nil)
152+
a.NoError(err, "GCP->Blob should skip validation and not panic with nil service clients")
153+
154+
// Test that Local->Blob doesn't call validation (should not panic with nil clients)
155+
err = validateProtocolCompatibility(ctx, common.EFromTo.LocalBlob(), src, dst, nil, nil)
156+
a.NoError(err, "Local->Blob should skip validation and not panic with nil service clients")
157+
}

common/util.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ func GetServiceClientForLocation(loc Location,
223223
return ret, nil
224224

225225
default:
226-
return nil, nil
226+
return ret, nil
227227
}
228228
}
229229

0 commit comments

Comments
 (0)