Skip to content

Commit 768d0c0

Browse files
committedJun 12, 2018
migrate: Refactor and add support for mysql.
1 parent 3bde6fe commit 768d0c0

File tree

6 files changed

+300
-181
lines changed

6 files changed

+300
-181
lines changed
 

‎migrate/backend.go

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
package migrate
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"net/url"
7+
"strings"
8+
)
9+
10+
type backend interface {
11+
Connect() (*sql.DB, error)
12+
Reset() error
13+
Drop() error
14+
Migrate() error
15+
}
16+
17+
// GetBackend selects the datasource identified by dbname in the config file and
18+
// initializes the correct type of backend.
19+
func (m *Module) GetBackend(dbname string) (backend, error) {
20+
ds, ok := m.config.Datasources[dbname]
21+
if !ok {
22+
return nil, fmt.Errorf("migrate: %q not configured", dbname)
23+
}
24+
25+
migrations, err := m.getMigrationSource()
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
// special case for parallelizing tests: add a suffix to the dbname
31+
if dbname == "test" {
32+
u, err := url.Parse(ds.DSN)
33+
if err != nil {
34+
return nil, err
35+
}
36+
database := strings.Trim(u.Path, "/")
37+
if m.suffixForTest == "" {
38+
m.suffixForTest = randomToken()
39+
}
40+
u.Path = database + "_" + m.suffixForTest
41+
ds.DSN = u.String()
42+
}
43+
switch ds.Driver {
44+
case "postgres":
45+
return &postgresBackend{ds, migrations}, nil
46+
case "mysql":
47+
return &mysqlBackend{ds, migrations}, nil
48+
}
49+
return nil, fmt.Errorf("migrate: unsupported driver %s", ds.Driver)
50+
}

‎migrate/commands.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package migrate
2+
3+
import (
4+
"fmt"
5+
"log"
6+
7+
"github.com/octavore/naga/service"
8+
)
9+
10+
func (m *Module) registerCommands(c *service.Config) {
11+
c.AddCommand(&service.Command{
12+
Keyword: "db:migrate <db>",
13+
ShortUsage: "run db migrations",
14+
Run: func(ctx *service.CommandContext) {
15+
if len(ctx.Args) != 1 {
16+
m.printHelp(ctx)
17+
}
18+
b, err := m.GetBackend(ctx.Args[0])
19+
if err != nil {
20+
log.Println("migrate:", err)
21+
return
22+
}
23+
err = b.Migrate()
24+
if err != nil {
25+
log.Println("migrate:", err)
26+
}
27+
},
28+
})
29+
30+
c.AddCommand(&service.Command{
31+
Keyword: "db:reset <db>",
32+
ShortUsage: "reset database",
33+
Run: func(ctx *service.CommandContext) {
34+
if len(ctx.Args) != 1 {
35+
m.printHelp(ctx)
36+
}
37+
dbname := ctx.Args[0]
38+
b, err := m.GetBackend(dbname)
39+
if err != nil {
40+
log.Println("migrate:", err)
41+
return
42+
}
43+
err = b.Reset()
44+
if err != nil {
45+
log.Println("migrate:", err)
46+
}
47+
err = b.Migrate()
48+
if err != nil {
49+
log.Println("migrate:", err)
50+
}
51+
},
52+
})
53+
}
54+
55+
func (m *Module) printHelp(ctx *service.CommandContext) {
56+
if len(ctx.Args) != 1 {
57+
fmt.Println("Please specify a db:")
58+
if len(m.config.Datasources) == 0 {
59+
fmt.Println(" no databases found!")
60+
} else {
61+
for ds := range m.config.Datasources {
62+
fmt.Println(" " + ds)
63+
}
64+
}
65+
ctx.UsageExit()
66+
}
67+
}

‎migrate/migrations.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package migrate
2+
3+
import (
4+
"path/filepath"
5+
6+
migrate "github.com/rubenv/sql-migrate"
7+
)
8+
9+
type (
10+
assetFunc func(path string) ([]byte, error)
11+
assetDirFunc func(path string) ([]string, error)
12+
)
13+
14+
// SetMigrationSource sets the migration source, for compatibility with
15+
// embedded file assets.
16+
func (m *Module) SetMigrationSource(asset assetFunc, assetDir assetDirFunc, dir string) {
17+
m.migrationSource = &migrate.AssetMigrationSource{
18+
Asset: asset,
19+
AssetDir: assetDir,
20+
Dir: dir,
21+
}
22+
}
23+
24+
// getMigrationSource returns the m.migrationSource if set, otherwise
25+
// it defaults by reading from the MigrationsDir specified in
26+
func (m *Module) getMigrationSource() (migrate.MigrationSource, error) {
27+
if m.migrationSource != nil {
28+
return m.migrationSource, nil
29+
}
30+
configPath, err := filepath.Abs(m.Config.ConfigPath)
31+
if err != nil {
32+
return nil, err
33+
}
34+
migrationPath := filepath.Join(filepath.Dir(configPath), m.config.MigrationsDir)
35+
return migrate.FileMigrationSource{Dir: migrationPath}, nil
36+
}

‎migrate/module.go

+15-181
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@ package migrate
22

33
import (
44
"database/sql"
5-
"fmt"
6-
"log"
7-
"net/url"
8-
"path/filepath"
9-
"strings"
105

116
"github.com/octavore/naga/service"
127
"github.com/rubenv/sql-migrate"
@@ -28,6 +23,7 @@ type Module struct {
2823
migrationSource migrate.MigrationSource
2924

3025
suffixForTest string
26+
env service.Environment
3127
}
3228

3329
// Config for migrate module
@@ -43,203 +39,41 @@ type Datasource struct {
4339
DSN string `json:"dsn"`
4440
}
4541

46-
func (m *Module) printHelp(ctx *service.CommandContext) {
47-
if len(ctx.Args) != 1 {
48-
fmt.Println("Please specify a db:")
49-
if len(m.config.Datasources) == 0 {
50-
fmt.Println(" no databases found!")
51-
} else {
52-
for ds := range m.config.Datasources {
53-
fmt.Println(" " + ds)
54-
}
55-
}
56-
ctx.UsageExit()
57-
}
58-
}
59-
42+
// Init the migrate module
6043
func (m *Module) Init(c *service.Config) {
61-
c.AddCommand(&service.Command{
62-
Keyword: "db:migrate <db>",
63-
ShortUsage: "run db migrations",
64-
Run: func(ctx *service.CommandContext) {
65-
if len(ctx.Args) != 1 {
66-
m.printHelp(ctx)
67-
}
68-
err := m.Migrate(ctx.Args[0])
69-
if err != nil {
70-
log.Println("migrate:", err)
71-
}
72-
},
73-
})
74-
75-
c.AddCommand(&service.Command{
76-
Keyword: "db:reset <db>",
77-
ShortUsage: "reset database",
78-
Run: func(ctx *service.CommandContext) {
79-
if len(ctx.Args) != 1 {
80-
m.printHelp(ctx)
81-
}
82-
dbname := ctx.Args[0]
83-
err := m.Reset(dbname)
84-
if err != nil {
85-
log.Println("migrate:", err)
86-
}
87-
err = m.Migrate(dbname)
88-
if err != nil {
89-
log.Println("migrate:", err)
90-
}
91-
},
92-
})
44+
m.registerCommands(c)
9345

9446
c.Setup = func() error {
47+
m.env = c.Env()
9548
err := m.Config.ReadConfig(&m.config)
9649
if m.config.MigrationsTable != "" {
9750
migrate.SetTable(m.config.MigrationsTable)
9851
}
9952
return err
10053
}
101-
102-
c.SetupTest = func() {
103-
}
104-
}
105-
106-
func (m *Module) getConfig(dbname string) (*Datasource, error) {
107-
ds, ok := m.config.Datasources[dbname]
108-
if !ok {
109-
return nil, fmt.Errorf("migrate: %q not configured", dbname)
110-
}
111-
112-
// special case for parallelizing tests: add a suffix to the dbname
113-
if dbname == "test" {
114-
u, err := url.Parse(ds.DSN)
115-
if err != nil {
116-
return nil, err
117-
}
118-
database := strings.Trim(u.Path, "/")
119-
if m.suffixForTest == "" {
120-
m.suffixForTest = randomToken()
121-
}
122-
u.Path = database + "_" + m.suffixForTest
123-
ds.DSN = u.String()
124-
}
125-
return &ds, nil
12654
}
12755

128-
// Connect to the given DB
129-
func (m *Module) Connect(dbname string) (*sql.DB, error) {
130-
ds, err := m.getConfig(dbname)
56+
// ConnectDefault to the DB with name specified by env
57+
func (m *Module) ConnectDefault() (*sql.DB, error) {
58+
ds, err := m.GetBackend(string(m.env))
13159
if err != nil {
13260
return nil, err
13361
}
134-
return sql.Open(ds.Driver, ds.DSN)
135-
}
136-
137-
func (m *Module) AddMigrations(migrationsDir string) {
138-
panic("todo")
139-
}
140-
141-
type (
142-
assetFunc func(path string) ([]byte, error)
143-
assetDirFunc func(path string) ([]string, error)
144-
)
145-
146-
// SetMigrationSource sets the migration source, for compatibility with
147-
// embedded file assets.
148-
func (m *Module) SetMigrationSource(asset assetFunc, assetDir assetDirFunc, dir string) {
149-
m.migrationSource = &migrate.AssetMigrationSource{
150-
Asset: asset,
151-
AssetDir: assetDir,
152-
Dir: dir,
153-
}
62+
return ds.Connect()
15463
}
15564

156-
// safeConnect connects to template1 so we can create/drop the desired database.
157-
func (m *Module) safeConnect(dbname string) (string, *sql.DB, error) {
158-
ds, err := m.getConfig(dbname)
159-
if err != nil {
160-
return "", nil, err
161-
}
162-
163-
u, err := url.Parse(ds.DSN)
164-
if err != nil {
165-
return "", nil, err
166-
}
167-
168-
database := strings.Trim(u.Path, "/")
169-
u.Path = "template1"
170-
u.RawPath = "template1"
171-
172-
db, err := sql.Open(ds.Driver, u.String())
173-
if err != nil {
174-
return "", nil, err
175-
}
176-
return database, db, nil
65+
// Connect is a helper function to connect to this datasource
66+
func (d *Datasource) Connect() (*sql.DB, error) {
67+
return sql.Open(d.Driver, d.DSN)
17768
}
17869

179-
// Reset drops and recreates the database
180-
func (m *Module) Reset(dbname string) error {
181-
err := m.Drop(dbname)
182-
if err != nil {
183-
return err
184-
}
185-
186-
databaseName, db, err := m.safeConnect(dbname)
187-
if err != nil {
188-
return err
189-
}
190-
defer db.Close()
191-
192-
// using template0 in order to support test parallelism
193-
// cf http://stackoverflow.com/questions/4977171/pgerror-error-source-database-template1-is-being-accessed-by-other-users
194-
// you may be able to hack around by creating some kind of global lock to protect
195-
// connections to the template1 database?
196-
// or maybe drop the connection as soon as possible?
197-
_, err = db.Exec(`CREATE DATABASE ` + databaseName + ` TEMPLATE template0`)
198-
return err
199-
}
200-
201-
// Drop the database `dbname`
202-
func (m *Module) Drop(dbname string) error {
203-
databaseName, db, err := m.safeConnect(dbname)
70+
// Migrate runs migrations in m
71+
func (d *Datasource) migrate(m migrate.MigrationSource) error {
72+
db, err := d.Connect()
20473
if err != nil {
20574
return err
20675
}
20776
defer db.Close()
208-
_, err = db.Exec(`DROP DATABASE IF EXISTS ` + databaseName)
209-
return err
210-
}
211-
212-
// getMigrationSource returns the m.migrationSource if set, otherwise
213-
// it defaults by reading from the MigrationsDir specified in
214-
func (m *Module) getMigrationSource() (migrate.MigrationSource, error) {
215-
if m.migrationSource != nil {
216-
return m.migrationSource, nil
217-
}
218-
configPath, err := filepath.Abs(m.Config.ConfigPath)
219-
if err != nil {
220-
return nil, err
221-
}
222-
migrationPath := filepath.Join(filepath.Dir(configPath), m.config.MigrationsDir)
223-
return migrate.FileMigrationSource{Dir: migrationPath}, nil
224-
}
225-
226-
// Migrate the given db
227-
func (m *Module) Migrate(dbname string) error {
228-
ds, err := m.getConfig(dbname)
229-
if err != nil {
230-
return err
231-
}
232-
233-
db, err := sql.Open(ds.Driver, ds.DSN)
234-
if err != nil {
235-
return err
236-
}
237-
defer db.Close()
238-
migrations, err := m.getMigrationSource()
239-
if err != nil {
240-
return err
241-
}
242-
243-
_, err = migrate.Exec(db, ds.Driver, migrations, migrate.Up)
77+
_, err = migrate.Exec(db, d.Driver, m, migrate.Up)
24478
return err
24579
}

‎migrate/mysql.go

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package migrate
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"regexp"
7+
8+
migrate "github.com/rubenv/sql-migrate"
9+
)
10+
11+
type mysqlBackend struct {
12+
Datasource
13+
migrate.MigrationSource
14+
}
15+
16+
// connectWithoutDB connects without a db so we can create/drop the desired database.
17+
func (m *mysqlBackend) connectWithoutDB() (string, *sql.DB, error) {
18+
re := regexp.MustCompile("(.*)/([^?]+)")
19+
matches := re.FindStringSubmatch(m.Datasource.DSN)
20+
if len(matches) != 3 {
21+
return "", nil, fmt.Errorf("migrate: error parsing mysql dsn")
22+
}
23+
dsn := matches[1] + "/"
24+
database := matches[2]
25+
26+
db, err := sql.Open("mysql", dsn)
27+
if err != nil {
28+
return "", nil, err
29+
}
30+
return database, db, err
31+
}
32+
33+
// Reset drops and recreates the database
34+
func (m *mysqlBackend) Reset() error {
35+
err := m.Drop()
36+
if err != nil {
37+
return err
38+
}
39+
databaseName, db, err := m.connectWithoutDB()
40+
if err != nil {
41+
return err
42+
}
43+
defer db.Close()
44+
_, err = db.Exec(`CREATE DATABASE ` + databaseName)
45+
return err
46+
}
47+
48+
// Drop the database `dbname`
49+
func (m *mysqlBackend) Drop() error {
50+
databaseName, db, err := m.connectWithoutDB()
51+
if err != nil {
52+
return err
53+
}
54+
defer db.Close()
55+
_, err = db.Exec(`DROP DATABASE IF EXISTS ` + databaseName)
56+
return err
57+
}
58+
59+
func (m *mysqlBackend) Migrate() error {
60+
return m.Datasource.migrate(m.MigrationSource)
61+
}

‎migrate/postgres.go

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package migrate
2+
3+
import (
4+
"database/sql"
5+
"net/url"
6+
"strings"
7+
8+
migrate "github.com/rubenv/sql-migrate"
9+
)
10+
11+
type postgresBackend struct {
12+
Datasource
13+
migrate.MigrationSource
14+
}
15+
16+
// safeConnect connects to template1 so we can create/drop the desired database.
17+
func (p *postgresBackend) safeConnect() (string, *sql.DB, error) {
18+
u, err := url.Parse(p.Datasource.DSN)
19+
if err != nil {
20+
return "", nil, err
21+
}
22+
23+
database := strings.Trim(u.Path, "/")
24+
u.Path = "template1"
25+
u.RawPath = "template1"
26+
27+
db, err := sql.Open(p.Datasource.Driver, u.String())
28+
if err != nil {
29+
return "", nil, err
30+
}
31+
return database, db, nil
32+
}
33+
34+
// Reset drops and recreates the database
35+
func (p *postgresBackend) Reset() error {
36+
err := p.Drop()
37+
if err != nil {
38+
return err
39+
}
40+
41+
databaseName, db, err := p.safeConnect()
42+
if err != nil {
43+
return err
44+
}
45+
defer db.Close()
46+
47+
// using template0 in order to support test parallelism
48+
// cf http://stackoverflow.com/questions/4977171/pgerror-error-source-database-template1-is-being-accessed-by-other-users
49+
// you may be able to hack around by creating some kind of global lock to protect
50+
// connections to the template1 database?
51+
// or maybe drop the connection as soon as possible?
52+
_, err = db.Exec(`CREATE DATABASE ` + databaseName + ` TEMPLATE template0`)
53+
return err
54+
}
55+
56+
// Drop the database `dbname`
57+
func (p *postgresBackend) Drop() error {
58+
databaseName, db, err := p.safeConnect()
59+
if err != nil {
60+
return err
61+
}
62+
defer db.Close()
63+
_, err = db.Exec(`DROP DATABASE IF EXISTS ` + databaseName)
64+
return err
65+
}
66+
67+
// Migrate the given db
68+
69+
func (p *postgresBackend) Migrate() error {
70+
return p.Datasource.migrate(p.MigrationSource)
71+
}

0 commit comments

Comments
 (0)
Please sign in to comment.