diff --git a/cmd/atlas/internal/migrate/migrate.go b/cmd/atlas/internal/migrate/migrate.go index 4f864fb8a47..b4567104ce5 100644 --- a/cmd/atlas/internal/migrate/migrate.go +++ b/cmd/atlas/internal/migrate/migrate.go @@ -18,6 +18,7 @@ import ( "ariga.io/atlas/cmd/atlas/internal/migrate/ent" "ariga.io/atlas/cmd/atlas/internal/migrate/ent/revision" "ariga.io/atlas/sql/migrate" + "ariga.io/atlas/sql/mysql" "ariga.io/atlas/sql/schema" "ariga.io/atlas/sql/sqlclient" "ariga.io/atlas/sql/sqltool" @@ -50,6 +51,14 @@ type ( Option func(*EntRevisions) error ) +// Dialect returns the "ent dialect" of the Ent client. +func (r *EntRevisions) Dialect() string { + if r.ac.Name == mysql.DriverMaria { + return mysql.DriverName // Ent does not support "mariadb" as dialect. + } + return r.ac.Name +} + // RevisionsForClient creates a new RevisionReadWriter for the given sqlclient.Client. func RevisionsForClient(ctx context.Context, ac *sqlclient.Client, schema string) (RevisionReadWriter, error) { // If the driver supports the RevisionReadWriter interface, use it. @@ -77,9 +86,9 @@ func NewEntRevisions(ctx context.Context, ac *sqlclient.Client, opts ...Option) } } // Create the connection with the underlying migrate.Driver to have it inside a possible transaction. - entopts := []ent.Option{ent.Driver(sql.NewDriver(r.ac.Name, sql.Conn{ExecQuerier: r.ac.Driver}))} + entopts := []ent.Option{ent.Driver(sql.NewDriver(r.Dialect(), sql.Conn{ExecQuerier: r.ac.Driver}))} // SQLite does not support multiple schema, therefore schema-config is only needed for other dialects. - if r.ac.Name != dialect.SQLite { + if r.Dialect() != dialect.SQLite { // Make sure the schema to store the revisions table in does exist. _, err := r.ac.InspectSchema(ctx, r.schema, &schema.InspectOptions{Mode: schema.InspectSchemas}) if err != nil && !schema.IsNotExistError(err) { @@ -189,17 +198,17 @@ func (r *EntRevisions) DeleteRevision(ctx context.Context, v string) error { // execution in a transaction and assumes the underlying connection is of type *sql.DB, which is not true for actually // reading and writing revisions. func (r *EntRevisions) Migrate(ctx context.Context) (err error) { - c := ent.NewClient(ent.Driver(sql.OpenDB(r.ac.Name, r.ac.DB))) + c := ent.NewClient(ent.Driver(sql.OpenDB(r.Dialect(), r.ac.DB))) // Ensure the ent client is bound to the requested revision schema. Open a new connection, if not. - if r.ac.Name != dialect.SQLite && r.ac.URL.Schema != r.schema { + if r.Dialect() != dialect.SQLite && r.ac.URL.Schema != r.schema { sc, err := sqlclient.OpenURL(ctx, r.ac.URL.URL, sqlclient.OpenSchema(r.schema)) if err != nil { return err } defer sc.Close() - c = ent.NewClient(ent.Driver(sql.OpenDB(sc.Name, sc.DB))) + c = ent.NewClient(ent.Driver(sql.OpenDB(r.Dialect(), sc.DB))) } - if r.ac.Name == dialect.SQLite { + if r.Dialect() == dialect.SQLite { var on sql.NullBool if err := r.ac.DB.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&on); err != nil { return err diff --git a/sql/mysql/driver.go b/sql/mysql/driver.go index 428ea91dddc..d9e89167631 100644 --- a/sql/mysql/driver.go +++ b/sql/mysql/driver.go @@ -49,21 +49,24 @@ var _ interface { schema.TypeParseFormatter } = (*Driver)(nil) -// DriverName holds the name used for registration. -const DriverName = "mysql" +// DriverName and DriverMaria holds the names used for registration. +const ( + DriverName = "mysql" + DriverMaria = "mariadb" +) func init() { sqlclient.Register( DriverName, - sqlclient.OpenerFunc(opener), + opener(DriverName), sqlclient.RegisterDriverOpener(Open), sqlclient.RegisterCodec(codec, codec), sqlclient.RegisterFlavours("mysql+unix"), sqlclient.RegisterURLParser(parser{}), ) sqlclient.Register( - "mariadb", - sqlclient.OpenerFunc(opener), + DriverMaria, + opener(DriverMaria), sqlclient.RegisterDriverOpener(Open), sqlclient.RegisterCodec(mariaCodec, mariaCodec), sqlclient.RegisterFlavours("mariadb+unix", "maria", "maria+unix"), @@ -97,26 +100,29 @@ func Open(db schema.ExecQuerier) (migrate.Driver, error) { }, nil } -func opener(_ context.Context, u *url.URL) (*sqlclient.Client, error) { - ur := parser{}.ParseURL(u) - db, err := sql.Open(DriverName, ur.DSN) - if err != nil { - return nil, err - } - drv, err := Open(db) - if err != nil { - if cerr := db.Close(); cerr != nil { - err = fmt.Errorf("%w: %v", err, cerr) +// opener for the given driver name. +func opener(name string) sqlclient.OpenerFunc { + return func(_ context.Context, u *url.URL) (*sqlclient.Client, error) { + ur := parser{}.ParseURL(u) + db, err := sql.Open(DriverName, ur.DSN) + if err != nil { + return nil, err } - return nil, err + drv, err := Open(db) + if err != nil { + if cerr := db.Close(); cerr != nil { + err = fmt.Errorf("%w: %v", err, cerr) + } + return nil, err + } + drv.(*Driver).schema = ur.Schema + return &sqlclient.Client{ + Name: name, + DB: db, + URL: ur, + Driver: drv, + }, nil } - drv.(*Driver).schema = ur.Schema - return &sqlclient.Client{ - Name: DriverName, - DB: db, - URL: ur, - Driver: drv, - }, nil } // NormalizeRealm returns the normal representation of the given database.