Skip to content

Commit ad4c76f

Browse files
authored
atlasaction: refactor logic exec and change default merge message (#366)
1 parent 6229fe0 commit ad4c76f

File tree

4 files changed

+175
-90
lines changed

4 files changed

+175
-90
lines changed

atlasaction/action.go

+102-77
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
"errors"
1616
"fmt"
1717
"io"
18+
"iter"
19+
"maps"
1820
"net/url"
1921
"os"
2022
"os/exec"
@@ -568,125 +570,120 @@ func (a *Actions) MigrateTest(ctx context.Context) error {
568570

569571
// MigrateAutoRebase runs the Action for "ariga/atlas-action/migrate/autorebase"
570572
func (a *Actions) MigrateAutoRebase(ctx context.Context) error {
571-
gitVer, err := a.CmdExecutor(ctx, "git", "--version").Output()
572-
switch err := err.(type) {
573-
case nil:
574-
a.Infof("running with git version: %s", string(gitVer))
575-
case *exec.ExitError:
576-
return fmt.Errorf("failed to get git version: stderr %s", string(err.Stderr))
577-
default:
578-
return fmt.Errorf("failed to get git version: %w", err)
579-
}
580-
dirpath := strings.TrimPrefix(a.GetInput("dir"), "file://")
581-
if dirpath == "" {
582-
dirpath = "migrations"
583-
}
584-
sumpath := filepath.Join(a.WorkingDir(), dirpath, migrate.HashFileName)
585573
tc, err := a.GetTriggerContext(ctx)
586574
if err != nil {
587575
return err
588576
}
589577
var (
578+
remote = a.GetInputDefault("remote", "origin")
579+
baseBranch = a.GetInputDefault("base-branch", tc.DefaultBranch)
590580
currBranch = tc.Branch
591-
baseBranch = a.GetInput("base-branch")
592-
remote = a.GetInput("remote")
593581
)
594-
if baseBranch == "" {
595-
baseBranch = tc.DefaultBranch
596-
}
597-
if remote == "" {
598-
remote = "origin"
582+
if v, err := a.exec(ctx, "git", "--version"); err != nil {
583+
return fmt.Errorf("failed to get git version: %w", err)
584+
} else {
585+
a.Infof("auto-rebase with %s", v)
599586
}
600-
if out, err := a.CmdExecutor(ctx, "git", "fetch", remote, baseBranch).Output(); err != nil {
601-
a.Errorf(string(out))
587+
if _, err := a.exec(ctx, "git", "fetch", remote, baseBranch); err != nil {
602588
return fmt.Errorf("failed to fetch the branch %s: %w", baseBranch, err)
603589
}
604590
// Since running in detached HEAD, we need to switch to the branch.
605-
if out, err := a.CmdExecutor(ctx, "git", "checkout", currBranch).Output(); err != nil {
606-
a.Errorf(string(out))
591+
if _, err := a.exec(ctx, "git", "checkout", currBranch); err != nil {
607592
return fmt.Errorf("failed to checkout to the branch: %w", err)
608593
}
609-
incoming, err := a.CmdExecutor(ctx, "git", "show", fmt.Sprintf("%s/%s:%s", remote, baseBranch, sumpath)).Output()
594+
dirURL := a.GetInputDefault("dir", "file://migrations")
595+
u, err := url.Parse(dirURL)
610596
if err != nil {
611-
a.Errorf(string(incoming))
612-
return fmt.Errorf("failed to get the atlas.sum file from the rebase branch: %w", err)
597+
return fmt.Errorf("failed to parse dir URL: %w", err)
613598
}
614-
base, err := a.CmdExecutor(ctx, "git", "show", fmt.Sprintf("%s/%s:%s", remote, currBranch, sumpath)).Output()
599+
dirPath := filepath.Join(u.Host, u.Path)
600+
sumPath := filepath.Join(a.WorkingDir(), dirPath, migrate.HashFileName)
601+
baseHash, err := a.hashFileFrom(ctx, remote, baseBranch, sumPath)
615602
if err != nil {
616-
a.Errorf(string(base))
617-
return fmt.Errorf("failed to get the atlas.sum file from current branch: %w", err)
618-
}
619-
var incomingHash, baseHash migrate.HashFile
620-
if err := incomingHash.UnmarshalText(incoming); err != nil {
621-
return fmt.Errorf("failed to unmarshal incoming atlas.sum: %w", err)
622-
}
623-
if err := baseHash.UnmarshalText(base); err != nil {
624-
return fmt.Errorf("failed to unmarshal base atlas.sum: %w", err)
625-
}
626-
incomingFilesSet := make(map[string]struct{})
627-
for _, v := range incomingHash {
628-
incomingFilesSet[v.N] = struct{}{}
603+
return fmt.Errorf("failed to get the atlas.sum file from the base branch: %w", err)
629604
}
630-
baseNames := make([]string, len(baseHash))
631-
for i, v := range baseHash {
632-
baseNames[i] = v.N
633-
}
634-
// Get all the file names the exists only in the base branch atlas.sum file.
635-
var onlyInBase []string
636-
for _, file := range baseNames {
637-
if _, ok := incomingFilesSet[file]; !ok {
638-
onlyInBase = append(onlyInBase, file)
639-
}
605+
currHash, err := a.hashFileFrom(ctx, remote, currBranch, sumPath)
606+
if err != nil {
607+
return fmt.Errorf("failed to get the atlas.sum file from the current branch: %w", err)
640608
}
641-
if len(onlyInBase) == 0 {
642-
a.Infof("No files to rebase")
609+
files := newFiles(baseHash, currHash)
610+
if len(files) == 0 {
611+
a.Infof("No new migration files to rebase")
643612
return nil
644613
}
645614
// Try to merge the base branch into the current branch.
646-
out, err := a.CmdExecutor(ctx, "git", "merge", "--no-ff", fmt.Sprintf("%s/%s", remote, baseBranch)).Output()
647-
switch err := err.(type) {
648-
case nil:
615+
if _, err := a.exec(ctx, "git", "merge", "--no-ff",
616+
fmt.Sprintf("%s/%s", remote, baseBranch)); err == nil {
649617
a.Infof("No conflict found when merging %s into %s", baseBranch, currBranch)
650618
return nil
651-
case *exec.ExitError:
652-
a.Infof("Running `git merge` got following error: %s", string(err.Stderr))
653-
a.Infof("git merge output: %s", string(out))
654-
default:
655-
return fmt.Errorf("receive unexpected error %w", err)
656619
}
657620
// If merge failed due to conflict, check that the conflict is only in atlas.sum file.
658-
diff, err := a.CmdExecutor(ctx, "git", "diff", "--name-only", "--diff-filter=U").Output()
659-
if err != nil {
660-
a.Errorf(string(diff))
621+
switch out, err := a.exec(ctx, "git", "diff", "--name-only", "--diff-filter=U"); {
622+
case err != nil:
661623
return fmt.Errorf("failed to get conflicting files: %w", err)
662-
}
663-
conflictFiles := strings.Split(strings.TrimSpace(string(diff)), "\n")
664-
if len(conflictFiles) != 1 || conflictFiles[0] != sumpath {
665-
return fmt.Errorf("conflict found in files other than %s, conflict files: %v", sumpath, conflictFiles)
624+
case len(out) == 0:
625+
return errors.New("conflict found but no conflicting files found")
626+
case strings.TrimSpace(string(out)) != sumPath:
627+
a.Infof("Conflict files are:\n%s", out)
628+
return fmt.Errorf("conflict found in files other than %s", sumPath)
666629
}
667630
// Re-hash the migrations and rebase the migrations.
668-
if err = a.Atlas.MigrateHash(ctx, &atlasexec.MigrateHashParams{DirURL: a.GetInput("dir")}); err != nil {
631+
if err = a.Atlas.MigrateHash(ctx, &atlasexec.MigrateHashParams{
632+
DirURL: dirURL,
633+
}); err != nil {
669634
return fmt.Errorf("failed to run `atlas migrate hash`: %w", err)
670635
}
671-
if err = a.Atlas.MigrateRebase(ctx, &atlasexec.MigrateRebaseParams{DirURL: a.GetInput("dir"), Files: onlyInBase}); err != nil {
636+
if err = a.Atlas.MigrateRebase(ctx, &atlasexec.MigrateRebaseParams{
637+
DirURL: dirURL,
638+
Files: files,
639+
}); err != nil {
672640
return fmt.Errorf("failed to rebase migrations: %w", err)
673641
}
674-
if out, err = a.CmdExecutor(ctx, "git", "add", dirpath).CombinedOutput(); err != nil {
675-
a.Errorf(string(out))
642+
if _, err = a.exec(ctx, "git", "add", dirPath); err != nil {
676643
return fmt.Errorf("failed to stage changes: %w", err)
677644
}
678-
if out, err = a.CmdExecutor(ctx, "git", "commit", "-m", fmt.Sprintf("Rebase migrations in %s", dirpath)).CombinedOutput(); err != nil {
679-
a.Errorf(string(out))
645+
if _, err = a.exec(ctx, "git", "commit", "--message",
646+
fmt.Sprintf("%s: rebase migration files", dirPath)); err != nil {
680647
return fmt.Errorf("failed to commit changes: %w", err)
681648
}
682-
if out, err = a.CmdExecutor(ctx, "git", "push", remote, currBranch).CombinedOutput(); err != nil {
683-
a.Errorf(string(out))
649+
if _, err = a.exec(ctx, "git", "push", remote, currBranch); err != nil {
684650
return fmt.Errorf("failed to push changes: %w", err)
685651
}
686652
a.Infof("Migrations rebased successfully")
687653
return nil
688654
}
689655

656+
// hashFileFrom returns the hash file from the remote branch.
657+
func (a *Actions) hashFileFrom(ctx context.Context, remote, branch, path string) (migrate.HashFile, error) {
658+
data, err := a.exec(ctx, "git", "show",
659+
fmt.Sprintf("%s/%s:%s", remote, branch, path))
660+
if err != nil {
661+
return nil, err
662+
}
663+
var hf migrate.HashFile
664+
if err := hf.UnmarshalText(data); err != nil {
665+
return nil, fmt.Errorf("failed to unmarshal atlas.sum: %w", err)
666+
}
667+
return hf, nil
668+
}
669+
670+
// exec runs the command and returns the output.
671+
func (a *Actions) exec(ctx context.Context, name string, args ...string) ([]byte, error) {
672+
cmd := a.CmdExecutor(ctx, name, args...)
673+
out, err := cmd.Output()
674+
switch err := err.(type) {
675+
case nil:
676+
return out, nil
677+
case *exec.ExitError:
678+
if err.Stderr != nil {
679+
a.Infof("Running %q got following error: %s", cmd.String(), string(err.Stderr))
680+
}
681+
return nil, fmt.Errorf("failed to run %s: %w", name, err)
682+
default:
683+
return nil, fmt.Errorf("failed to run %s: %w", name, err)
684+
}
685+
}
686+
690687
// SchemaPush runs the GitHub Action for "ariga/atlas-action/schema/push"
691688
func (a *Actions) SchemaPush(ctx context.Context) error {
692689
tc, err := a.GetTriggerContext(ctx)
@@ -1315,6 +1312,15 @@ func (a *Actions) GetArrayInput(name string) []string {
13151312
})
13161313
}
13171314

1315+
// GetInputDefault returns the input with the given name.
1316+
// If the input is empty, it returns the default value.
1317+
func (a *Actions) GetInputDefault(name, def string) string {
1318+
if v := a.GetInput(name); v != "" {
1319+
return v
1320+
}
1321+
return def
1322+
}
1323+
13181324
// DeployRunContext returns the run context for the `migrate/apply`, and `migrate/down` actions.
13191325
func (a *Actions) DeployRunContext() *atlasexec.DeployRunContext {
13201326
return &atlasexec.DeployRunContext{
@@ -1447,6 +1453,25 @@ func (tc *TriggerContext) GetRunContext() *atlasexec.RunContext {
14471453
return rc
14481454
}
14491455

1456+
// newFiles returns the files that only exists in the current hash.
1457+
func newFiles(base, current migrate.HashFile) []string {
1458+
m := maps.Collect(hashIter(current))
1459+
for k := range hashIter(base) {
1460+
delete(m, k)
1461+
}
1462+
return slices.Collect(maps.Keys(m))
1463+
}
1464+
1465+
func hashIter(hf migrate.HashFile) iter.Seq2[string, string] {
1466+
return func(yield func(string, string) bool) {
1467+
for _, v := range hf {
1468+
if !yield(v.N, v.H) {
1469+
return
1470+
}
1471+
}
1472+
}
1473+
}
1474+
14501475
func execTime(start, end time.Time) string {
14511476
return end.Sub(start).String()
14521477
}

atlasaction/action_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,7 @@ func (m *MockCmdExecutor) ExecCmd(ctx context.Context, name string, args ...stri
878878
return m.onCommand(ctx, name, args...)
879879
}
880880

881-
func TestMigrateAutorebase(t *testing.T) {
881+
func TestMigrateAutoRebase(t *testing.T) {
882882
t.Run("no conflict", func(t *testing.T) {
883883
c, err := atlasexec.NewClient("", "atlas")
884884
require.NoError(t, err)
@@ -915,7 +915,7 @@ func TestMigrateAutorebase(t *testing.T) {
915915
require.NoError(t, err)
916916

917917
require.NoError(t, acts.MigrateAutoRebase(context.Background()))
918-
require.Contains(t, out.String(), "No files to rebase")
918+
require.Contains(t, out.String(), "No new migration files to rebase")
919919
// Check that the correct git commands were executed
920920
require.Len(t, mockExec.ran, 5)
921921
require.Equal(t, []string{"--version"}, mockExec.ran[0].args)
@@ -998,7 +998,7 @@ func TestMigrateAutorebase(t *testing.T) {
998998
require.Equal(t, []string{"merge", "--no-ff", "origin/rebase-branch"}, mockExec.ran[5].args)
999999
require.Equal(t, []string{"diff", "--name-only", "--diff-filter=U"}, mockExec.ran[6].args)
10001000
require.Equal(t, []string{"add", "testdata/need_rebase"}, mockExec.ran[7].args)
1001-
require.Equal(t, []string{"commit", "-m", "Rebase migrations in testdata/need_rebase"}, mockExec.ran[8].args)
1001+
require.Equal(t, []string{"commit", "--message", "testdata/need_rebase: rebase migration files"}, mockExec.ran[8].args)
10021002
require.Equal(t, []string{"push", "origin", "my-branch"}, mockExec.ran[9].args)
10031003
})
10041004
t.Run("conflict, but not only in atlas.sum", func(t *testing.T) {
@@ -1049,7 +1049,7 @@ func TestMigrateAutorebase(t *testing.T) {
10491049
require.NoError(t, err)
10501050

10511051
err = acts.MigrateAutoRebase(context.Background())
1052-
require.EqualError(t, err, "conflict found in files other than testdata/need_rebase/atlas.sum, conflict files: [testdata/need_rebase/atlas.sum not_atlas.sum]")
1052+
require.EqualError(t, err, "conflict found in files other than testdata/need_rebase/atlas.sum")
10531053
// Check that the correct git commands were executed
10541054
require.Len(t, mockExec.ran, 7)
10551055
require.Equal(t, []string{"--version"}, mockExec.ran[0].args)

atlasaction/circleci_action_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"testing"
1616

1717
"ariga.io/atlas-action/atlasaction"
18+
"ariga.io/atlas/sql/migrate"
1819
"github.com/rogpeppe/go-internal/testscript"
1920
"github.com/stretchr/testify/require"
2021
)
@@ -125,6 +126,29 @@ func TestCircleCI(t *testing.T) {
125126
}
126127
cmpFiles(ts, neg, args[0], output)
127128
},
129+
"hashFile": func(ts *testscript.TestScript, neg bool, args []string) {
130+
if len(args) != 1 {
131+
ts.Fatalf("usage: hashFile <file>")
132+
}
133+
var hf migrate.HashFile
134+
if err := hf.UnmarshalText([]byte(ts.ReadFile(args[0]))); err != nil {
135+
ts.Fatalf("failed to unmarshal hash file: %v", err)
136+
return
137+
}
138+
var files []string
139+
for _, f := range hf {
140+
files = append(files, f.N)
141+
}
142+
fmt.Fprintf(ts.Stdout(), "%v", files)
143+
},
144+
"writeFile": func(ts *testscript.TestScript, neg bool, args []string) {
145+
if len(args) != 2 {
146+
ts.Fatalf("usage: writeFile <file> <content>")
147+
}
148+
if err := os.WriteFile(args[0], []byte(args[1]), 0600); err != nil {
149+
ts.Fatalf("failed to write file: %v", err)
150+
}
151+
},
128152
},
129153
})
130154
}

0 commit comments

Comments
 (0)