@@ -35,7 +35,6 @@ func NewDriver(datasource string) (*mysqlDriver, error) {
35
35
}
36
36
37
37
func (d mysqlDriver ) Init () error {
38
- d .fresh ()
39
38
// Before running migrations, make sure that the migrations table exists on
40
39
// the underlying database. This table is used to track which migrations
41
40
// have already been ran. If it doesn't exist, then create it.
@@ -52,7 +51,18 @@ func (d mysqlDriver) Close() error {
52
51
return d .db .Close ()
53
52
}
54
53
55
- func (d mysqlDriver ) Run (migrations []exodus.Migration ) error {
54
+ func (d mysqlDriver ) Run (opts exodus.Options , migrations []exodus.Migration ) error {
55
+ switch opts .Direction () {
56
+ case exodus .Up :
57
+ return d .runUp (opts , migrations )
58
+ case exodus .Down :
59
+ return d .runDown (opts , migrations )
60
+ default :
61
+ return errors .New ("not running migrations, direction not specified" )
62
+ }
63
+ }
64
+
65
+ func (d mysqlDriver ) runUp (opts exodus.Options , migrations []exodus.Migration ) error {
56
66
// First, retrieve the list of migrations that have previously been ran. These
57
67
// migrations are then used to determine which of the incoming migrations
58
68
// should be ran against the database.
@@ -78,6 +88,29 @@ func (d mysqlDriver) Run(migrations []exodus.Migration) error {
78
88
return nil
79
89
}
80
90
91
+ func (d mysqlDriver ) runDown (opts exodus.Options , migrations []exodus.Migration ) error {
92
+ // First, retrieve the list of migrations that have previously been ran. These
93
+ // will be used to retrieve the Down method of each migration to run.
94
+ previous , err := d .getLastBatchRan ()
95
+ if err != nil {
96
+ return fmt .Errorf ("unable to process migrations: %w" , err )
97
+ }
98
+ migrations = filterRanMigrations (migrations , previous )
99
+
100
+ if len (migrations ) == 0 {
101
+ d .log .Info ().Msg ("Nothing to rollback." )
102
+ return nil
103
+ }
104
+
105
+ for _ , migration := range migrations {
106
+ if err := d .processDown (migration ); err != nil {
107
+ return fmt .Errorf ("unable to execute SQL: %w" , err )
108
+ }
109
+ }
110
+
111
+ return nil
112
+ }
113
+
81
114
func (d mysqlDriver ) getRan () ([]string , error ) {
82
115
rows , err := d .db .Query ("SELECT migration FROM migrations" )
83
116
if err != nil {
@@ -99,6 +132,27 @@ func (d mysqlDriver) getRan() ([]string, error) {
99
132
return ran , nil
100
133
}
101
134
135
+ func (d mysqlDriver ) getLastBatchRan () ([]string , error ) {
136
+ rows , err := d .db .Query ("SELECT migration FROM migrations WHERE batch = (SELECT MAX(batch) FROM migrations)" )
137
+ if err != nil {
138
+ return []string {}, fmt .Errorf ("unable to get previous migrations from database: %w" , err )
139
+ }
140
+
141
+ var ran []string
142
+ for rows .Next () {
143
+ var migration string
144
+ if err := rows .Scan (& migration ); err != nil {
145
+ return []string {}, fmt .Errorf ("unable to get previous migrations from database: %w" , err )
146
+ }
147
+ ran = append (ran , migration )
148
+ }
149
+ if err := rows .Err (); err != nil {
150
+ return []string {}, fmt .Errorf ("unable to get previous migrations from database: %w" , err )
151
+ }
152
+
153
+ return ran , nil
154
+ }
155
+
102
156
func (d mysqlDriver ) process (migration exodus.Migration , batch int ) error {
103
157
builder := & exodus.MigrationPayload {}
104
158
migration .Up (builder )
@@ -125,6 +179,36 @@ func (d mysqlDriver) process(migration exodus.Migration, batch int) error {
125
179
return nil
126
180
}
127
181
182
+ func (d mysqlDriver ) processDown (migration exodus.Migration ) error {
183
+ builder := & exodus.MigrationPayload {}
184
+ migration .Down (builder )
185
+ start := time .Now ()
186
+ d .log .Info ().Msgf ("Rolling back: %s" , reflect .TypeOf (migration ).String ())
187
+ for _ , p := range builder .Operations () {
188
+ switch p .Operation () {
189
+ case exodus .CREATE_TABLE :
190
+ if err := d .createTable (p ); err != nil {
191
+ return err
192
+ }
193
+ case exodus .RENAME_TABLE :
194
+ if err := d .renameTable (p ); err != nil {
195
+ return err
196
+ }
197
+ case exodus .DROP_TABLE :
198
+ if err := d .dropTable (p ); err != nil {
199
+ return err
200
+ }
201
+ default :
202
+ return errors .New ("operation not supported" )
203
+ }
204
+ }
205
+
206
+ d .removeMigrationLog (migration )
207
+
208
+ d .log .Info ().Msgf ("Rolled back: %s in %v" , reflect .TypeOf (migration ).String (), time .Since (start ))
209
+ return nil
210
+ }
211
+
128
212
func (d mysqlDriver ) logMigration (migration exodus.Migration , batch int ) error {
129
213
stmt , err := d .db .Prepare ("INSERT INTO migrations (migration, batch) VALUES ( ?, ? )" )
130
214
if err != nil {
@@ -139,6 +223,20 @@ func (d mysqlDriver) logMigration(migration exodus.Migration, batch int) error {
139
223
return nil
140
224
}
141
225
226
+ func (d mysqlDriver ) removeMigrationLog (migration exodus.Migration ) error {
227
+ stmt , err := d .db .Prepare ("DELETE FROM migrations WHERE migration = ?" )
228
+ if err != nil {
229
+ log .Fatalln ("Cannot create `migrations` remove statement. " )
230
+ }
231
+ defer stmt .Close ()
232
+
233
+ if _ , err = stmt .Exec (reflect .TypeOf (migration ).String ()); err != nil {
234
+ return err
235
+ }
236
+
237
+ return nil
238
+ }
239
+
142
240
// nextBatchNumber retreives the highest batch number from the
143
241
// migrations table and increments it by one.
144
242
func (d mysqlDriver ) nextBatchNumber () int {
@@ -154,6 +252,17 @@ func (d mysqlDriver) lastBatchNumber() int {
154
252
return num
155
253
}
156
254
255
+ func (d mysqlDriver ) dropTable (payload * exodus.MigrationOperation ) error {
256
+ table := payload .Table ()
257
+ sql := fmt .Sprintf (dropTableSchema , table )
258
+
259
+ if _ , err := d .db .Exec (sql ); err != nil {
260
+ return fmt .Errorf ("unable to drop table `%s`: %w" , table , err )
261
+ }
262
+
263
+ return nil
264
+ }
265
+
157
266
func (d mysqlDriver ) renameTable (payload * exodus.MigrationOperation ) error {
158
267
from := payload .Table ()
159
268
to := payload .Payload ().(string )
@@ -254,6 +363,18 @@ func filterPendingMigrations(migrations []exodus.Migration, existing []string) [
254
363
return response
255
364
}
256
365
366
+ func filterRanMigrations (migrations []exodus.Migration , existing []string ) []exodus.Migration {
367
+ var response []exodus.Migration
368
+
369
+ for _ , migration := range migrations {
370
+ if exists (migration , existing ) {
371
+ response = append (response , migration )
372
+ }
373
+ }
374
+
375
+ return response
376
+ }
377
+
257
378
func exists (migration exodus.Migration , existing []string ) bool {
258
379
for _ , ex := range existing {
259
380
if reflect .TypeOf (migration ).String () == ex {
0 commit comments