Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue identified on websocket endpoint URL type set to http #3579

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions adapter/internal/oasparser/model/api_env.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func retrieveEndpointsFromEnv(apiHashValue string) ([]Endpoint, []Endpoint) {
break
}

productionEndpoint, err := getHostandBasepathandPort(productionEndpointURL)
productionEndpoint, err := getHTTPEndpoint(productionEndpointURL)
if err != nil {
loggers.LoggerAPI.Errorf("error while reading production endpoint : %v in env variables, %v", productionEndpointURL, err.Error())
} else if productionEndpoint != nil {
Expand All @@ -52,7 +52,7 @@ func retrieveEndpointsFromEnv(apiHashValue string) ([]Endpoint, []Endpoint) {
break
}

sandboxEndpoint, err := getHostandBasepathandPort(sandboxEndpointURL)
sandboxEndpoint, err := getHTTPEndpoint(sandboxEndpointURL)
if err != nil {
loggers.LoggerAPI.Errorf("error while reading sandbox endpoint : %v in env variables, %v", sandboxEndpointURL, err.Error())
} else if sandboxEndpoint != nil {
Expand Down
2 changes: 1 addition & 1 deletion adapter/internal/oasparser/model/async_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func (swagger *MgwSwagger) SetInfoAsyncAPI(asyncAPI AsyncAPI) error {
swagger.apiType = WS

if asyncAPI.Servers.Production.URL != "" {
endpoint, err := getEndpointForWebsocketURL(asyncAPI.Servers.Production.URL)
endpoint, err := getWebSocketEndpoint(asyncAPI.Servers.Production.URL)
if err == nil {
productionEndpoints := append([]Endpoint{}, *endpoint)
swagger.productionEndpoints = generateEndpointCluster("clusterProd",
Expand Down
2 changes: 2 additions & 0 deletions adapter/internal/oasparser/model/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ const (
WS string = "WS"
// WEBHOOK - API type for WEBHOOK APIs
WEBHOOK string = "WEBHOOK"
// GRAPHQL - API type for GRAPHQL APIs
GRAPHQL string = "GRAPHQL"
)

// Constants to represent errors
Expand Down
20 changes: 10 additions & 10 deletions adapter/internal/oasparser/model/mgw_swagger.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func (swagger *MgwSwagger) SetEnvLabelProperties(envProps synchronizer.APIEnvPro
if envProps.APIConfigs.SandboxEndpointChoreo != "" && !conf.ControlPlane.DynamicEnvironments.Enabled {
logger.LoggerOasparser.Infof("SandboxEndpointChoreo is found in env properties for %v : %v",
swagger.title, swagger.version)
endpoint, err := getHostandBasepathandPort(envProps.APIConfigs.SandboxEndpointChoreo)
endpoint, err := getHTTPEndpoint(envProps.APIConfigs.SandboxEndpointChoreo)
if err == nil {
productionUrls = append(productionUrls, *endpoint)
} else {
Expand All @@ -487,7 +487,7 @@ func (swagger *MgwSwagger) SetEnvLabelProperties(envProps synchronizer.APIEnvPro
if envProps.APIConfigs.ProductionEndpoint != "" {
logger.LoggerOasparser.Infof("Production endpoints are found in env properties for %v : %v",
swagger.title, swagger.version)
endpoint, err := getHostandBasepathandPort(envProps.APIConfigs.ProductionEndpoint)
endpoint, err := getHTTPEndpoint(envProps.APIConfigs.ProductionEndpoint)
if err == nil {
productionUrls = append(productionUrls, *endpoint)
} else {
Expand All @@ -503,7 +503,7 @@ func (swagger *MgwSwagger) SetEnvLabelProperties(envProps synchronizer.APIEnvPro

if envProps.APIConfigs.SandBoxEndpoint != "" {
logger.LoggerOasparser.Infof("Sandbox endpoints are found in env properties %v : %v", swagger.title, swagger.version)
endpoint, err := getHostandBasepathandPort(envProps.APIConfigs.SandBoxEndpoint)
endpoint, err := getHTTPEndpoint(envProps.APIConfigs.SandBoxEndpoint)
if err == nil {
sandboxUrls = append(sandboxUrls, *endpoint)
} else {
Expand Down Expand Up @@ -931,15 +931,15 @@ func processEndpointUrls(urlsArray []interface{}) ([]Endpoint, error) {
logger.LoggerOasparser.Error("Consul syntax parse error ", err)
continue
}
endpoint, err := getHostandBasepathandPort(defHost)
endpoint, err := getHTTPEndpoint(defHost)
if err == nil {
endpoint.ServiceDiscoveryString = queryString
endpoints = append(endpoints, *endpoint)
} else {
return nil, err
}
} else {
endpoint, err := getHostandBasepathandPort(v.(string))
endpoint, err := getHTTPEndpoint(v.(string))
if err == nil {
endpoints = append(endpoints, *endpoint)
} else {
Expand Down Expand Up @@ -1112,7 +1112,7 @@ func (swagger *MgwSwagger) GetInterceptor(vendorExtensions map[string]interface{
//serviceURL mandatory
if v, found := val[serviceURL]; found {
serviceURLV := v.(string)
endpoint, err := getHostandBasepathandPort(serviceURLV)
endpoint, err := getHTTPEndpoint(serviceURLV)
if err != nil {
logger.LoggerOasparser.Error("Error reading interceptors service url value", err)
return InterceptEndpoint{}, errors.New("error reading interceptors service url value")
Expand Down Expand Up @@ -1305,7 +1305,7 @@ func (swagger *MgwSwagger) PopulateSwaggerFromAPIYaml(apiData APIYaml, apiType s
var unProcessedURLs []interface{}
for _, endpointConfig := range endpointConfig.ProductionEndpoints {
if apiType == WS {
prodEndpoint, err := getEndpointForWebsocketURL(endpointConfig.Endpoint)
prodEndpoint, err := getWebSocketEndpoint(endpointConfig.Endpoint)
if err == nil {
endpoints = append(endpoints, *prodEndpoint)
} else {
Expand All @@ -1319,7 +1319,7 @@ func (swagger *MgwSwagger) PopulateSwaggerFromAPIYaml(apiData APIYaml, apiType s
endpointType = FailOver
for _, endpointConfig := range endpointConfig.ProductionFailoverEndpoints {
if apiType == WS {
failoverEndpoint, err := getEndpointForWebsocketURL(endpointConfig.Endpoint)
failoverEndpoint, err := getWebSocketEndpoint(endpointConfig.Endpoint)
if err == nil {
endpoints = append(endpoints, *failoverEndpoint)
} else {
Expand Down Expand Up @@ -1347,7 +1347,7 @@ func (swagger *MgwSwagger) PopulateSwaggerFromAPIYaml(apiData APIYaml, apiType s
var unProcessedURLs []interface{}
for _, endpointConfig := range endpointConfig.SandBoxEndpoints {
if apiType == WS {
sandBoxEndpoint, err := getEndpointForWebsocketURL(endpointConfig.Endpoint)
sandBoxEndpoint, err := getWebSocketEndpoint(endpointConfig.Endpoint)
if err == nil {
endpoints = append(endpoints, *sandBoxEndpoint)
} else {
Expand All @@ -1361,7 +1361,7 @@ func (swagger *MgwSwagger) PopulateSwaggerFromAPIYaml(apiData APIYaml, apiType s
endpointType = FailOver
for _, endpointConfig := range endpointConfig.SandboxFailoverEndpoints {
if apiType == WS {
failoverEndpoint, err := getEndpointForWebsocketURL(endpointConfig.Endpoint)
failoverEndpoint, err := getWebSocketEndpoint(endpointConfig.Endpoint)
if err == nil {
endpoints = append(endpoints, *failoverEndpoint)
} else {
Expand Down
83 changes: 26 additions & 57 deletions adapter/internal/oasparser/model/open_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (swagger *MgwSwagger) SetInfoOpenAPI(swagger3 openapi3.Swagger) error {
if len(serverEntry.URL) == 0 || strings.HasPrefix(serverEntry.URL, "/") {
continue
}
endpoint, err := getHostandBasepathandPort(serverEntry.URL)
endpoint, err := getHTTPEndpoint(serverEntry.URL)
if err == nil {
productionUrls = append(productionUrls, *endpoint)
swagger.xWso2Basepath = endpoint.Basepath
Expand Down Expand Up @@ -160,7 +160,7 @@ func setResourcesOpenAPI(openAPI openapi3.Swagger) ([]*Resource, error) {
if len(serverEntry.URL) == 0 || strings.HasPrefix(serverEntry.URL, "/") {
continue
}
endpoint, err := getHostandBasepathandPort(serverEntry.URL)
endpoint, err := getHTTPEndpoint(serverEntry.URL)
if err == nil {
productionUrls = append(productionUrls, *endpoint)

Expand Down Expand Up @@ -207,21 +207,36 @@ func getOperationLevelDetails(operation *openapi3.Operation, method string) *Ope

}

func getHTTPEndpoint(rawURL string) (*Endpoint, error) {
return getHostandBasepathandPort(HTTP, rawURL)
}

func getWebSocketEndpoint(rawURL string) (*Endpoint, error) {
return getHostandBasepathandPort(WS, rawURL)
}

// getHostandBasepathandPort retrieves host, basepath and port from the endpoint defintion
// from of the production endpoints url entry, combination of schemes and host (in openapi v2)
// or server property.
//
// if no scheme is mentioned before the hostname, urlType would be assigned as http
func getHostandBasepathandPort(rawURL string) (*Endpoint, error) {
func getHostandBasepathandPort(apiType string, rawURL string) (*Endpoint, error) {
var (
basepath string
host string
port uint32
urlType string
)

// Remove leading and trailing spaces of rawURL
rawURL = strings.TrimSpace(rawURL)

if !strings.Contains(rawURL, "://") {
rawURL = "http://" + rawURL
if apiType == HTTP || apiType == GRAPHQL || apiType == WEBHOOK {
rawURL = "http://" + rawURL
} else if apiType == WS {
rawURL = "ws://" + rawURL
}
}
parsedURL, err := url.Parse(rawURL)
if err != nil {
Expand All @@ -244,7 +259,7 @@ func getHostandBasepathandPort(rawURL string) (*Endpoint, error) {
}
port = uint32(u32)
} else {
if strings.HasPrefix(rawURL, "https://") {
if strings.HasPrefix(rawURL, "https://") || strings.HasPrefix(rawURL, "wss://") {
port = uint32(443)
} else {
port = uint32(80)
Expand All @@ -254,8 +269,12 @@ func getHostandBasepathandPort(rawURL string) (*Endpoint, error) {
urlType = "http"
if strings.HasPrefix(rawURL, "https://") {
urlType = "https"
} else if !strings.HasPrefix(rawURL, "http://") {
rawURL = "http://" + rawURL
} else if strings.HasPrefix(rawURL, "http://") {
urlType = "http"
} else if strings.HasPrefix(rawURL, "wss://") {
urlType = "wss"
} else if strings.HasPrefix(rawURL, "ws://") {
urlType = "ws"
}

return &Endpoint{Host: host, Basepath: basepath, Port: port, URLType: urlType, RawURL: rawURL}, nil
Expand Down Expand Up @@ -329,53 +348,3 @@ func GetXWso2Label(vendorExtensions openapi3.ExtensionProps) []string {
}
return []string{"default"}
}

func getEndpointForWebsocketURL(rawURL string) (*Endpoint, error) {
var (
basepath string
host string
port uint32
urlType string
)
if !strings.Contains(rawURL, "://") {
rawURL = "ws://" + rawURL
}
parsedURL, err := url.Parse(rawURL)
if err != nil {
logger.LoggerOasparser.Errorf("Failed to parse the malformed endpoint %v. Error message: %v", rawURL, err)
return nil, err
}

// Hostname validation
if !regexp.MustCompile(hostNameValidator).MatchString(parsedURL.Hostname()) {
logger.LoggerOasparser.Error("Malformed endpoint detected (Invalid host name) : ", rawURL)
return nil, errors.New("malformed endpoint detected (Invalid host name) : " + rawURL)
}

host = parsedURL.Hostname()
if parsedURL.Path == "" {
basepath = "/"
} else {
basepath = parsedURL.Path
}
if parsedURL.Port() != "" {
u32, err := strconv.ParseUint(parsedURL.Port(), 10, 32)
if err != nil {
logger.LoggerOasparser.Error("Error passing port value to mgwSwagger", err)
}
port = uint32(u32)
} else {
if strings.HasPrefix(rawURL, "wss://") {
port = uint32(443)
} else {
port = uint32(80)
}
}
urlType = "ws"
if strings.HasPrefix(rawURL, "wss://") {
urlType = "wss"
} else if !strings.HasPrefix(rawURL, "ws://") {
rawURL = "ws://" + rawURL
}
return &Endpoint{Host: host, Basepath: basepath, Port: port, URLType: urlType, RawURL: rawURL}, nil
}
4 changes: 2 additions & 2 deletions adapter/internal/oasparser/model/open_api_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ func TestGetHostandBasepathandPort(t *testing.T) {
},
}
for _, item := range dataItems {
resultResources, err := getHostandBasepathandPort(item.input)
resultResources, err := getHTTPEndpoint(item.input)
assert.Equal(t, item.result, resultResources, item.message)
if resultResources != nil {
assert.Nil(t, err, "Error encountered when processing the endpoint")
Expand Down Expand Up @@ -219,7 +219,7 @@ func TestMalformedUrl(t *testing.T) {
}

for index := range suspectedRawUrls {
response, _ := getHostandBasepathandPort(suspectedRawUrls[index])
response, _ := getHTTPEndpoint(suspectedRawUrls[index])
assert.Nil(t, response)
}

Expand Down
2 changes: 1 addition & 1 deletion adapter/internal/oasparser/model/swagger.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (swagger *MgwSwagger) SetInfoSwagger(swagger2 spec.Swagger) error {
swagger2.Info.Title, swagger2.Info.Version)
}
}
endpoint, err := getHostandBasepathandPort(urlScheme + swagger2.Host + swagger2.BasePath)
endpoint, err := getHTTPEndpoint(urlScheme + swagger2.Host + swagger2.BasePath)
if err == nil {
productionEndpoints := append([]Endpoint{}, *endpoint)
swagger.productionEndpoints = generateEndpointCluster(prodClustersConfigNamePrefix, productionEndpoints, LoadBalance)
Expand Down
16 changes: 8 additions & 8 deletions adapter/internal/oasparser/operator/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ func TestMgwSwaggerWebSocketSand(t *testing.T) {
testGetMgwSwaggerWebSocket(t, apiYamlFilePath)
}

//Test execution for GetOpenAPIVersionAndJSONContent
// Test execution for GetOpenAPIVersionAndJSONContent
func TestGetOpenAPIVersionAndJSONContent(t *testing.T) {

apiYamlFilePath := config.GetMgwHome() + "/../adapter/test-resources/envoycodegen"
Expand All @@ -164,7 +164,7 @@ func TestGetOpenAPIVersionAndJSONContent(t *testing.T) {
}
}

//helper function to test GetOpenAPIVersionAndJSONContent
// helper function to test GetOpenAPIVersionAndJSONContent
func testGetOpenAPIVersionAndJSONContent(t *testing.T, apiYamlFilePath string) {

apiYamlByteArr, err := ioutil.ReadFile(apiYamlFilePath)
Expand Down Expand Up @@ -195,7 +195,7 @@ func testGetOpenAPIVersionAndJSONContent(t *testing.T, apiYamlFilePath string) {

}

//Test execution for TestGetOpenAPIV3Struct
// Test execution for TestGetOpenAPIV3Struct
func TestGetOpenAPIV3Struct(t *testing.T) {
apiYamlFilePath := config.GetMgwHome() + "/../adapter/test-resources/envoycodegen"
files, err := ioutil.ReadDir(apiYamlFilePath)
Expand All @@ -211,7 +211,7 @@ func TestGetOpenAPIV3Struct(t *testing.T) {
}
}

//helper function for TestGetOpenAPIV3Struct
// helper function for TestGetOpenAPIV3Struct
func testGetOpenAPIV3Struct(t *testing.T, apiYamlFilePath string) {
apiYamlByteArr, err := ioutil.ReadFile(apiYamlFilePath)
assert.Nil(t, err, "Error while reading the openapi.yaml file : %v"+apiYamlFilePath)
Expand Down Expand Up @@ -245,14 +245,14 @@ func testGetMgwSwaggerWebSocket(t *testing.T, apiYamlFilePath string) {
productionEndpoints := mgwSwagger.GetProdEndpoints().Endpoints
productionEndpoint := productionEndpoints[0]
assert.Equal(t, productionEndpoint.Host, "echo.websocket.org", "mgwSwagger production endpoint host mismatch")
assert.Equal(t, productionEndpoint.Basepath, "/", "mgwSwagger production endpoint basepath mistmatch")
assert.Equal(t, productionEndpoint.Basepath, "", "mgwSwagger production endpoint basepath mistmatch")
assert.Equal(t, productionEndpoint.URLType, "ws", "mgwSwagger production endpoint URLType mismatch")
var port uint32 = 80
assert.Equal(t, productionEndpoint.Port, port, "mgwSwagger production endpoint port mismatch")
sandboxEndpoints := mgwSwagger.GetSandEndpoints().Endpoints
sandboxEndpoint := sandboxEndpoints[0]
assert.Equal(t, sandboxEndpoint.Host, "echo.websocket.org", "mgwSwagger sandbox endpoint host mismatch")
assert.Equal(t, sandboxEndpoint.Basepath, "/", "mgwSwagger sandbox endpoint basepath mistmatch")
assert.Equal(t, sandboxEndpoint.Basepath, "", "mgwSwagger sandbox endpoint basepath mistmatch")
assert.Equal(t, sandboxEndpoint.URLType, "ws", "mgwSwagger sandbox endpoint URLType mismatch")
assert.Equal(t, sandboxEndpoint.Port, port, "mgwSwagger sandbox endpoint port mismatch")
}
Expand All @@ -265,7 +265,7 @@ func testGetMgwSwaggerWebSocket(t *testing.T, apiYamlFilePath string) {
productionEndpoint := productionEndpoints[0]
var port uint32 = 80
assert.Equal(t, productionEndpoint.Host, "echo.websocket.org", "mgwSwagger production endpoint host mismatch")
assert.Equal(t, productionEndpoint.Basepath, "/", "mgwSwagger production endpoint basepath mistmatch")
assert.Equal(t, productionEndpoint.Basepath, "", "mgwSwagger production endpoint basepath mistmatch")
assert.Equal(t, productionEndpoint.URLType, "ws", "mgwSwagger production endpoint URLType mismatch")
assert.Equal(t, productionEndpoint.Port, port, "mgwSwagger production endpoint port mismatch")
sandboxEndpoints := mgwSwagger.GetSandEndpoints()
Expand All @@ -281,7 +281,7 @@ func testGetMgwSwaggerWebSocket(t *testing.T, apiYamlFilePath string) {
sandboxEndpoints := mgwSwagger.GetSandEndpoints().Endpoints
sandboxEndpoint := sandboxEndpoints[0]
assert.Equal(t, sandboxEndpoint.Host, "echo.websocket.org", "mgwSwagger sandbox endpoint host mismatch")
assert.Equal(t, sandboxEndpoint.Basepath, "/", "mgwSwagger sandbox endpoint basepath mistmatch")
assert.Equal(t, sandboxEndpoint.Basepath, "", "mgwSwagger sandbox endpoint basepath mistmatch")
assert.Equal(t, sandboxEndpoint.URLType, "ws", "mgwSwagger sandbox endpoint URLType mismatch")
assert.Equal(t, sandboxEndpoint.Port, port, "mgwSwagger sandbox endpoint port mismatch")
productionEndpoints := mgwSwagger.GetProdEndpoints()
Expand Down
Loading