diff --git a/internal/integration/docker-compose.yaml b/internal/integration/docker-compose.yaml index 27055bf955f..220f89aa0e1 100644 --- a/internal/integration/docker-compose.yaml +++ b/internal/integration/docker-compose.yaml @@ -120,4 +120,17 @@ services: healthcheck: test: mysqladmin ping -ppass ports: - - 4308:3306 \ No newline at end of file + - 4308:3306 + +# Default DB test, No Password + tidb5: + platform: linux/amd64 + image: pingcap/tidb:v5.4.0 + ports: + - 4309:4000 + + tidb_latest: + platform: linux/amd64 + image: pingcap/tidb:latest + ports: + - 4310:4000 diff --git a/internal/integration/mysql_test.go b/internal/integration/mysql_test.go index 9344b38c9a0..bca2796d90b 100644 --- a/internal/integration/mysql_test.go +++ b/internal/integration/mysql_test.go @@ -37,7 +37,11 @@ func myRun(t *testing.T, fn func(*myTest)) { myTests.Do(func() { myTests.drivers = make(map[string]*myTest) for version, port := range map[string]int{"56": 3306, "57": 3307, "8": 3308, "Maria107": 4306, "Maria102": 4307, "Maria103": 4308} { - db, err := sql.Open("mysql", fmt.Sprintf("root:pass@tcp(localhost:%d)/test?parseTime=True", port)) + password := ":pass" + if version == "TiDB" { + password = "" + } + db, err := sql.Open("mysql", fmt.Sprintf("root%s@tcp(localhost:%d)/test?parseTime=True", password, port)) require.NoError(t, err) drv, err := mysql.Open(db) require.NoError(t, err) @@ -1275,6 +1279,9 @@ func (t *myTest) defaultAttrs() []schema.Attr { collation = "latin1_swedish_ci" ) switch { + case t.version == "TiDB": + charset = "utf8mb4" + collation = "utf8mb4_bin" case t.version == "8": charset = "utf8mb4" collation = "utf8mb4_0900_ai_ci" diff --git a/sql/mysql/driver.go b/sql/mysql/driver.go index 90d33006e75..6f3a2ba3112 100644 --- a/sql/mysql/driver.go +++ b/sql/mysql/driver.go @@ -46,6 +46,14 @@ func Open(db schema.ExecQuerier) (*Driver, error) { if err := sqlx.ScanOne(rows, &c.version, &c.collate, &c.charset); err != nil { return nil, fmt.Errorf("mysql: scan system variables: %w", err) } + if c.tidb() { + return &Driver{ + conn: c, + Differ: &sqlx.Diff{DiffDriver: &diff{c}}, + Inspector: &tinspect{inspect{c}}, + PlanApplier: &planApply{c}, + }, nil + } return &Driver{ conn: c, Differ: &sqlx.Diff{DiffDriver: &diff{c}}, @@ -99,6 +107,11 @@ func (d *conn) mariadb() bool { return strings.Index(d.version, "MariaDB") > 0 } +// tidb reports if the Driver is connected to a TiDB database. +func (d *conn) tidb() bool { + return strings.Index(d.version, "TiDB") > 0 +} + // compareV returns an integer comparing two versions according to // semantic version precedence. func (d *conn) compareV(w string) int { diff --git a/sql/mysql/inspect.go b/sql/mysql/inspect.go index f90da736875..d3ff9558b3e 100644 --- a/sql/mysql/inspect.go +++ b/sql/mysql/inspect.go @@ -465,14 +465,14 @@ func (i *inspect) extraAttr(t *schema.Table, c *schema.Column, extra string) err // the 'SHOW CREATE' command. func (i *inspect) showCreate(ctx context.Context, s *schema.Schema) error { for _, t := range s.Tables { - s, ok := popShow(t) + st, ok := popShow(t) if !ok { continue } if err := i.createStmt(ctx, t); err != nil { return err } - if err := i.setAutoInc(s, t); err != nil { + if err := i.setAutoInc(st, t); err != nil { return err } // TODO(a8m): setChecks, setIndexExpr from CREATE statement. @@ -648,7 +648,7 @@ SELECT t1.CREATE_OPTIONS FROM INFORMATION_SCHEMA.TABLES AS t1 - JOIN INFORMATION_SCHEMA.COLLATIONS AS t2 + LEFT JOIN INFORMATION_SCHEMA.COLLATIONS AS t2 ON t1.TABLE_COLLATION = t2.COLLATION_NAME WHERE TABLE_SCHEMA IN (%s) diff --git a/sql/mysql/tidb.go b/sql/mysql/tidb.go new file mode 100644 index 00000000000..bac3e84585b --- /dev/null +++ b/sql/mysql/tidb.go @@ -0,0 +1,129 @@ +// Copyright 2021-present The Atlas Authors. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +package mysql + +import ( + "context" + "fmt" + "regexp" + "strings" + + "ariga.io/atlas/sql/internal/sqlx" + "ariga.io/atlas/sql/schema" +) + +type tinspect struct { + inspect +} + +func (i *tinspect) InspectSchema(ctx context.Context, name string, opts *schema.InspectOptions) (*schema.Schema, error) { + s, err := i.inspect.InspectSchema(ctx, name, opts) + if err != nil { + return nil, err + } + return i.patchSchema(ctx, s) +} + +func (i *tinspect) InspectRealm(ctx context.Context, opts *schema.InspectRealmOption) (*schema.Realm, error) { + r, err := i.inspect.InspectRealm(ctx, opts) + if err != nil { + return nil, err + } + for _, s := range r.Schemas { + if _, err := i.patchSchema(ctx, s); err != nil { + return nil, err + } + } + return r, nil +} + +func (i *tinspect) patchSchema(ctx context.Context, s *schema.Schema) (*schema.Schema, error) { + for _, t := range s.Tables { + var createStmt CreateStmt + if ok := sqlx.Has(t.Attrs, &createStmt); !ok { + if err := i.createStmt(ctx, t); err != nil { + return nil, err + } + } + if err := i.setCollate(t); err != nil { + return nil, err + } + if err := i.setFKs(s, t); err != nil { + return nil, err + } + } + return s, nil +} + +// e.g CONSTRAINT "" FOREIGN KEY ("foo_id") REFERENCES "foo" ("id") +var reFK = regexp.MustCompile("(?i)CONSTRAINT\\s+[\"`]*(\\w+)[\"`]*\\s+FOREIGN\\s+KEY\\s*\\(([,\"` \\w]+)\\)\\s+REFERENCES\\s+[\"`]*(\\w+)[\"`]*\\s*\\(([,\"` \\w]+)\\)") + +func (i *tinspect) setFKs(s *schema.Schema, t *schema.Table) error { + var c CreateStmt + if !sqlx.Has(t.Attrs, &c) { + return fmt.Errorf("missing CREATE TABLE statment in attribuets for %q", t.Name) + } + for _, m := range reFK.FindAllStringSubmatch(c.S, -1) { + if len(m) != 5 { + return fmt.Errorf("unexpected number of matches for a table constraint: %q", m) + } + ctName, clmns, refTableName, refClmns := m[1], m[2], m[3], m[4] + fk := &schema.ForeignKey{ + Symbol: ctName, + Table: t, + // There is no support in TiDB for FKs so inherently there are no actions on update/delete. + OnUpdate: schema.NoAction, + OnDelete: schema.NoAction, + } + refTable, ok := s.Table(refTableName) + if !ok { + return fmt.Errorf("couldn't resolve ref table %s on ", m[3]) + } + fk.RefTable = refTable + for _, c := range columns(s, clmns) { + column, ok := t.Column(c) + if !ok { + return fmt.Errorf("column %q was not found for fk %q", c, ctName) + } + fk.Columns = append(fk.Columns, column) + } + for _, c := range columns(s, refClmns) { + column, ok := t.Column(c) + if !ok { + return fmt.Errorf("ref column %q was not found for fk %q", c, ctName) + } + fk.RefColumns = append(fk.RefColumns, column) + } + t.ForeignKeys = append(t.ForeignKeys, fk) + } + return nil +} + +// columns from the matched regex above. +func columns(schema *schema.Schema, s string) []string { + names := strings.Split(s, ",") + for i := range names { + names[i] = strings.Trim(strings.TrimSpace(names[i]), "`\"") + } + return names +} + +// e.g CHARSET=utf8mb4 COLLATE=utf8mb4_bin +var reColl = regexp.MustCompile(`(?i)CHARSET\s*=\s*(\w+)\s*COLLATE\s*=\s*(\w+)`) + +// setCollate extracts the updated Collation from CREATE TABLE statement. +func (i *tinspect) setCollate(t *schema.Table) error { + var c CreateStmt + if !sqlx.Has(t.Attrs, &c) { + return fmt.Errorf("missing CREATE TABLE statment in attribuets for %q", t.Name) + } + matches := reColl.FindStringSubmatch(c.S) + if len(matches) != 3 { + return fmt.Errorf("missing COLLATE and/or CHARSET information on CREATE TABLE statment for %q", t.Name) + } + t.SetCharset(matches[1]) + t.SetCollation(matches[2]) + return nil +}