diff --git a/pkg/cache/redis/redis.go b/pkg/cache/redis/redis.go new file mode 100644 index 00000000..4ff9ffd6 --- /dev/null +++ b/pkg/cache/redis/redis.go @@ -0,0 +1,161 @@ +package redis + +import ( + "crypto/tls" + "crypto/x509" + + "github.com/redis/go-redis/v9" + + "github.com/scribd/go-sdk/pkg/cache" +) + +func New(cfg *cache.Redis) (redis.UniversalClient, error) { + opts, err := cfgToRedisClientOptions(cfg) + if err != nil { + return nil, err + } + + return redis.NewUniversalClient(opts), nil +} + +func cfgToRedisClientOptions(cfg *cache.Redis) (*redis.UniversalOptions, error) { + var err error + var clusterOptions *redis.ClusterOptions + if cfg.URL != "" { + clusterOptions, err = redis.ParseClusterURL(cfg.URL) + if err != nil { + return nil, err + } + } + + opts := &redis.UniversalOptions{ + Addrs: cfg.Addrs, + DB: cfg.DB, + ClientName: cfg.ClientName, + + Protocol: cfg.Protocol, + Username: cfg.Username, + Password: cfg.Password, + + SentinelUsername: cfg.SentinelUsername, + SentinelPassword: cfg.SentinelPassword, + + MaxRetries: cfg.MaxRetries, + MinRetryBackoff: cfg.MinRetryBackoff, + MaxRetryBackoff: cfg.MaxRetryBackoff, + + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + ContextTimeoutEnabled: cfg.ContextTimeoutEnabled, + + PoolSize: cfg.PoolSize, + PoolTimeout: cfg.PoolTimeout, + MaxIdleConns: cfg.MaxIdleConns, + MinIdleConns: cfg.MinIdleConns, + MaxActiveConns: cfg.MaxActiveConns, + ConnMaxIdleTime: cfg.ConnMaxIdleTime, + ConnMaxLifetime: cfg.ConnMaxLifetime, + + MaxRedirects: cfg.MaxRedirects, + ReadOnly: cfg.ReadOnly, + RouteByLatency: cfg.RouteByLatency, + RouteRandomly: cfg.RouteRandomly, + + MasterName: cfg.MasterName, + DisableIndentity: cfg.DisableIndentity, + IdentitySuffix: cfg.IdentitySuffix, + } + if clusterOptions != nil { + opts.Addrs = clusterOptions.Addrs + opts.ClientName = clusterOptions.ClientName + + opts.Protocol = clusterOptions.Protocol + opts.Username = clusterOptions.Username + opts.Password = clusterOptions.Password + + if clusterOptions.MaxRetries != 0 { + opts.MaxRetries = clusterOptions.MaxRetries + } + if clusterOptions.MinRetryBackoff != 0 { + opts.MinRetryBackoff = clusterOptions.MinRetryBackoff + } + if clusterOptions.MaxRetryBackoff != 0 { + opts.MaxRetryBackoff = clusterOptions.MaxRetryBackoff + } + + if clusterOptions.DialTimeout != 0 { + opts.DialTimeout = clusterOptions.DialTimeout + } + if clusterOptions.ReadTimeout != 0 { + opts.ReadTimeout = clusterOptions.ReadTimeout + } + if clusterOptions.WriteTimeout != 0 { + opts.WriteTimeout = clusterOptions.WriteTimeout + } + if clusterOptions.ContextTimeoutEnabled { + opts.ContextTimeoutEnabled = clusterOptions.ContextTimeoutEnabled + } + + if clusterOptions.PoolSize != 0 { + opts.PoolSize = clusterOptions.PoolSize + } + if clusterOptions.PoolTimeout != 0 { + opts.PoolTimeout = clusterOptions.PoolTimeout + } + if clusterOptions.MaxIdleConns != 0 { + opts.MaxIdleConns = clusterOptions.MaxIdleConns + } + if clusterOptions.MinIdleConns != 0 { + opts.MinIdleConns = clusterOptions.MinIdleConns + } + if clusterOptions.MaxActiveConns != 0 { + opts.MaxActiveConns = clusterOptions.MaxActiveConns + } + if clusterOptions.ConnMaxIdleTime != 0 { + opts.ConnMaxIdleTime = clusterOptions.ConnMaxIdleTime + } + if clusterOptions.ConnMaxLifetime != 0 { + opts.ConnMaxLifetime = clusterOptions.ConnMaxLifetime + } + + if clusterOptions.MaxRedirects != 0 { + opts.MaxRedirects = clusterOptions.MaxRedirects + } + if clusterOptions.ReadOnly { + opts.ReadOnly = clusterOptions.ReadOnly + } + if clusterOptions.RouteByLatency { + opts.RouteByLatency = clusterOptions.RouteByLatency + } + if clusterOptions.RouteRandomly { + opts.RouteRandomly = clusterOptions.RouteRandomly + } + } + + if cfg.TLS.Enabled { + var caCertPool *x509.CertPool + + if cfg.TLS.Ca != "" { + caCertPool = x509.NewCertPool() + caCertPool.AppendCertsFromPEM([]byte(cfg.TLS.Ca)) + } + + var certificates []tls.Certificate + if cfg.TLS.Cert != "" && cfg.TLS.CertKey != "" { + cert, err := tls.X509KeyPair([]byte(cfg.TLS.Cert), []byte(cfg.TLS.CertKey)) + if err != nil { + return nil, err + } + certificates = []tls.Certificate{cert} + } + + opts.TLSConfig = &tls.Config{ + InsecureSkipVerify: cfg.TLS.InsecureSkipVerify, + Certificates: certificates, + RootCAs: caCertPool, + } + } + + return opts, nil +} diff --git a/pkg/cache/redis/redis_test.go b/pkg/cache/redis/redis_test.go new file mode 100644 index 00000000..a289daeb --- /dev/null +++ b/pkg/cache/redis/redis_test.go @@ -0,0 +1,130 @@ +package redis + +import ( + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + + "github.com/scribd/go-sdk/pkg/cache" +) + +func TestNew(t *testing.T) { + tests := []struct { + name string + cfg cache.Redis + wantErr bool + }{ + { + name: "Config without URL set", + cfg: cache.Redis{ + Addrs: []string{"localhost:6379"}, + }, + }, + { + name: "Config with URL set", + cfg: cache.Redis{ + URL: "redis://localhost:6379", + }, + }, + { + name: "Config with URL set to cluster URL", + cfg: cache.Redis{ + URL: "redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791", + }, + }, + { + name: "Config with invalid URL", + cfg: cache.Redis{ + URL: "localhost:6379", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := New(&tt.cfg) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCfgToRedisClientOptions(t *testing.T) { + tests := []struct { + name string + cfg cache.Redis + check func(t *testing.T, opts *redis.UniversalOptions) + wantErr bool + }{ + { + name: "Config without URL set", + cfg: cache.Redis{ + Addrs: []string{"localhost:6379"}, + }, + check: func(t *testing.T, opts *redis.UniversalOptions) { + assert.Equal(t, []string{"localhost:6379"}, opts.Addrs) + }, + }, + { + name: "Config with URL set", + cfg: cache.Redis{ + URL: "redis://localhost:6379", + }, + check: func(t *testing.T, opts *redis.UniversalOptions) { + assert.Equal(t, []string{"localhost:6379"}, opts.Addrs) + }, + }, + { + name: "Config with TLS enabled", + cfg: cache.Redis{ + URL: "rediss://localhost:6379", + TLS: cache.TLS{ + Enabled: true, + }, + }, + check: func(t *testing.T, opts *redis.UniversalOptions) { + assert.NotNil(t, opts.TLSConfig) + assert.False(t, opts.TLSConfig.InsecureSkipVerify) + }, + }, + { + name: "Config with URL set to cluster URL", + cfg: cache.Redis{ + URL: "redis://user:password@localhost:6789?dial_timeout=3&read_timeout=6s&addr=localhost:6790&addr=localhost:6791", + }, + check: func(t *testing.T, opts *redis.UniversalOptions) { + assert.Equal(t, []string{"localhost:6789", "localhost:6790", "localhost:6791"}, opts.Addrs) + assert.Equal(t, 3*time.Second, opts.DialTimeout) + assert.Equal(t, 6*time.Second, opts.ReadTimeout) + assert.Equal(t, "user", opts.Username) + assert.Equal(t, "password", opts.Password) + }, + }, + { + name: "Config with invalid URL", + cfg: cache.Redis{ + URL: "localhost:6379", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts, err := cfgToRedisClientOptions(&tt.cfg) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + + tt.check(t, opts) + } + }) + } +} diff --git a/pkg/instrumentation/redis.go b/pkg/instrumentation/redis.go new file mode 100644 index 00000000..a82fc34c --- /dev/null +++ b/pkg/instrumentation/redis.go @@ -0,0 +1,18 @@ +package instrumentation + +import ( + "fmt" + + "github.com/redis/go-redis/v9" + redistrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/redis/go-redis.v9" +) + +const ( + redisServiceNameSuffix = "cache-redis" +) + +func InstrumentRedis(client redis.UniversalClient, applicationName string) { + serviceName := fmt.Sprintf("%s-%s", applicationName, redisServiceNameSuffix) + + redistrace.WrapClient(client, redistrace.WithServiceName(serviceName)) +} diff --git a/pkg/instrumentation/redis_test.go b/pkg/instrumentation/redis_test.go new file mode 100644 index 00000000..15162547 --- /dev/null +++ b/pkg/instrumentation/redis_test.go @@ -0,0 +1,78 @@ +package instrumentation + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "gopkg.in/DataDog/dd-trace-go.v1/ddtrace/mocktracer" + + "github.com/scribd/go-sdk/pkg/cache" + "github.com/scribd/go-sdk/pkg/cache/redis" +) + +func TestInstrumentRedis(t *testing.T) { + expectedSpans := []struct { + name string + tags map[string]interface{} + }{ + { + name: "redis.dial", + tags: map[string]interface{}{ + "component": "redis/go-redis.v9", + "db.system": "redis", + "out.db": "0", + "out.host": "localhost", + "out.port": "6379", + "resource.name": "redis.dial", + "service.name": "test-app-cache-redis", + "span.kind": "client", + "span.type": "redis", + }, + }, + { + name: "redis.command", + tags: map[string]interface{}{ + "component": "redis/go-redis.v9", + "db.system": "redis", + "out.db": "0", + "out.host": "localhost", + "out.port": "6379", + "redis.args_length": "2", + "redis.raw_command": "get test-key: ", + "resource.name": "get", + "service.name": "test-app-cache-redis", + "span.kind": "client", + "span.type": "redis", + }, + }, + } + + mt := mocktracer.Start() + defer mt.Stop() + + client, err := redis.New(&cache.Redis{ + Addrs: []string{"localhost:6379"}, + }) + assert.NoError(t, err) + + InstrumentRedis(client, "test-app") + + client.Get(context.Background(), "test-key") + spans := mt.FinishedSpans() + assert.Len(t, spans, 2) + for i := range spans { + actualName := spans[i].OperationName() + actualTags := spans[i].Tags() + + expectedName := expectedSpans[i].name + expectedTags := expectedSpans[i].tags + + if actualName != expectedName { + t.Errorf("Got span: %s, expected: %s", actualName, expectedName) + } + + assert.Equal(t, expectedTags, actualTags, "database tags didn't match") + } + +} diff --git a/pkg/logger/redis.go b/pkg/logger/redis.go new file mode 100644 index 00000000..391e34c5 --- /dev/null +++ b/pkg/logger/redis.go @@ -0,0 +1,34 @@ +package logger + +import ( + "context" + + "github.com/redis/go-redis/v9" + + "github.com/scribd/go-sdk/pkg/instrumentation" +) + +type ( + RedisLogger struct { + logger Logger + } +) + +func NewRedisLogger(l Logger) *RedisLogger { + return &RedisLogger{l} +} + +func (r *RedisLogger) Printf(ctx context.Context, format string, v ...interface{}) { + logContext := instrumentation.TraceLogs(ctx) + + r.logger.WithFields(Fields{ + "dd": Fields{ + "trace_id": logContext.TraceID, + "span_id": logContext.SpanID, + }, + }).Errorf(format, v...) +} + +func SetRedisLogger(logger Logger) { + redis.SetLogger(NewRedisLogger(logger)) +}