|
| 1 | +package exodus |
| 2 | + |
| 3 | +import ( |
| 4 | + "database/sql" |
| 5 | + "fmt" |
| 6 | + "log" |
| 7 | + "reflect" |
| 8 | +) |
| 9 | + |
| 10 | +// supportedDrivers lists the drivers that currently work with |
| 11 | +// the migration framework. |
| 12 | +var supportedDrivers = []string{ |
| 13 | + "sqlite3", |
| 14 | +} |
| 15 | + |
| 16 | +// Migrator is responsible for receiving the incoming migrations |
| 17 | +// and running their SQL. |
| 18 | +type Migrator struct { |
| 19 | + DB *sql.DB |
| 20 | + Batch int |
| 21 | +} |
| 22 | + |
| 23 | +// NewMigrator creates a new instance of a migrator. |
| 24 | +func NewMigrator(db *sql.DB) (*Migrator, error) { |
| 25 | + m := Migrator{ |
| 26 | + DB: db, |
| 27 | + } |
| 28 | + |
| 29 | + if !m.driverIsSupported(m.getDriverName()) { |
| 30 | + return nil, fmt.Errorf("the %s driver is currently unsupported", m.getDriverName()) |
| 31 | + } |
| 32 | + |
| 33 | + return &m, nil |
| 34 | +} |
| 35 | + |
| 36 | +// TableExists determines if a table exists on the database. |
| 37 | +// TODO: Probably a better way of doing this. |
| 38 | +func (m *Migrator) TableExists(table string, database *sql.DB) bool { |
| 39 | + sql := fmt.Sprintf("SELECT * FROM %s LIMIT 1", table) |
| 40 | + if _, err := database.Exec(sql); err != nil { |
| 41 | + return false |
| 42 | + } |
| 43 | + |
| 44 | + return true |
| 45 | +} |
| 46 | + |
| 47 | +// Fresh drops all tables in the database. |
| 48 | +func (m *Migrator) Fresh(database *sql.DB) { |
| 49 | + if err := m.dropAllTables(database); err != nil { |
| 50 | + log.Fatalln(err) |
| 51 | + } |
| 52 | +} |
| 53 | + |
| 54 | +// dropAllTables grabs the tables from the database and drops |
| 55 | +// them in turn, stopping if there is an error. |
| 56 | +// TODO: Wrap this in a transaction, so it is cancelled if any |
| 57 | +// of the drops fail? |
| 58 | +func (m *Migrator) dropAllTables(database *sql.DB) error { |
| 59 | + // Get the SQL command to drop all tables for the current |
| 60 | + // SQL driver provided in the database connection. |
| 61 | + dropSQL, err := m.getDropSQLForDriver(m.getDriverName()) |
| 62 | + if err != nil { |
| 63 | + // If support for the driver does not exist, log a |
| 64 | + // fatal error. |
| 65 | + log.Fatalln("Unable to drop tables:", err) |
| 66 | + } |
| 67 | + |
| 68 | + rows, err := database.Query(dropSQL) |
| 69 | + if err != nil { |
| 70 | + return err |
| 71 | + } |
| 72 | + defer rows.Close() |
| 73 | + |
| 74 | + // tables is the list of tables returned from the database. |
| 75 | + var tables []string |
| 76 | + |
| 77 | + // for each row returned, add the name of it to the |
| 78 | + // tables slice. |
| 79 | + for rows.Next() { |
| 80 | + var name string |
| 81 | + if err := rows.Scan(&name); err != nil { |
| 82 | + return err |
| 83 | + } |
| 84 | + if name == "sqlite_sequence" { |
| 85 | + continue |
| 86 | + } |
| 87 | + tables = append(tables, name) |
| 88 | + } |
| 89 | + if err := rows.Err(); err != nil { |
| 90 | + return err |
| 91 | + } |
| 92 | + |
| 93 | + for _, table := range tables { |
| 94 | + if _, err := database.Exec("DROP TABLE IF EXISTS " + table); err != nil { |
| 95 | + return err |
| 96 | + } |
| 97 | + } |
| 98 | + |
| 99 | + return nil |
| 100 | +} |
| 101 | + |
| 102 | +func (m *Migrator) getDropSQLForDriver(d string) (string, error) { |
| 103 | + // TODO: Add more driver support. |
| 104 | + // Postgres? Then that'll do. |
| 105 | + if d == "sqlite3" { |
| 106 | + return "SELECT name FROM sqlite_master WHERE type='table'", nil |
| 107 | + } |
| 108 | + |
| 109 | + if d == "mysql" { |
| 110 | + return "SHOW FULL TABLES WHERE table_type = 'BASE TABLE'", nil |
| 111 | + } |
| 112 | + |
| 113 | + return "", fmt.Errorf("`%s` driver is not yet supported", d) |
| 114 | +} |
| 115 | + |
| 116 | +// nextBatchNumber retreives the highest batch number from the |
| 117 | +// migrations table and increments it by one. |
| 118 | +func (m *Migrator) nextBatchNumber() int { |
| 119 | + return m.lastBatchNumber() + 1 |
| 120 | +} |
| 121 | + |
| 122 | +// lastBatchNumber retrieves the number of the last batch ran |
| 123 | +// on the migrations table. |
| 124 | +func (m *Migrator) lastBatchNumber() int { |
| 125 | + r := m.DB.QueryRow("SELECT MAX(batch) FROM migrations") |
| 126 | + var num int |
| 127 | + r.Scan(&num) |
| 128 | + return num |
| 129 | +} |
| 130 | + |
| 131 | +// TODO: Wrap this in a transaction and reverse it |
| 132 | +func (m *Migrator) Run(migrations ...MigrationInterface) error { |
| 133 | + m.verifyMigrationsTable() |
| 134 | + |
| 135 | + batch := m.nextBatchNumber() |
| 136 | + |
| 137 | + for _, migration := range migrations { |
| 138 | + if _, err := m.DB.Exec(string(migration.Up())); err != nil { |
| 139 | + return err |
| 140 | + } |
| 141 | + |
| 142 | + m.addBatchToMigrationsTable(migration, batch) |
| 143 | + } |
| 144 | + |
| 145 | + return nil |
| 146 | +} |
| 147 | + |
| 148 | +func (m *Migrator) addBatchToMigrationsTable(migration MigrationInterface, batch int) { |
| 149 | + stmt, err := m.DB.Prepare("INSERT INTO migrations (migration, batch) VALUES ( ?, ? )") |
| 150 | + if err != nil { |
| 151 | + log.Fatalln("Cannot create `migrations` batch statement. ") |
| 152 | + } |
| 153 | + defer stmt.Close() |
| 154 | + |
| 155 | + if _, err = stmt.Exec(reflect.TypeOf(migration).String(), batch); err != nil { |
| 156 | + log.Fatalln(err) |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +// prepMigrations ensures that the migrations are ready to |
| 161 | +// be ran. |
| 162 | +func (m *Migrator) verifyMigrationsTable() { |
| 163 | + if !m.TableExists("migrations", m.DB) { |
| 164 | + if err := m.createMigrationsTable(); err != nil { |
| 165 | + log.Fatalln("Could not create `migrations` table: ", err) |
| 166 | + } |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +func (m *Migrator) driverIsSupported(driver string) bool { |
| 171 | + for _, d := range supportedDrivers { |
| 172 | + if d == driver { |
| 173 | + return true |
| 174 | + } |
| 175 | + } |
| 176 | + |
| 177 | + return false |
| 178 | +} |
| 179 | + |
| 180 | +// getDriverName returns the name of the SQL driver currently |
| 181 | +// associated with the Migrator. |
| 182 | +func (m *Migrator) getDriverName() string { |
| 183 | + sqlDriverNamesByType := map[reflect.Type]string{} |
| 184 | + |
| 185 | + for _, driverName := range sql.Drivers() { |
| 186 | + // Tested empty string DSN with MySQL, PostgreSQL, and SQLite3 drivers. |
| 187 | + db, _ := sql.Open(driverName, "") |
| 188 | + |
| 189 | + if db != nil { |
| 190 | + driverType := reflect.TypeOf(db.Driver()) |
| 191 | + sqlDriverNamesByType[driverType] = driverName |
| 192 | + } |
| 193 | + } |
| 194 | + |
| 195 | + driverType := reflect.TypeOf(m.DB.Driver()) |
| 196 | + if driverName, found := sqlDriverNamesByType[driverType]; found { |
| 197 | + return driverName |
| 198 | + } |
| 199 | + |
| 200 | + return "" |
| 201 | +} |
| 202 | + |
| 203 | +// createMigrationsTable makes a table to hold migrations and |
| 204 | +// the order that they were executed. |
| 205 | +func (m *Migrator) createMigrationsTable() error { |
| 206 | + migrationSchema := fmt.Sprintf( |
| 207 | + "CREATE TABLE migrations ( %s, %s, %s )", |
| 208 | + "id integer not null primary key autoincrement", |
| 209 | + "migration varchar not null", |
| 210 | + "batch integer not null", |
| 211 | + ) |
| 212 | + |
| 213 | + if _, err := m.DB.Exec(migrationSchema); err != nil { |
| 214 | + return fmt.Errorf("error creating migrations table: %s", err) |
| 215 | + } |
| 216 | + |
| 217 | + return nil |
| 218 | +} |
0 commit comments