Skip to content

Commit

Permalink
add global connection pool for mysql
Browse files Browse the repository at this point in the history
Signed-off-by: Rob Pickerill <[email protected]>
  • Loading branch information
robpickerill committed Nov 10, 2024
1 parent 5c52d03 commit 218e3c9
Show file tree
Hide file tree
Showing 4 changed files with 499 additions and 7 deletions.
112 changes: 105 additions & 7 deletions pkg/scalers/mysql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net"
"strings"
"time"

"github.com/go-logr/logr"
"github.com/go-sql-driver/mysql"
Expand All @@ -16,6 +17,16 @@ import (
kedautil "github.com/kedacore/keda/v2/pkg/util"
)

var (
// A map that holds MySQL connection pools, keyed by connection string
connectionPools *kedautil.RefMap[string, *sql.DB]
)

func init() {
// Initialize the global connectionPools map
connectionPools = kedautil.NewRefMap[string, *sql.DB]()
}

type mySQLScaler struct {
metricType v2.MetricTargetType
metadata *mySQLMetadata
Expand All @@ -34,6 +45,12 @@ type mySQLMetadata struct {
QueryValue float64 `keda:"name=queryValue, order=triggerMetadata"`
ActivationQueryValue float64 `keda:"name=activationQueryValue, order=triggerMetadata, default=0"`
MetricName string `keda:"name=metricName, order=triggerMetadata, optional"`

// Connection pool settings
UseGlobalConnPools bool `keda:"name=useGlobalConnPools, order=triggerMetadata, optional"`
MaxOpenConns int `keda:"name=maxOpenConns, order=triggerMetadata, optional"`
MaxIdleConns int `keda:"name=maxIdleConns, order=triggerMetadata, optional"`
ConnMaxIdleTime int `keda:"name=connMaxIdleTime, order=triggerMetadata, optional"` // seconds
}

// NewMySQLScaler creates a new MySQL scaler
Expand All @@ -50,10 +67,19 @@ func NewMySQLScaler(config *scalersconfig.ScalerConfig) (Scaler, error) {
return nil, fmt.Errorf("error parsing MySQL metadata: %w", err)
}

conn, err := newMySQLConnection(meta, logger)
// Create MySQL connection, if useGlobalConnPools is set to true, it will use
// the global connection pool for the given connection string, otherwise it
// will create a new local connection pool for the given connection string
var conn *sql.DB
if meta.UseGlobalConnPools {
conn, err = getConnectionPool(meta, logger)
} else {
conn, err = newMySQLConnection(meta, logger)
}
if err != nil {
return nil, fmt.Errorf("error establishing MySQL connection: %w", err)
return nil, fmt.Errorf("error creating MySQL connection: %w", err)
}

return &mySQLScaler{
metricType: metricType,
metadata: meta,
Expand Down Expand Up @@ -96,6 +122,40 @@ func metadataToConnectionStr(meta *mySQLMetadata) string {
return connStr
}

// getConnectionPool will check if the connection pool has already been
// created for the given connection string and return it. If it has not
// been created, it will create a new connection pool and store it in the
// connectionPools map.
func getConnectionPool(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := metadataToConnectionStr(meta)
// Try to load an existing pool and increment its reference count if found
if pool, ok := connectionPools.Load(connStr); ok {
err := connectionPools.AddRef(connStr)
if err != nil {
logger.Error(err, "Error increasing connection pool reference count")
return nil, err
}

return pool, nil
}

// If pool does not exist, create a new one and store it in RefMap
newPool, err := newMySQLConnection(meta, logger)
if err != nil {
return nil, err
}
err = connectionPools.Store(connStr, newPool, func(db *sql.DB) error {
logger.Info("Closing MySQL connection pool", "connectionString", connStr)
return db.Close()
})
if err != nil {
logger.Error(err, "Error storing connection pool in RefMap")
return nil, err
}

return newPool, nil
}

// newMySQLConnection creates MySQL db connection
func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error) {
connStr := metadataToConnectionStr(meta)
Expand All @@ -104,14 +164,35 @@ func newMySQLConnection(meta *mySQLMetadata, logger logr.Logger) (*sql.DB, error
logger.Error(err, fmt.Sprintf("Found error when opening connection: %s", err))
return nil, err
}

err = db.Ping()
if err != nil {
logger.Error(err, fmt.Sprintf("Found error when pinging database: %s", err))
return nil, err
}

setConnectionPoolConfiguration(meta, db)

return db, nil
}

// setConnectionPoolConfiguration configures the MySQL connection pool settings
// based on the parameters provided in mySQLMetadata. If a setting is zero, it
// is left at its default value.
func setConnectionPoolConfiguration(meta *mySQLMetadata, db *sql.DB) {
if meta.MaxOpenConns > 0 {
db.SetMaxOpenConns(meta.MaxOpenConns)
}

if meta.MaxIdleConns > 0 {
db.SetMaxIdleConns(meta.MaxIdleConns)
}

if meta.ConnMaxIdleTime > 0 {
db.SetConnMaxIdleTime(time.Duration(meta.ConnMaxIdleTime) * time.Second)
}
}

// parseMySQLDbNameFromConnectionStr returns dbname from connection string
// in it is not able to parse it, it returns "dbname" string
func parseMySQLDbNameFromConnectionStr(connectionString string) string {
Expand All @@ -123,13 +204,30 @@ func parseMySQLDbNameFromConnectionStr(connectionString string) string {
return "dbname"
}

// Close disposes of MySQL connections
func (s *mySQLScaler) Close(context.Context) error {
err := s.connection.Close()
if err != nil {
s.logger.Error(err, "Error closing MySQL connection")
// Close disposes of MySQL connections, closing either the global pool if used
// or the local connection pool
func (s *mySQLScaler) Close(ctx context.Context) error {
if s.metadata.UseGlobalConnPools {
if err := s.closeGlobalPool(ctx); err != nil {
return fmt.Errorf("error closing MySQL connection: %w", err)
}
} else {
if err := s.connection.Close(); err != nil {
return fmt.Errorf("error closing MySQL connection: %w", err)
}
}

return nil
}

// closeGlobalPool closes all MySQL connections in the global pool
func (s *mySQLScaler) closeGlobalPool(_ context.Context) error {
connStr := metadataToConnectionStr(s.metadata)
if err := connectionPools.RemoveRef(connStr); err != nil {
s.logger.Error(err, "Error decreasing connection pool reference count")
return err
}

return nil
}

Expand Down
21 changes: 21 additions & 0 deletions pkg/scalers/mysql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,27 @@ var testMySQLMetadata = []parseMySQLMetadataTestData{
resolvedEnv: map[string]string{},
raisesError: true,
},
// use global pool
{
metadata: map[string]string{"query": "query", "queryValue": "12", "useGlobalConnPools": "true"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
// use connection pool settings
{
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
// use connection pool settings and global pool
{
metadata: map[string]string{"query": "query", "queryValue": "12", "maxOpenConns": "10", "maxIdleConns": "5", "connMaxIdleTime": "10", "useGlobalConnPools": "true"},
authParams: map[string]string{"host": "test_host", "port": "test_port", "username": "test_username", "password": "MYSQL_PASSWORD", "dbName": "test_dbname"},
resolvedEnv: testMySQLResolvedEnv,
raisesError: false,
},
}

var mySQLMetricIdentifiers = []mySQLMetricIdentifier{
Expand Down
123 changes: 123 additions & 0 deletions pkg/util/refmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package util

//nolint:depguard // sync/atomic
import (
"fmt"
"sync"
"sync/atomic"
)

// refCountedValue manages a reference-counted value with a cleanup function.
type refCountedValue[V any] struct {
value V
refCount atomic.Int64
closeFunc func(V) error // Cleanup function to call when count reaches zero
}

// Add increments the reference count.
func (r *refCountedValue[V]) Add() {
r.refCount.Add(1)
}

// Remove decrements the reference count and invokes closeFunc if the count
// reaches zero.
func (r *refCountedValue[V]) Remove() error {
if r.refCount.Add(-1) == 0 {
return r.closeFunc(r.value)
}

return nil
}

// Value returns the underlying value.
func (r *refCountedValue[V]) Value() V {
return r.value
}

// RefMap manages reference-counted items in a concurrent-safe map.
type RefMap[K comparable, V any] struct {
data map[K]*refCountedValue[V]
mu sync.RWMutex
}

// NewRefMap initializes a new RefMap. A RefMap is an atomic reference-counted
// concurrent hashmap. The general usage pattern is to Store a value with a
// close function, once the value is contained within the RefMap, it can be
// accessed via the Load method. The AddRef method signals ownership of the
// value and increments the reference count. The RemoveRef method decrements
// the reference count. When the reference count reaches zero, the close
// function is called and the value is removed from the map.
func NewRefMap[K comparable, V any]() *RefMap[K, V] {
return &RefMap[K, V]{
data: make(map[K]*refCountedValue[V]),
}
}

// Store adds a new item with an initial reference count of 1 and a close
// function.
func (r *RefMap[K, V]) Store(key K, value V, closeFunc func(V) error) error {
r.mu.Lock()
defer r.mu.Unlock()

if _, exists := r.data[key]; exists {
return fmt.Errorf("key already exists: %v", key)
}

r.data[key] = &refCountedValue[V]{value: value, refCount: atomic.Int64{}, closeFunc: closeFunc}
r.data[key].Add() // Set initial reference count to 1

return nil
}

// Load retrieves a value by key without modifying the reference count,
// returning the value and a boolean indicating if it was found. The reference
// count not being modified means that a check for the existence of a key
// can be performed with signalling ownership of the value. If the value is used
// after this method, it is recommended to call AddRef to increment the
// reference
func (r *RefMap[K, V]) Load(key K) (V, bool) {
r.mu.RLock()
defer r.mu.RUnlock()

if refValue, found := r.data[key]; found {
return refValue.Value(), true
}
var zero V

return zero, false
}

// AddRef increments the reference count for a key if it exists. Ensure
// to call RemoveRef when done with the value to prevent memory leaks.
func (r *RefMap[K, V]) AddRef(key K) error {
r.mu.RLock()
defer r.mu.RUnlock()

refValue, found := r.data[key]
if !found {
return fmt.Errorf("key not found: %v", key)
}

refValue.Add()
return nil
}

// RemoveRef decrements the reference count and deletes the entry if count
// reaches zero.
func (r *RefMap[K, V]) RemoveRef(key K) error {
r.mu.Lock()
defer r.mu.Unlock()

refValue, found := r.data[key]
if !found {
return fmt.Errorf("key not found: %v", key)
}

err := refValue.Remove()

if refValue.refCount.Load() == 0 {
delete(r.data, key)
}

return err // returns the error from closeFunc
}
Loading

0 comments on commit 218e3c9

Please sign in to comment.