Skip to content

Commit

Permalink
Refactor postgres scaler config (#6262)
Browse files Browse the repository at this point in the history
Signed-off-by: Rushen Wang <[email protected]>
  • Loading branch information
dovics authored Oct 24, 2024
1 parent 2cf3c4c commit b2ce95d
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 99 deletions.
158 changes: 63 additions & 95 deletions pkg/scalers/postgresql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -42,12 +41,46 @@ type postgreSQLScaler struct {
}

type postgreSQLMetadata struct {
targetQueryValue float64
activationTargetQueryValue float64
connection string
query string
TargetQueryValue float64 `keda:"name=targetQueryValue, order=triggerMetadata, optional"`
ActivationTargetQueryValue float64 `keda:"name=activationTargetQueryValue, order=triggerMetadata, optional"`
Connection string `keda:"name=connection, order=authParams;resolvedEnv, optional"`
Query string `keda:"name=query, order=triggerMetadata"`
triggerIndex int
azureAuthContext azureAuthContext

Host string `keda:"name=host, order=authParams;triggerMetadata, optional"`
Port string `keda:"name=port, order=authParams;triggerMetadata, optional"`
UserName string `keda:"name=userName, order=authParams;triggerMetadata, optional"`
DBName string `keda:"name=dbName, order=authParams;triggerMetadata, optional"`
SslMode string `keda:"name=sslmode, order=authParams;triggerMetadata, optional"`

Password string `keda:"name=password, order=authParams;resolvedEnv, optional"`
}

func (p *postgreSQLMetadata) Validate() error {
if p.Connection == "" {
if p.Host == "" {
return fmt.Errorf("no host given")
}

if p.Port == "" {
return fmt.Errorf("no port given")
}

if p.UserName == "" {
return fmt.Errorf("no userName given")
}

if p.DBName == "" {
return fmt.Errorf("no dbName given")
}

if p.SslMode == "" {
return fmt.Errorf("no sslmode given")
}
}

return nil
}

type azureAuthContext struct {
Expand Down Expand Up @@ -83,66 +116,26 @@ func NewPostgreSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig
}

func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerConfig) (*postgreSQLMetadata, kedav1alpha1.AuthPodIdentity, error) {
meta := postgreSQLMetadata{}

meta := &postgreSQLMetadata{}
authPodIdentity := kedav1alpha1.AuthPodIdentity{}

if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, authPodIdentity, fmt.Errorf("no query given")
}

if val, ok := config.TriggerMetadata["targetQueryValue"]; ok {
targetQueryValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("queryValue parsing error %w", err)
}
meta.targetQueryValue = targetQueryValue
} else {
if config.AsMetricSource {
meta.targetQueryValue = 0
} else {
return nil, authPodIdentity, fmt.Errorf("no targetQueryValue given")
}
meta.triggerIndex = config.TriggerIndex
if err := config.TypedConfig(meta); err != nil {
return nil, authPodIdentity, fmt.Errorf("error parsing postgresql metadata: %w", err)
}

meta.activationTargetQueryValue = 0
if val, ok := config.TriggerMetadata["activationTargetQueryValue"]; ok {
activationTargetQueryValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("activationTargetQueryValue parsing error %w", err)
}
meta.activationTargetQueryValue = activationTargetQueryValue
if !config.AsMetricSource && meta.TargetQueryValue == 0 {
return nil, authPodIdentity, fmt.Errorf("no targetQueryValue given")
}

switch config.PodIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
switch {
case config.AuthParams["connection"] != "":
meta.connection = config.AuthParams["connection"]
case config.TriggerMetadata["connectionFromEnv"] != "":
meta.connection = config.ResolvedEnv[config.TriggerMetadata["connectionFromEnv"]]
default:
params, err := buildConnArray(config)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("failed to parse fields related to the connection")
}

var password string
if config.AuthParams["password"] != "" {
password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
}
params = append(params, "password="+escapePostgreConnectionParameter(password))
meta.connection = strings.Join(params, " ")
if meta.Connection == "" {
params := buildConnArray(meta)
params = append(params, "password="+escapePostgreConnectionParameter(meta.Password))
meta.Connection = strings.Join(params, " ")
}
case kedav1alpha1.PodIdentityProviderAzureWorkload:
params, err := buildConnArray(config)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("failed to parse fields related to the connection")
}
params := buildConnArray(meta)

cred, err := azure.NewChainedCredential(logger, config.PodIdentity)
if err != nil {
Expand All @@ -152,59 +145,34 @@ func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerCon
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.connection = strings.Join(params, " ")
meta.Connection = strings.Join(params, " ")
}
meta.triggerIndex = config.TriggerIndex

return &meta, authPodIdentity, nil
return meta, authPodIdentity, nil
}

func buildConnArray(config *scalersconfig.ScalerConfig) ([]string, error) {
func buildConnArray(meta *postgreSQLMetadata) []string {
var params []string
params = append(params, "host="+escapePostgreConnectionParameter(meta.Host))
params = append(params, "port="+escapePostgreConnectionParameter(meta.Port))
params = append(params, "user="+escapePostgreConnectionParameter(meta.UserName))
params = append(params, "dbname="+escapePostgreConnectionParameter(meta.DBName))
params = append(params, "sslmode="+escapePostgreConnectionParameter(meta.SslMode))

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, err
}

port, err := GetFromAuthOrMeta(config, "port")
if err != nil {
return nil, err
}

userName, err := GetFromAuthOrMeta(config, "userName")
if err != nil {
return nil, err
}

dbName, err := GetFromAuthOrMeta(config, "dbName")
if err != nil {
return nil, err
}

sslmode, err := GetFromAuthOrMeta(config, "sslmode")
if err != nil {
return nil, err
}
params = append(params, "host="+escapePostgreConnectionParameter(host))
params = append(params, "port="+escapePostgreConnectionParameter(port))
params = append(params, "user="+escapePostgreConnectionParameter(userName))
params = append(params, "dbname="+escapePostgreConnectionParameter(dbName))
params = append(params, "sslmode="+escapePostgreConnectionParameter(sslmode))

return params, nil
return params
}

func getConnection(ctx context.Context, meta *postgreSQLMetadata, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger) (*sql.DB, error) {
connectionString := meta.connection
connectionString := meta.Connection

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload {
accessToken, err := getAzureAccessToken(ctx, meta, azureDatabasePostgresResource)
if err != nil {
return nil, err
}
newPasswordField := "password=" + escapePostgreConnectionParameter(accessToken)
connectionString = passwordConnPattern.ReplaceAllString(meta.connection, newPasswordField)
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

db, err := sql.Open("pgx", connectionString)
Expand Down Expand Up @@ -245,7 +213,7 @@ func (s *postgreSQLScaler) getActiveNumber(ctx context.Context) (float64, error)
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&id)
err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&id)
if err != nil {
s.logger.Error(err, fmt.Sprintf("could not query postgreSQL: %s", err))
return 0, fmt.Errorf("could not query postgreSQL: %w", err)
Expand All @@ -259,7 +227,7 @@ func (s *postgreSQLScaler) GetMetricSpecForScaling(context.Context) []v2.MetricS
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString("postgresql")),
},
Target: GetMetricTargetMili(s.metricType, s.metadata.targetQueryValue),
Target: GetMetricTargetMili(s.metricType, s.metadata.TargetQueryValue),
}
metricSpec := v2.MetricSpec{
External: externalMetric, Type: externalMetricType,
Expand All @@ -276,7 +244,7 @@ func (s *postgreSQLScaler) GetMetricsAndActivity(ctx context.Context, metricName

metric := GenerateMetricInMili(metricName, num)

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetQueryValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetQueryValue, nil
}

func escapePostgreConnectionParameter(str string) string {
Expand Down
8 changes: 4 additions & 4 deletions pkg/scalers/postgresql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func TestPosgresSQLConnectionStringGeneration(t *testing.T) {
t.Fatal("Could not parse metadata:", err)
}

if meta.connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.connection)
if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}
Expand All @@ -104,8 +104,8 @@ func TestPodIdentityAzureWorkloadPosgresSQLConnectionStringGeneration(t *testing
t.Fatal("Could not parse metadata:", err)
}

if meta.connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.connection)
if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}
Expand Down

0 comments on commit b2ce95d

Please sign in to comment.