Skip to content

Commit 00eea54

Browse files
committed
add driver check and improve run method
1 parent e1294ca commit 00eea54

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed

Migration.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ import "fmt"
55
// Migration is a complete, raw SQL command that can be ran
66
// against a database.
77
type Migration struct {
8-
SQL string
9-
Batch int
8+
SQL string
109
}
1110

1211
func (m Migration) String() string {

Migrator.go

+32-11
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,21 @@ var supportedDrivers = []string{
1616
// Migrator is responsible for receiving the incoming migrations
1717
// and running their SQL.
1818
type Migrator struct {
19-
DB *sql.DB
19+
DB *sql.DB
20+
Batch int
2021
}
2122

2223
// NewMigrator creates a new instance of a migrator.
23-
func NewMigrator(db *sql.DB) *Migrator {
24-
return &Migrator{
24+
func NewMigrator(db *sql.DB) (*Migrator, error) {
25+
m := Migrator{
2526
DB: db,
2627
}
28+
29+
if !m.driverIsSupported(m.getDriverName()) {
30+
return nil, fmt.Errorf("the %s driver is currently unsupported", m.getDriverName())
31+
}
32+
33+
return &m, nil
2734
}
2835

2936
// TableExists determines if a table exists on the database.
@@ -121,27 +128,31 @@ func (m *Migrator) lastBatchNumber() int {
121128
return num
122129
}
123130

124-
// Run uses the passed migration to update the passed database.
125-
func (m *Migrator) Run(migration MigrationInterface) error {
131+
// TODO: Wrap this in a transaction and reverse it
132+
func (m *Migrator) Run(migrations ...MigrationInterface) error {
126133
m.verifyMigrationsTable()
127134

128-
if _, err := m.DB.Exec(migration.Up().String()); err != nil {
129-
return err
130-
}
135+
batch := m.nextBatchNumber()
136+
137+
for _, migration := range migrations {
138+
if _, err := m.DB.Exec(migration.Up().String()); err != nil {
139+
return err
140+
}
131141

132-
m.addBatchToMigrationsTable(migration)
142+
m.addBatchToMigrationsTable(migration, batch)
143+
}
133144

134145
return nil
135146
}
136147

137-
func (m *Migrator) addBatchToMigrationsTable(migration MigrationInterface) {
148+
func (m *Migrator) addBatchToMigrationsTable(migration MigrationInterface, batch int) {
138149
stmt, err := m.DB.Prepare("INSERT INTO migrations (migration, batch) VALUES ( ?, ? )")
139150
if err != nil {
140151
log.Fatalln("Cannot create `migrations` batch statement. ")
141152
}
142153
defer stmt.Close()
143154

144-
if _, err = stmt.Exec(reflect.TypeOf(migration).String(), m.nextBatchNumber()); err != nil {
155+
if _, err = stmt.Exec(reflect.TypeOf(migration).String(), batch); err != nil {
145156
log.Fatalln(err)
146157
}
147158
}
@@ -156,6 +167,16 @@ func (m *Migrator) verifyMigrationsTable() {
156167
}
157168
}
158169

170+
func (m *Migrator) driverIsSupported(driver string) bool {
171+
for _, d := range supportedDrivers {
172+
if d == driver {
173+
return true
174+
}
175+
}
176+
177+
return false
178+
}
179+
159180
// getDriverName returns the name of the SQL driver currently
160181
// associated with the Migrator.
161182
func (m *Migrator) getDriverName() string {

0 commit comments

Comments
 (0)