Skip to content

Commit

Permalink
cmd/atlas/docker: allow provide User via options
Browse files Browse the repository at this point in the history
  • Loading branch information
giautm committed Dec 10, 2024
1 parent 0bee868 commit 476b8b9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 28 deletions.
73 changes: 48 additions & 25 deletions cmd/atlas/internal/docker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,24 @@ type (
Config struct {
driver string // driver to open connections with.
setup []string // contains statements to execute once the service is up
// User is the user to connect to the database.
User *url.Userinfo
// Internal Port to expose and connect to.
Port string
// Image is the name of the image to pull and run.
Image string
// Env vars to pass to the docker container.
Env []string
// Internal Port to expose anc connect to.
Port string
// Database name to create and connect on init.
Database string
// Out is a custom writer to send docker cli output to.
Out io.Writer
}
// A Container is an instance of a created container.
Container struct {
cfg Config // Config used to create this container
out io.Writer // custom write to log status messages to
Config // Config used to create this container
// ID of the container.
ID string
// Passphrase of the root user.
Passphrase string
// Port on the host this containers service is bound to.
Port string
}
Expand All @@ -72,6 +71,7 @@ func NewConfig(opts ...ConfigOption) (*Config, error) {
return c, nil
}

// Supported drivers.
const (
DriverMySQL = "mysql"
DriverMariaDB = "mariadb"
Expand Down Expand Up @@ -211,6 +211,7 @@ func MySQL(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image(hubUser, "mysql:"+version),
Userinfo(url.UserPassword("root", pass)),
Port("3306"),
Env("MYSQL_ROOT_PASSWORD=" + pass),
},
Expand All @@ -230,6 +231,7 @@ func PostgreSQL(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("postgres:" + version),
Userinfo(url.UserPassword("postgres", pass)),
Port("5432"),
Database("postgres"),
Env("POSTGRES_PASSWORD=" + pass),
Expand All @@ -245,6 +247,7 @@ func SQLServer(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("mcr.microsoft.com/mssql/server:" + version),
Userinfo(url.UserPassword("sa", passSQLServer)),
Port("1433"),
Database("master"),
Env(
Expand All @@ -264,6 +267,7 @@ func ClickHouse(version string, opts ...ConfigOption) (*Config, error) {
append(
[]ConfigOption{
Image("clickhouse/clickhouse-server:" + version),
Userinfo(url.UserPassword("default", pass)),
Port("9000"),
Env("CLICKHOUSE_PASSWORD=" + pass),
},
Expand Down Expand Up @@ -310,13 +314,25 @@ func Env(env ...string) ConfigOption {
}

// Database sets the database name to connect to. For example:
//
// Database("test")
func Database(name string) ConfigOption {
return func(c *Config) error {
c.Database = name
return nil
}
}

// Userinfo sets the user to connect to the database. For example:
//
// Userinfo(url.User("root", "pass"))
func Userinfo(u *url.Userinfo) ConfigOption {
return func(c *Config) error {
c.User = u
return nil
}
}

// Out sets an io.Writer to use when running docker commands. For example:
//
// buf := new(bytes.Buffer)
Expand Down Expand Up @@ -366,11 +382,9 @@ func (c *Config) Run(ctx context.Context) (*Container, error) {
return nil, err
}
return &Container{
cfg: *c,
ID: strings.TrimSpace(stdout.String()),
Passphrase: pass,
Port: p,
out: c.Out,
Config: *c,
ID: strings.TrimSpace(stdout.String()),
Port: p,
}, nil
}

Expand All @@ -381,7 +395,7 @@ func (c *Container) Close() error {

// Wait waits for this container to be ready.
func (c *Container) Wait(ctx context.Context, timeout time.Duration) error {
fmt.Fprintln(c.out, "Waiting for service to be ready ... ")
fmt.Fprintln(c.Out, "Waiting for service to be ready ... ")
mysql.SetLogger(log.New(io.Discard, "", 1))
defer mysql.SetLogger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
if timeout > time.Minute {
Expand All @@ -408,14 +422,14 @@ func (c *Container) Wait(ctx context.Context, timeout time.Duration) error {
if err = db.PingContext(ctx); err != nil {
continue
}
for _, s := range c.cfg.setup {
for _, s := range c.setup {
if _, err := db.ExecContext(ctx, s); err != nil {
err = errors.Join(err, db.Close())
return fmt.Errorf("%q: %w", s, err)
}
}
_ = db.Close()
fmt.Fprintln(c.out, "Service is ready to connect!")
fmt.Fprintln(c.Out, "Service is ready to connect!")
return nil
case <-ctx.Done():
return ctx.Err()
Expand All @@ -440,23 +454,31 @@ func (c *Container) URL() (*url.URL, error) {
}
host = u.Hostname()
}
switch c.cfg.driver {
case DriverClickHouse:
return url.Parse(fmt.Sprintf("clickhouse://:%s@%s:%s/%s", c.Passphrase, host, c.Port, c.cfg.Database))
u := &url.URL{
Scheme: c.driver,
User: c.User,
Host: fmt.Sprintf("%s:%s", host, c.Port),
}
switch c.driver {
case DriverSQLServer:
return url.Parse(fmt.Sprintf("sqlserver://sa:%s@%s:%s?database=%s", passSQLServer, host, c.Port, c.cfg.Database))
q := u.Query()
q.Set("database", c.Database)
u.RawQuery = q.Encode()
case DriverPostgres:
return url.Parse(fmt.Sprintf("postgres://postgres:%s@%s:%s/%s?sslmode=disable", c.Passphrase, host, c.Port, c.cfg.Database))
case DriverMySQL, DriverMariaDB:
return url.Parse(fmt.Sprintf("%s://root:%s@%s:%s/%s", c.cfg.driver, c.Passphrase, host, c.Port, c.cfg.Database))
q := u.Query()
q.Set("sslmode", "disable")
u.Path, u.RawQuery = c.Database, q.Encode()
case DriverMySQL, DriverMariaDB, DriverClickHouse: // MySQL compatible
u.Path = c.Database
default:
return nil, fmt.Errorf("unknown driver: %q", c.cfg.driver)
return nil, fmt.Errorf("unknown driver: %q", c.driver)
}
return u, nil
}

// PingURL returns a URL to ping the Container.
func (c *Container) PingURL(u url.URL) string {
switch c.cfg.driver {
switch c.driver {
case DriverSQLServer:
q := u.Query()
q.Del("database")
Expand Down Expand Up @@ -496,8 +518,9 @@ func init() {
"docker",
sqlclient.OpenerFunc(Open),
sqlclient.RegisterFlavours(
"docker+postgres", "docker+mysql", "docker+maria", "docker+mariadb",
"docker+sqlserver", "docker+clickhouse",
"docker+postgres",
"docker+mysql", "docker+maria", "docker+mariadb", "docker+clickhouse",
"docker+sqlserver",
),
)
}
Expand Down
29 changes: 26 additions & 3 deletions cmd/atlas/internal/docker/docker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "arigaio/mysql:latest",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -35,6 +36,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "arigaio/mariadb:latest",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -45,6 +47,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "postgres:latest",
User: url.UserPassword("postgres", pass),
Env: []string{"POSTGRES_PASSWORD=pass"},
Database: "postgres",
Port: "5432",
Expand All @@ -56,6 +59,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "mcr.microsoft.com/mssql/server:2022-latest",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Database: "master",
Out: io.Discard,
Expand All @@ -71,6 +75,7 @@ func TestDockerConfig(t *testing.T) {
require.NoError(t, err)
require.Equal(t, &Config{
Image: "clickhouse/clickhouse-server:23.11",
User: url.UserPassword("default", pass),
Port: "9000",
Out: io.Discard,
Env: []string{
Expand All @@ -87,6 +92,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "mysql",
Image: "arigaio/mysql",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -99,6 +105,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "mysql",
Image: "arigaio/mysql:8",
User: url.UserPassword("root", pass),
Env: []string{"MYSQL_ROOT_PASSWORD=pass"},
Port: "3306",
Out: io.Discard,
Expand All @@ -113,6 +120,7 @@ func TestFromURL(t *testing.T) {
Image: "arigaio/mysql:latest",
Database: "test",
Env: []string{"MYSQL_ROOT_PASSWORD=pass", "MYSQL_DATABASE=test"},
User: url.UserPassword("root", pass),
Port: "3306",
Out: io.Discard,
setup: []string{"CREATE DATABASE IF NOT EXISTS `test`"},
Expand All @@ -127,6 +135,7 @@ func TestFromURL(t *testing.T) {
Image: "postgres:13",
Database: "postgres",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)
Expand All @@ -140,6 +149,7 @@ func TestFromURL(t *testing.T) {
Image: "postgis/postgis:14-3.4",
Database: "postgres",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
}, cfg)
Expand All @@ -153,6 +163,7 @@ func TestFromURL(t *testing.T) {
Image: "postgis/postgis:14-3.4",
Database: "dev",
Env: []string{"POSTGRES_PASSWORD=pass"},
User: url.UserPassword("postgres", pass),
Port: "5432",
Out: io.Discard,
setup: []string{`CREATE DATABASE "dev"`},
Expand All @@ -167,6 +178,7 @@ func TestFromURL(t *testing.T) {
driver: "sqlserver",
Image: "mcr.microsoft.com/mssql/server",
Database: "master",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -184,6 +196,7 @@ func TestFromURL(t *testing.T) {
driver: "sqlserver",
Image: "mcr.microsoft.com/mssql/server:2022-latest",
Database: "master",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -202,6 +215,7 @@ func TestFromURL(t *testing.T) {
setup: []string{"CREATE DATABASE [foo]"},
Image: "mcr.microsoft.com/mssql/server:2019-latest",
Database: "foo",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -221,6 +235,7 @@ func TestFromURL(t *testing.T) {
setup: []string{"CREATE DATABASE [foo]"},
Image: "mcr.microsoft.com/azure-sql-edge:1.0.7",
Database: "foo",
User: url.UserPassword("sa", passSQLServer),
Port: "1433",
Out: io.Discard,
Env: []string{
Expand All @@ -239,6 +254,7 @@ func TestFromURL(t *testing.T) {
driver: "clickhouse",
Image: "clickhouse/clickhouse-server",
Env: []string{"CLICKHOUSE_PASSWORD=pass"},
User: url.UserPassword("default", pass),
Port: "9000",
Out: io.Discard,
}, cfg)
Expand All @@ -251,6 +267,7 @@ func TestFromURL(t *testing.T) {
require.Equal(t, &Config{
driver: "clickhouse",
Image: "clickhouse/clickhouse-server:23.11",
User: url.UserPassword("default", pass),
Env: []string{"CLICKHOUSE_PASSWORD=pass"},
Port: "9000",
Out: io.Discard,
Expand Down Expand Up @@ -416,14 +433,20 @@ func TestImageURL(t *testing.T) {
}

func TestContainerURL(t *testing.T) {
c := &Container{cfg: Config{driver: "postgres"}, Passphrase: "pass", Port: "5432"}
c := &Container{
Config: Config{
driver: "postgres",
User: url.UserPassword("postgres", "pass"),
},
Port: "5432",
}
u, err := c.URL()
require.NoError(t, err)
require.Equal(t, "postgres://postgres:pass@localhost:5432/?sslmode=disable", u.String())
require.Equal(t, "postgres://postgres:pass@localhost:5432?sslmode=disable", u.String())

// With DOCKER_HOST
t.Setenv("DOCKER_HOST", "tcp://host.docker.internal:2375")
u, err = c.URL()
require.NoError(t, err)
require.Equal(t, "postgres://postgres:[email protected]:5432/?sslmode=disable", u.String())
require.Equal(t, "postgres://postgres:[email protected]:5432?sslmode=disable", u.String())
}

0 comments on commit 476b8b9

Please sign in to comment.