Skip to content

Commit

Permalink
cmd/atlascmd: support project file in inspect (#771)
Browse files Browse the repository at this point in the history
  • Loading branch information
rotemtam authored May 12, 2022
1 parent 4c0e28e commit 7910885
Show file tree
Hide file tree
Showing 7 changed files with 329 additions and 50 deletions.
75 changes: 75 additions & 0 deletions cmd/atlascmd/project.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package atlascmd

import (
"fmt"
"os"

"ariga.io/atlas/schema/schemaspec"
"ariga.io/atlas/schema/schemaspec/schemahcl"
)

const projectFileName = "atlas.hcl"

// projectFile represents an atlas.hcl file.
type projectFile struct {
Envs []*Env `spec:"env"`
}

// Env represents an Atlas environment.
type Env struct {
// Name for this environment.
Name string `spec:"name,name"`

// URL of the database.
URL string `spec:"url"`

// URL of the dev-database for this environment.
// See: https://atlasgo.io/dev-database
DevURL string `spec:"dev"`

// Path to the file containing the desired schema of the environment.
Source string `spec:"src"`

// List of schemas in this database that are managed by Atlas.
Schemas []string `spec:"schemas"`
schemaspec.DefaultExtension
}

// LoadEnv reads the project file in path, and loads the environment
// with the provided name into env.
func LoadEnv(path string, name string) (*Env, error) {
b, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var project projectFile
if err := schemahcl.New().Eval(b, &project, nil); err != nil {
return nil, fmt.Errorf("error reading project file: %w", err)
}
projEnvs := make(map[string]*Env)
for _, e := range project.Envs {
if _, ok := projEnvs[e.Name]; ok {
return nil, fmt.Errorf("duplicate environment name %q", e.Name)
}
if e.Name == "" {
return nil, fmt.Errorf("all envs must have names on file %q", path)
}
if e.URL == "" {
return nil, fmt.Errorf("no url set for e %q", e.Name)
}
projEnvs[e.Name] = e
}
selected, ok := projEnvs[name]
if !ok {
return nil, fmt.Errorf("env %q not defined in project file", name)
}
return selected, nil
}

func init() {
schemaspec.Register("env", &Env{})
}
58 changes: 58 additions & 0 deletions cmd/atlascmd/project_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2021-present The Atlas Authors. All rights reserved.
// This source code is licensed under the Apache 2.0 license found
// in the LICENSE file in the root directory of this source tree.

package atlascmd

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/require"
)

func TestLoadEnv(t *testing.T) {
d := t.TempDir()
h := `
env "local" {
url = "mysql://root:pass@localhost:3306/"
dev = "docker://mysql/8"
src = "./app.hcl"
schemas = ["hello", "world"]
}
`
err := os.WriteFile(filepath.Join(d, projectFileName), []byte(h), 0600)
require.NoError(t, err)
path := filepath.Join(d, projectFileName)
t.Run("ok", func(t *testing.T) {
env := &Env{}
env, err = LoadEnv(path, "local")
require.NoError(t, err)
require.EqualValues(t, &Env{
Name: "local",
URL: "mysql://root:pass@localhost:3306/",
DevURL: "docker://mysql/8",
Source: "./app.hcl",
Schemas: []string{"hello", "world"},
}, env)
})
t.Run("wrong env", func(t *testing.T) {
_, err = LoadEnv(path, "home")
require.EqualError(t, err, `env "home" not defined in project file`)
})
t.Run("wrong dir", func(t *testing.T) {
wd, err := os.Getwd()
require.NoError(t, err)
_, err = LoadEnv(filepath.Join(wd, projectFileName), "home")
require.ErrorContains(t, err, `no such file or directory`)
})
t.Run("duplicate env", func(t *testing.T) {
dup := h + "\n" + h
path := filepath.Join(d, "dup.hcl")
err = os.WriteFile(path, []byte(dup), 0600)
require.NoError(t, err)
_, err = LoadEnv(path, "local")
require.EqualError(t, err, `duplicate environment name "local"`)
})
}
112 changes: 87 additions & 25 deletions cmd/atlascmd/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"context"
"errors"
"fmt"
"io/fs"
"io/ioutil"
"os"
Expand Down Expand Up @@ -54,7 +55,7 @@ migration, Atlas will print the migration plan and prompt the user for approval.
If run with the "--dry-run" flag, atlas will exit after printing out the planned
migration.`,
Run: CmdApplyRun,
RunE: CmdApplyRun,
Example: ` atlas schema apply -u "mysql://user:pass@localhost/dbname" -f atlas.hcl
atlas schema apply -u "mysql://localhost" -f atlas.hcl --schema prod --schema staging
atlas schema apply -u "mysql://user:pass@localhost:3306/dbname" -f atlas.hcl --dry-run
Expand Down Expand Up @@ -114,6 +115,14 @@ const (
answerAbort = "Abort"
)

// selectEnv selects the environment config from the current directory project file.
func selectEnv(args []string) (*Env, error) {
if len(args) == 0 {
return nil, nil
}
return LoadEnv(projectFileName, args[0])
}

func init() {
// Schema apply flags.
schemaCmd.AddCommand(SchemaApply)
Expand All @@ -129,8 +138,7 @@ func init() {
SchemaApply.Flags().BoolVarP(&ApplyFlags.Verbose, migrateDiffFlagVerbose, "", false, "enable verbose logging")
SchemaApply.Flags().StringToStringVarP(&ApplyFlags.Vars, "var", "", nil, "input variables")
cobra.CheckErr(SchemaApply.MarkFlagRequired("url"))
cobra.CheckErr(SchemaApply.MarkFlagRequired("file"))
dsn2url(SchemaApply, &ApplyFlags.URL)
fixURLFlag(SchemaApply, &ApplyFlags.URL)

// Schema inspect flags.
schemaCmd.AddCommand(SchemaInspect)
Expand All @@ -139,14 +147,14 @@ func init() {
SchemaInspect.Flags().StringVarP(&InspectFlags.Addr, "addr", "", ":5800", "Used with -w, local address to bind the server to")
SchemaInspect.Flags().StringSliceVarP(&InspectFlags.Schema, "schema", "s", nil, "Set schema name")
cobra.CheckErr(SchemaInspect.MarkFlagRequired("url"))
dsn2url(SchemaInspect, &InspectFlags.URL)
fixURLFlag(SchemaInspect, &InspectFlags.URL)

// Schema fmt.
schemaCmd.AddCommand(SchemaFmt)
}

// CmdInspectRun is the command used when running CLI.
func CmdInspectRun(cmd *cobra.Command, _ []string) {
func CmdInspectRun(cmd *cobra.Command, args []string) {
if InspectFlags.Web {
schemaCmd.PrintErrln("The Alas UI is not available in this release.")
return
Expand All @@ -155,6 +163,11 @@ func CmdInspectRun(cmd *cobra.Command, _ []string) {
cobra.CheckErr(err)
defer client.Close()
schemas := InspectFlags.Schema
activeEnv, err := selectEnv(args)
cobra.CheckErr(err)
if activeEnv != nil && len(activeEnv.Schemas) > 0 {
schemas = activeEnv.Schemas
}
if client.URL.Schema != "" {
schemas = append(schemas, client.URL.Schema)
}
Expand All @@ -168,15 +181,37 @@ func CmdInspectRun(cmd *cobra.Command, _ []string) {
}

// CmdApplyRun is the command used when running CLI.
func CmdApplyRun(cmd *cobra.Command, _ []string) {
func CmdApplyRun(cmd *cobra.Command, args []string) error {
if ApplyFlags.Web {
schemaCmd.PrintErrln("The Atlas UI is not available in this release.")
return
cmd.Println("The Atlas UI is not available in this release.")
return errors.New("unavailable")
}
c, err := sqlclient.Open(cmd.Context(), ApplyFlags.URL)
cobra.CheckErr(err)
if err != nil {
return err
}
defer c.Close()
applyRun(cmd.Context(), c, ApplyFlags.File, ApplyFlags.DryRun, ApplyFlags.AutoApprove, ApplyFlags.Vars)
devURL := ApplyFlags.DevURL
activeEnv, err := selectEnv(args)
if err != nil {
return err
}
if activeEnv != nil && activeEnv.DevURL != "" {
devURL = activeEnv.DevURL
}
var file string
switch {
case activeEnv != nil && activeEnv.Source != "":
file = activeEnv.Source
case ApplyFlags.File != "":
file = ApplyFlags.File
default:
return fmt.Errorf("source file must be set via -f or project file")
}
if activeEnv != nil && activeEnv.Source != "" {
file = activeEnv.Source
}
return applyRun(cmd.Context(), c, devURL, file, ApplyFlags.DryRun, ApplyFlags.AutoApprove, ApplyFlags.Vars)
}

// CmdFmtRun formats all HCL files in a given directory using canonical HCL formatting
Expand All @@ -190,19 +225,25 @@ func CmdFmtRun(cmd *cobra.Command, args []string) {
}
}

func applyRun(ctx context.Context, client *sqlclient.Client, file string, dryRun, autoApprove bool, input map[string]string) {
func applyRun(ctx context.Context, client *sqlclient.Client, devURL string, file string, dryRun, autoApprove bool, input map[string]string) error {
schemas := ApplyFlags.Schema
if client.URL.Schema != "" {
schemas = append(schemas, client.URL.Schema)
}
realm, err := client.InspectRealm(ctx, &schema.InspectRealmOption{
Schemas: schemas,
})
cobra.CheckErr(err)
if err != nil {
return err
}
f, err := ioutil.ReadFile(file)
cobra.CheckErr(err)
if err != nil {
return err
}
desired := &schema.Realm{}
cobra.CheckErr(client.Eval(f, desired, input))
if err := client.Eval(f, desired, input); err != nil {
return err
}
if len(schemas) > 0 {
// Validate all schemas in file were selected by user.
sm := make(map[string]bool, len(schemas))
Expand All @@ -211,26 +252,33 @@ func applyRun(ctx context.Context, client *sqlclient.Client, file string, dryRun
}
for _, s := range desired.Schemas {
if !sm[s.Name] {
schemaCmd.Printf("schema %q from file %q was not selected %q, all schemas defined in file must be selected\n", s.Name, file, schemas)
return
return fmt.Errorf("schema %q from file %q was not selected %q, all schemas defined in file must be selected", s.Name, file, schemas)
}
}
}
if _, ok := client.Driver.(schema.Normalizer); ok && ApplyFlags.DevURL != "" {
if _, ok := client.Driver.(schema.Normalizer); ok && devURL != "" {
dev, err := sqlclient.Open(ctx, ApplyFlags.DevURL)
cobra.CheckErr(err)
if err != nil {
return err
}
defer dev.Close()
desired, err = dev.Driver.(schema.Normalizer).NormalizeRealm(ctx, desired)
cobra.CheckErr(err)
if err != nil {
return err
}
}
changes, err := client.RealmDiff(realm, desired)
cobra.CheckErr(err)
if err != nil {
return err
}
if len(changes) == 0 {
schemaCmd.Println("Schema is synced, no changes to be made")
return
return nil
}
p, err := client.PlanChanges(ctx, "plan", changes)
cobra.CheckErr(err)
if err != nil {
return err
}
schemaCmd.Println("-- Planned Changes:")
for _, c := range p.Changes {
if c.Comment != "" {
Expand All @@ -239,11 +287,14 @@ func applyRun(ctx context.Context, client *sqlclient.Client, file string, dryRun
schemaCmd.Println(c.Cmd)
}
if dryRun {
return
return nil
}
if autoApprove || promptUser() {
cobra.CheckErr(client.ApplyChanges(ctx, changes))
if err := client.ApplyChanges(ctx, changes); err != nil {
return err
}
}
return nil
}

func promptUser() bool {
Expand All @@ -256,12 +307,23 @@ func promptUser() bool {
return result == answerApply
}

func dsn2url(cmd *cobra.Command, p *string) {
// fixURLFlag fixes the url flag by pulling its value either from the flag itself,
// the (deprecated) dsn flag, or from the active environment.
func fixURLFlag(cmd *cobra.Command, p *string) {
cmd.Flags().StringVarP(p, "dsn", "d", "", "")
cobra.CheckErr(cmd.Flags().MarkHidden("dsn"))
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
activeEnv, err := selectEnv(args)
if err != nil {
return err
}
dsnF, urlF := cmd.Flag("dsn"), cmd.Flag("url")
switch {
case activeEnv != nil && activeEnv.URL != "":
urlF.Changed = true
if err := urlF.Value.Set(activeEnv.URL); err != nil {
return err
}
case !dsnF.Changed && !urlF.Changed:
return errors.New(`required flag "url" was not set`)
case dsnF.Changed && urlF.Changed:
Expand Down
Loading

0 comments on commit 7910885

Please sign in to comment.