From 9551e8a744983c51b0230388cdb2c056da1e2585 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Sun, 28 May 2017 16:34:46 +1000 Subject: [PATCH] add redoall command --- cmd/mig/redo.go | 37 +++++++++++++++++++++++++++++++++++-- migrate_test.go | 26 +++++++++++++------------- migration_test.go | 5 ++++- 3 files changed, 52 insertions(+), 16 deletions(-) diff --git a/cmd/mig/redo.go b/cmd/mig/redo.go index ed80c7c93..4ac6e0fee 100644 --- a/cmd/mig/redo.go +++ b/cmd/mig/redo.go @@ -10,12 +10,20 @@ import ( var redoCmd = &cobra.Command{ Use: "redo", - Short: "Re-run the latest migration", - Long: "Re-run the latest migration", + Short: "Down then up the latest migration", + Long: "Down then up the latest migration", Example: `mig redo postgres "user=postgres dbname=postgres sslmode=disable"`, RunE: redoRunE, } +var redoAllCmd = &cobra.Command{ + Use: "redo", + Short: "Down then up all migrations", + Long: "Down then up all migrations", + Example: `mig redoall postgres "user=postgres dbname=postgres sslmode=disable"`, + RunE: redoAllRunE, +} + func init() { redoCmd.Flags().StringP("dir", "d", ".", "directory with migration files") @@ -42,3 +50,28 @@ func redoRunE(cmd *cobra.Command, args []string) error { fmt.Printf("Success %v\n", name) return nil } + +func redoAllRunE(cmd *cobra.Command, args []string) error { + driver, conn, err := getConnArgs(args) + if err != nil { + return err + } + + _, err = mig.DownAll(driver, conn, viper.GetString("dir")) + if err != nil { + return err + } + + count, err := mig.Up(driver, conn, viper.GetString("dir")) + if err != nil { + return err + } + + if count == 0 { + fmt.Printf("No migrations to run") + } else { + fmt.Printf("Success %d migrations\n", count) + } + + return nil +} diff --git a/migrate_test.go b/migrate_test.go index daf700499..a372f561d 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -4,13 +4,13 @@ import ( "testing" ) -func newMigration(v int64, src string) *Migration { - return &Migration{Version: v, Previous: -1, Next: -1, Source: src} +func newMigration(v int64, src string) *migration { + return &migration{version: v, previous: -1, next: -1, source: src} } func TestMigrationSort(t *testing.T) { - ms := Migrations{} + ms := migrations{} // insert in any order ms = append(ms, newMigration(20120000, "test")) @@ -25,10 +25,10 @@ func TestMigrationSort(t *testing.T) { validateMigrationSort(t, ms, sorted) } -func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { +func validateMigrationSort(t *testing.T, ms migrations, sorted []int64) { for i, m := range ms { - if sorted[i] != m.Version { + if sorted[i] != m.version { t.Error("incorrect sorted version") } @@ -36,21 +36,21 @@ func validateMigrationSort(t *testing.T, ms Migrations, sorted []int64) { if i == 0 { prev = -1 - next = ms[i+1].Version + next = ms[i+1].version } else if i == len(ms)-1 { - prev = ms[i-1].Version + prev = ms[i-1].version next = -1 } else { - prev = ms[i-1].Version - next = ms[i+1].Version + prev = ms[i-1].version + next = ms[i+1].version } - if m.Next != next { - t.Errorf("mismatched Next. v: %v, got %v, wanted %v\n", m, m.Next, next) + if m.next != next { + t.Errorf("mismatched next. v: %v, got %v, wanted %v\n", m, m.next, next) } - if m.Previous != prev { - t.Errorf("mismatched Previous v: %v, got %v, wanted %v\n", m, m.Previous, prev) + if m.previous != prev { + t.Errorf("mismatched previous v: %v, got %v, wanted %v\n", m, m.previous, prev) } } diff --git a/migration_test.go b/migration_test.go index 4c22bbd08..e2a60f1f5 100644 --- a/migration_test.go +++ b/migration_test.go @@ -79,7 +79,10 @@ func TestSplitStatements(t *testing.T) { } for _, test := range tests { - stmts := splitSQLStatements(strings.NewReader(test.sql), test.direction) + stmts, err := splitSQLStatements(strings.NewReader(test.sql), test.direction) + if err != nil { + t.Error(err) + } if len(stmts) != test.count { t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count) }