@@ -16,14 +16,21 @@ var supportedDrivers = []string{
16
16
// Migrator is responsible for receiving the incoming migrations
17
17
// and running their SQL.
18
18
type Migrator struct {
19
- DB * sql.DB
19
+ DB * sql.DB
20
+ Batch int
20
21
}
21
22
22
23
// NewMigrator creates a new instance of a migrator.
23
- func NewMigrator (db * sql.DB ) * Migrator {
24
- return & Migrator {
24
+ func NewMigrator (db * sql.DB ) ( * Migrator , error ) {
25
+ m := Migrator {
25
26
DB : db ,
26
27
}
28
+
29
+ if ! m .driverIsSupported (m .getDriverName ()) {
30
+ return nil , fmt .Errorf ("the %s driver is currently unsupported" , m .getDriverName ())
31
+ }
32
+
33
+ return & m , nil
27
34
}
28
35
29
36
// TableExists determines if a table exists on the database.
@@ -121,27 +128,31 @@ func (m *Migrator) lastBatchNumber() int {
121
128
return num
122
129
}
123
130
124
- // Run uses the passed migration to update the passed database.
125
- func (m * Migrator ) Run (migration MigrationInterface ) error {
131
+ // TODO: Wrap this in a transaction and reverse it
132
+ func (m * Migrator ) Run (migrations ... MigrationInterface ) error {
126
133
m .verifyMigrationsTable ()
127
134
128
- if _ , err := m .DB .Exec (migration .Up ().String ()); err != nil {
129
- return err
130
- }
135
+ batch := m .nextBatchNumber ()
136
+
137
+ for _ , migration := range migrations {
138
+ if _ , err := m .DB .Exec (migration .Up ().String ()); err != nil {
139
+ return err
140
+ }
131
141
132
- m .addBatchToMigrationsTable (migration )
142
+ m .addBatchToMigrationsTable (migration , batch )
143
+ }
133
144
134
145
return nil
135
146
}
136
147
137
- func (m * Migrator ) addBatchToMigrationsTable (migration MigrationInterface ) {
148
+ func (m * Migrator ) addBatchToMigrationsTable (migration MigrationInterface , batch int ) {
138
149
stmt , err := m .DB .Prepare ("INSERT INTO migrations (migration, batch) VALUES ( ?, ? )" )
139
150
if err != nil {
140
151
log .Fatalln ("Cannot create `migrations` batch statement. " )
141
152
}
142
153
defer stmt .Close ()
143
154
144
- if _ , err = stmt .Exec (reflect .TypeOf (migration ).String (), m . nextBatchNumber () ); err != nil {
155
+ if _ , err = stmt .Exec (reflect .TypeOf (migration ).String (), batch ); err != nil {
145
156
log .Fatalln (err )
146
157
}
147
158
}
@@ -156,6 +167,16 @@ func (m *Migrator) verifyMigrationsTable() {
156
167
}
157
168
}
158
169
170
+ func (m * Migrator ) driverIsSupported (driver string ) bool {
171
+ for _ , d := range supportedDrivers {
172
+ if d == driver {
173
+ return true
174
+ }
175
+ }
176
+
177
+ return false
178
+ }
179
+
159
180
// getDriverName returns the name of the SQL driver currently
160
181
// associated with the Migrator.
161
182
func (m * Migrator ) getDriverName () string {
0 commit comments