Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/atlas/docker: allow provide User via options #3251

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions cmd/atlas/internal/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,24 @@ type (
Config struct {
driver string // driver to open connections with.
setup []string // contains statements to execute once the service is up
// User is the user to connect to the database.
User *url.Userinfo
// Internal Port to expose and connect to.
Port string
// Image is the name of the image to pull and run.
Image string
// Env vars to pass to the docker container.
Env []string
// Internal Port to expose anc connect to.
Port string
// Database name to create and connect on init.
Database string
// Out is a custom writer to send docker cli output to.
Out io.Writer
}
// A Container is an instance of a created container.
Container struct {
cfg Config // Config used to create this container
out io.Writer // custom write to log status messages to
Config // Config used to create this container
// ID of the container.
ID string
// Passphrase of the root user.
Passphrase string
// Port on the host this containers service is bound to.
Port string
}
Expand Down Expand Up @@ -210,6 +209,7 @@ func MySQL(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image(hubUser, "mysql:"+version),
Userinfo(url.UserPassword("root", pass)),
Port("3306"),
Env("MYSQL_ROOT_PASSWORD=" + pass),
},
Expand All @@ -229,6 +229,7 @@ func PostgreSQL(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("postgres:" + version),
Userinfo(url.UserPassword("postgres", pass)),
Port("5432"),
Database("postgres"),
Env("POSTGRES_PASSWORD=" + pass),
Expand All @@ -244,6 +245,7 @@ func SQLServer(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("mcr.microsoft.com/mssql/server:" + version),
Userinfo(url.UserPassword("sa", passSQLServer)),
Port("1433"),
Database("master"),
Env(
Expand All @@ -263,6 +265,7 @@ func ClickHouse(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("clickhouse/clickhouse-server:" + version),
Userinfo(url.UserPassword("default", pass)),
Port("9000"),
Env("CLICKHOUSE_PASSWORD=" + pass),
},
Expand Down Expand Up @@ -309,13 +312,25 @@ func Env(env ...string) ConfigOption {
}

// Database sets the database name to connect to. For example:
//
// Database("test")
func Database(name string) ConfigOption {
return func(c *Config) error {
c.Database = name
return nil
}
}

// Userinfo sets the user to connect to the database. For example:
//
// Userinfo(url.User("root", "pass"))
func Userinfo(u *url.Userinfo) ConfigOption {
return func(c *Config) error {
c.User = u
return nil
}
}

// Out sets an io.Writer to use when running docker commands. For example:
//
// buf := new(bytes.Buffer)
Expand Down Expand Up @@ -365,11 +380,9 @@ func (c *Config) Run(ctx context.Context) (*Container, error) {
return nil, err
}
return &Container{
cfg: *c,
ID: strings.TrimSpace(stdout.String()),
Passphrase: pass,
Port: p,
out: c.Out,
Config: *c,
ID: strings.TrimSpace(stdout.String()),
Port: p,
}, nil
}

Expand All @@ -380,7 +393,7 @@ func (c *Container) Close() error {

// Wait waits for this container to be ready.
func (c *Container) Wait(ctx context.Context, timeout time.Duration) error {
fmt.Fprintln(c.out, "Waiting for service to be ready ... ")
fmt.Fprintln(c.Out, "Waiting for service to be ready ... ")
mysql.SetLogger(log.New(io.Discard, "", 1))
defer mysql.SetLogger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
if timeout > time.Minute {
Expand All @@ -407,14 +420,14 @@ func (c *Container) Wait(ctx context.Context, timeout time.Duration) error {
if err = db.PingContext(ctx); err != nil {
continue
}
for _, s := range c.cfg.setup {
for _, s := range c.setup {
if _, err := db.ExecContext(ctx, s); err != nil {
err = errors.Join(err, db.Close())
return fmt.Errorf("%q: %w", s, err)
}
}
_ = db.Close()
fmt.Fprintln(c.out, "Service is ready to connect!")
fmt.Fprintln(c.Out, "Service is ready to connect!")
return nil
case <-ctx.Done():
return ctx.Err()
Expand All @@ -439,23 +452,34 @@ func (c *Container) URL() (*url.URL, error) {
}
host = u.Hostname()
}
switch c.cfg.driver {
case DriverClickHouse:
return url.Parse(fmt.Sprintf("clickhouse://:%s@%s:%s/%s", c.Passphrase, host, c.Port, c.cfg.Database))
u := &url.URL{
Scheme: c.driver,
User: c.User,
Host: fmt.Sprintf("%s:%s", host, c.Port),
}
switch c.driver {
case DriverSQLServer:
return url.Parse(fmt.Sprintf("sqlserver://sa:%s@%s:%s?database=%s", passSQLServer, host, c.Port, c.cfg.Database))
q := u.Query()
q.Set("database", c.Database)
u.RawQuery = q.Encode()
case DriverPostgres:
return url.Parse(fmt.Sprintf("postgres://postgres:%s@%s:%s/%s?sslmode=disable", c.Passphrase, host, c.Port, c.cfg.Database))
case DriverMySQL, DriverMariaDB:
return url.Parse(fmt.Sprintf("%s://root:%s@%s:%s/%s", c.cfg.driver, c.Passphrase, host, c.Port, c.cfg.Database))
q := u.Query()
q.Set("sslmode", "disable")
u.Path, u.RawQuery = c.Database, q.Encode()
if u.Path == "" {
u.Path = "/"
}
case DriverMySQL, DriverMariaDB, DriverClickHouse: // MySQL compatible
u.Path = c.Database
default:
return nil, fmt.Errorf("unknown driver: %q", c.cfg.driver)
return nil, fmt.Errorf("unknown driver: %q", c.driver)
}
return u, nil
}

// PingURL returns a URL to ping the Container.
func (c *Container) PingURL(u url.URL) string {
switch c.cfg.driver {
switch c.driver {
case DriverSQLServer:
q := u.Query()
q.Del("database")
Expand Down Expand Up @@ -495,8 +519,9 @@ func init() {
"docker",
sqlclient.OpenerFunc(Open),
sqlclient.RegisterFlavours(
"docker+postgres", "docker+mysql", "docker+maria", "docker+mariadb",
"docker+sqlserver", "docker+clickhouse",
"docker+postgres",
"docker+mysql", "docker+maria", "docker+mariadb", "docker+clickhouse",
"docker+sqlserver",
),
)
}
Expand Down
25 changes: 24 additions & 1 deletion cmd/atlas/internal/docker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "arigaio/mysql:latest",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -35,6 +36,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "arigaio/mariadb:latest",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -45,6 +47,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "postgres:latest",
User: url.UserPassword("postgres", pass),
Env: []string{"POSTGRES_PASSWORD=pass"},
Database: "postgres",
Port: "5432",
Expand All @@ -56,6 +59,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "mcr.microsoft.com/mssql/server:2022-latest",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Database: "master",
Out: io.Discard,
Expand All @@ -71,6 +75,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "clickhouse/clickhouse-server:23.11",
User: url.UserPassword("default", pass),
Port: "9000",
Out: io.Discard,
Env: []string{
Expand All @@ -87,6 +92,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "mysql",
Image: "arigaio/mysql",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -99,6 +105,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "mysql",
Image: "arigaio/mysql:8",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -113,6 +120,7 @@ func TestFromURL(t *testing.T) {
Image: "arigaio/mysql:latest",
Database: "test",
Env: []string{"MYSQL_ROOT_PASSWORD=pass", "MYSQL_DATABASE=test"},
User: url.UserPassword("root", pass),
Port: "3306",
Out: io.Discard,
setup: []string{"CREATE DATABASE IF NOT EXISTS `test`"},
Expand All @@ -127,6 +135,7 @@ func TestFromURL(t *testing.T) {
Image: "postgres:13",
Database: "postgres",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)
Expand All @@ -140,6 +149,7 @@ func TestFromURL(t *testing.T) {
Image: "postgis/postgis:14-3.4",
Database: "postgres",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)
Expand All @@ -153,6 +163,7 @@ func TestFromURL(t *testing.T) {
Image: "postgis/postgis:14-3.4",
Database: "dev",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
setup: []string{`CREATE DATABASE "dev"`},
Expand All @@ -167,6 +178,7 @@ func TestFromURL(t *testing.T) {
driver: "sqlserver",
Image: "mcr.microsoft.com/mssql/server",
Database: "master",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -184,6 +196,7 @@ func TestFromURL(t *testing.T) {
driver: "sqlserver",
Image: "mcr.microsoft.com/mssql/server:2022-latest",
Database: "master",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -202,6 +215,7 @@ func TestFromURL(t *testing.T) {
setup: []string{"CREATE DATABASE [foo]"},
Image: "mcr.microsoft.com/mssql/server:2019-latest",
Database: "foo",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -221,6 +235,7 @@ func TestFromURL(t *testing.T) {
setup: []string{"CREATE DATABASE [foo]"},
Image: "mcr.microsoft.com/azure-sql-edge:1.0.7",
Database: "foo",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -239,6 +254,7 @@ func TestFromURL(t *testing.T) {
driver: "clickhouse",
Image: "clickhouse/clickhouse-server",
Env: []string{"CLICKHOUSE_PASSWORD=pass"},
User: url.UserPassword("default", pass),
Port: "9000",
Out: io.Discard,
}, cfg)
Expand All @@ -251,6 +267,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "clickhouse",
Image: "clickhouse/clickhouse-server:23.11",
User: url.UserPassword("default", pass),
Env: []string{"CLICKHOUSE_PASSWORD=pass"},
Port: "9000",
Out: io.Discard,
Expand Down Expand Up @@ -416,7 +433,13 @@ func TestImageURL(t *testing.T) {
}

func TestContainerURL(t *testing.T) {
c := &Container{cfg: Config{driver: "postgres"}, Passphrase: "pass", Port: "5432"}
c := &Container{
Config: Config{
driver: "postgres",
User: url.UserPassword("postgres", "pass"),
},
Port: "5432",
}
u, err := c.URL()
require.NoError(t, err)
require.Equal(t, "postgres://postgres:pass@localhost:5432/?sslmode=disable", u.String())
Expand Down
Loading