From 5998ae44206cc91c84df6213b9de258eb5712c5e Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Thu, 19 Dec 2024 22:06:49 +0200 Subject: [PATCH] cmd/atlas/internal/sqlparse: move parsers out --- .../internal/sqlparse/pgparse/pgparse.go | 215 ------- .../internal/sqlparse/pgparse/pgparse_oss.go | 16 +- .../internal/sqlparse/pgparse/pgparse_test.go | 583 ------------------ .../sqlparse/sqliteparse/sqliteparse.go | 261 -------- .../sqlparse/sqliteparse/sqliteparse_oss.go | 28 + .../sqlparse/sqliteparse/sqliteparse_test.go | 199 ------ 6 files changed, 42 insertions(+), 1260 deletions(-) delete mode 100644 cmd/atlas/internal/sqlparse/pgparse/pgparse.go delete mode 100644 cmd/atlas/internal/sqlparse/pgparse/pgparse_test.go delete mode 100644 cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse.go create mode 100644 cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_oss.go delete mode 100644 cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_test.go diff --git a/cmd/atlas/internal/sqlparse/pgparse/pgparse.go b/cmd/atlas/internal/sqlparse/pgparse/pgparse.go deleted file mode 100644 index bc8fefc94da..00000000000 --- a/cmd/atlas/internal/sqlparse/pgparse/pgparse.go +++ /dev/null @@ -1,215 +0,0 @@ -// Copyright 2021-present The Atlas Authors. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package pgparse - -import ( - "errors" - "fmt" - "slices" - - "ariga.io/atlas/cmd/atlas/internal/sqlparse/parseutil" - "ariga.io/atlas/sql/migrate" - "ariga.io/atlas/sql/postgres" - "ariga.io/atlas/sql/schema" - - pgquery "github.com/pganalyze/pg_query_go/v5" -) - -// Parser implements the sqlparse.Parser -type Parser struct{} - -// ColumnFilledBefore checks if the column was filled before the given position. -func (p *Parser) ColumnFilledBefore(stmts []*migrate.Stmt, t *schema.Table, c *schema.Column, pos int) (bool, error) { - return parseutil.MatchStmtBefore(stmts, pos, func(s *migrate.Stmt) (bool, error) { - tr, err := pgquery.Parse(s.Text) - if err != nil { - return false, err - } - idx := slices.IndexFunc(tr.Stmts, func(s *pgquery.RawStmt) bool { - return s.Stmt.GetUpdateStmt() != nil - }) - if idx == -1 { - return false, nil - } - u := tr.Stmts[idx].Stmt.GetUpdateStmt() - if u == nil || !matchTable(u.Relation, t) { - return false, nil - } - // Accept UPDATE that fills all rows or those with NULL values as we cannot - // determine if NULL values were filled in case there is a custom filtering. - affectC := func() bool { - if u.WhereClause == nil { - return true - } - x := u.WhereClause.GetNullTest() - if x == nil || x.GetNulltesttype() != pgquery.NullTestType_IS_NULL { - return false - } - fields := x.GetArg().GetColumnRef().GetFields() - return len(fields) == 1 && fields[0].GetString_().GetSval() == c.Name - }() - idx = slices.IndexFunc(u.TargetList, func(n *pgquery.Node) bool { - r := n.GetResTarget() - return r.GetName() == c.Name && !r.GetVal().GetAConst().GetIsnull() - }) - // Ensure the column was filled. - return affectC && idx != -1, nil - }) -} - -// CreateViewAfter checks if a view was created after the position with the given name to a table. -func (p *Parser) CreateViewAfter(stmts []*migrate.Stmt, old, new string, pos int) (bool, error) { - return parseutil.MatchStmtAfter(stmts, pos, func(s *migrate.Stmt) (bool, error) { - tr, err := pgquery.Parse(s.Text) - if err != nil { - return false, err - } - idx := slices.IndexFunc(tr.Stmts, func(s *pgquery.RawStmt) bool { - return s.Stmt.GetViewStmt() != nil - }) - if idx == -1 { - return false, nil - } - v := tr.Stmts[idx].Stmt.GetViewStmt() - if v.GetView().GetRelname() != old { - return false, nil - } - from := v.Query.GetSelectStmt().GetFromClause() - if len(from) != 1 { - return false, nil - } - return from[0].GetRangeVar().GetRelname() == new, nil - }) -} - -// FixChange fixes the changes according to the given statement. -func (p *Parser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) { - if len(changes) == 0 { - return nil, errors.New("no changes to fix") - } - tr, err := pgquery.Parse(s) - if err != nil { - return nil, err - } - for _, stmt := range tr.Stmts { - switch stmt := stmt.GetStmt(); { - case stmt.GetRenameStmt() != nil && - stmt.GetRenameStmt().GetRenameType() == pgquery.ObjectType_OBJECT_COLUMN: - modify, err := expectHaveModify(changes) - if err != nil { - return nil, err - } - rename := stmt.GetRenameStmt() - parseutil.RenameColumn(modify, &parseutil.Rename{ - From: rename.GetSubname(), - To: rename.GetNewname(), - }) - case stmt.GetRenameStmt() != nil && - stmt.GetRenameStmt().GetRenameType() == pgquery.ObjectType_OBJECT_INDEX: - modify, err := expectOneModify(changes) - if err != nil { - return nil, err - } - rename := stmt.GetRenameStmt() - parseutil.RenameIndex(modify, &parseutil.Rename{ - From: rename.GetRelation().GetRelname(), - To: rename.GetNewname(), - }) - case stmt.GetRenameStmt() != nil && - stmt.GetRenameStmt().GetRenameType() == pgquery.ObjectType_OBJECT_TABLE: - rename := stmt.GetRenameStmt() - changes = parseutil.RenameTable(changes, &parseutil.Rename{ - From: rename.GetRelation().GetRelname(), - To: rename.GetNewname(), - }) - case stmt.GetIndexStmt() != nil && - stmt.GetIndexStmt().GetConcurrent(): - modify, err := expectOneModify(changes) - if err != nil { - return nil, err - } - name := stmt.GetIndexStmt().GetIdxname() - i := schema.Changes(modify.Changes).IndexAddIndex(name) - if i == -1 { - return nil, fmt.Errorf("AddIndex %q command not found", name) - } - add := modify.Changes[i].(*schema.AddIndex) - if !slices.ContainsFunc(add.Extra, func(c schema.Clause) bool { - _, ok := c.(*postgres.Concurrently) - return ok - }) { - add.Extra = append(add.Extra, &postgres.Concurrently{}) - } - case stmt.GetDropStmt() != nil && stmt.GetDropStmt().GetConcurrent(): - modify, err := expectOneModify(changes) - if err != nil { - return nil, err - } - for _, p := range stmt.GetDropStmt().GetObjects() { - items := p.GetList().GetItems() - var name string - switch { - // Match DROP INDEX . - case len(items) == 1 && items[0].GetString_().GetSval() != "": - name = items[0].GetString_().GetSval() - // Match DROP INDEX .. - case len(items) == 2 && modify.T.Schema != nil && - items[0].GetString_().GetSval() == modify.T.Schema.Name && - items[1].GetString_().GetSval() != "": - name = items[1].GetString_().GetSval() - default: - continue - } - i := schema.Changes(modify.Changes).IndexDropIndex(name) - if i == -1 { - return nil, fmt.Errorf("DropIndex %q command not found", name) - } - drop := modify.Changes[i].(*schema.DropIndex) - if !slices.ContainsFunc(drop.Extra, func(c schema.Clause) bool { - _, ok := c.(*postgres.Concurrently) - return ok - }) { - drop.Extra = append(drop.Extra, &postgres.Concurrently{}) - } - } - case stmt.GetAlterTableStmt() != nil: - if fixed, err := FixAlterTable(s, stmt.GetAlterTableStmt(), changes); err == nil { - changes = fixed // Make ALTER fixes optional. - } - } - } - return changes, nil -} - -func expectOneModify(changes schema.Changes) (*schema.ModifyTable, error) { - modify, ok := changes[0].(*schema.ModifyTable) - if !ok { - return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0]) - } - return modify, nil -} - -func expectHaveModify(changes schema.Changes) (*schema.ModifyTable, error) { - var modify []*schema.ModifyTable - for _, c := range changes { - switch c := c.(type) { - case *schema.ModifyTable: - modify = append(modify, c) - // The column might be used in the view. - case *schema.ModifyView: - default: - return nil, fmt.Errorf("unexpected change for alter-table statement: %#v", c) - } - } - if len(modify) != 1 { - return nil, fmt.Errorf("expected one modify-table change for alter-table statement, but got: %d", len(modify)) - } - return modify[0], nil -} - -// tableUpdated checks if the table was updated in the statement. -func matchTable(n *pgquery.RangeVar, t *schema.Table) bool { - return n.GetRelname() == t.Name && (n.GetSchemaname() == "" || n.GetSchemaname() == t.Schema.Name) -} diff --git a/cmd/atlas/internal/sqlparse/pgparse/pgparse_oss.go b/cmd/atlas/internal/sqlparse/pgparse/pgparse_oss.go index 6a6e2e9a3e2..c748faf4cf6 100644 --- a/cmd/atlas/internal/sqlparse/pgparse/pgparse_oss.go +++ b/cmd/atlas/internal/sqlparse/pgparse/pgparse_oss.go @@ -7,10 +7,22 @@ package pgparse import ( + "errors" + + "ariga.io/atlas/sql/migrate" "ariga.io/atlas/sql/schema" - pgquery "github.com/pganalyze/pg_query_go/v5" ) -func FixAlterTable(_ string, _ *pgquery.AlterTableStmt, changes schema.Changes) (schema.Changes, error) { +type Parser struct{} + +func (*Parser) ColumnFilledBefore([]*migrate.Stmt, *schema.Table, *schema.Column, int) (bool, error) { + return false, errors.New("unimplemented") +} + +func (*Parser) CreateViewAfter([]*migrate.Stmt, string, string, int) (bool, error) { + return false, errors.New("unimplemented") +} + +func (*Parser) FixChange(_ migrate.Driver, _ string, changes schema.Changes) (schema.Changes, error) { return changes, nil // Unimplemented. } diff --git a/cmd/atlas/internal/sqlparse/pgparse/pgparse_test.go b/cmd/atlas/internal/sqlparse/pgparse/pgparse_test.go deleted file mode 100644 index f64d0d0eaea..00000000000 --- a/cmd/atlas/internal/sqlparse/pgparse/pgparse_test.go +++ /dev/null @@ -1,583 +0,0 @@ -// Copyright 2021-present The Atlas Authors. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package pgparse_test - -import ( - "strconv" - "testing" - - "ariga.io/atlas/cmd/atlas/internal/sqlparse/pgparse" - "ariga.io/atlas/sql/migrate" - "ariga.io/atlas/sql/postgres" - "ariga.io/atlas/sql/schema" - - "github.com/stretchr/testify/require" -) - -func TestFixChange_RenameColumns(t *testing.T) { - var p pgparse.Parser - _, err := p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - nil, - ) - require.Error(t, err) - - _, err = p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - schema.Changes{&schema.AddTable{}}, - ) - require.Error(t, err) - - changes, err := p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.DropColumn{C: schema.NewColumn("c1")}, - &schema.AddColumn{C: schema.NewColumn("c2")}, - }, - }, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.RenameColumn{From: schema.NewColumn("c1"), To: schema.NewColumn("c2")}, - }, - }, - }, - changes, - ) - - changes, err = p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.DropColumn{C: schema.NewColumn("c1")}, - &schema.AddColumn{C: schema.NewColumn("c2")}, - }, - }, - &schema.ModifyView{ - From: &schema.View{Name: "t", Def: "select c1 from t"}, - To: &schema.View{Name: "t", Def: "select c2 from t"}, - }, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.RenameColumn{From: schema.NewColumn("c1"), To: schema.NewColumn("c2")}, - }, - }, - &schema.ModifyView{ - From: &schema.View{Name: "t", Def: "select c1 from t"}, - To: &schema.View{Name: "t", Def: "select c2 from t"}, - }, - }, - changes, - ) -} - -func TestFixChange_RenameIndexes(t *testing.T) { - var p pgparse.Parser - changes, err := p.FixChange( - nil, - "ALTER INDEX IF EXISTS i1 RENAME TO i2", - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - &schema.AddIndex{I: schema.NewIndex("i2")}, - }, - }, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.RenameIndex{From: schema.NewIndex("i1"), To: schema.NewIndex("i2")}, - }, - }, - }, - changes, - ) -} - -func TestFixChange_CreateIndexCon(t *testing.T) { - var p pgparse.Parser - changes, err := p.FixChange( - nil, - "CREATE INDEX i1 ON t1 (c1)", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - // No changes. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{ - I: schema.NewIndex("i1"), - }, - }, - }, - }, - changes, - ) - - changes, err = p.FixChange( - nil, - "CREATE INDEX CONCURRENTLY i1 ON t1 (c1)", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - // Should add the "Concurrently" clause to the AddIndex command. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - changes, - ) - - changes, err = p.FixChange( - nil, - "CREATE INDEX CONCURRENTLY i1 ON t1 (c1)", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - ) - require.NoError(t, err) - // The "Concurrently" clause should not be added if it already exists. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - changes, - ) - // Support quoted identifiers. - changes, err = p.FixChange( - nil, - `CREATE INDEX CONCURRENTLY "i1" ON t1 (c1)`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.AddIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok := changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.AddIndex).Extra[0]) - - // Support qualified quoted identifiers. - changes, err = p.FixChange( - nil, - `CREATE INDEX CONCURRENTLY "i1" ON "public".t1 (c1)`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"). - SetSchema(schema.New("public")), - Changes: schema.Changes{ - &schema.AddIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok = changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.AddIndex).Extra[0]) -} - -func TestFixChange_DropIndexCon(t *testing.T) { - var p pgparse.Parser - changes, err := p.FixChange( - nil, - "DROP INDEX i1", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - // No changes. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{ - I: schema.NewIndex("i1"), - }, - }, - }, - }, - changes, - ) - - changes, err = p.FixChange( - nil, - "DROP INDEX CONCURRENTLY i1", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - // Should add the "Concurrently" clause to the DropIndex command. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - changes, - ) - - changes, err = p.FixChange( - nil, - "DROP INDEX CONCURRENTLY i1", - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - ) - require.NoError(t, err) - // The "Concurrently" clause should not be added if it already exists. - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{ - I: schema.NewIndex("i1"), - Extra: []schema.Clause{ - &postgres.Concurrently{}, - }, - }, - }, - }, - }, - changes, - ) - // Support quoted identifiers. - changes, err = p.FixChange( - nil, - `DROP INDEX CONCURRENTLY "i1"`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok := changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.DropIndex).Extra[0]) - - // Support qualified identifiers. - changes, err = p.FixChange( - nil, - `DROP INDEX CONCURRENTLY public.i1`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"). - SetSchema(schema.New("public")), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok = changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.DropIndex).Extra[0]) - - // Support qualified quoted identifiers. - changes, err = p.FixChange( - nil, - `DROP INDEX CONCURRENTLY "public"."i1"`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"). - SetSchema(schema.New("public")), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok = changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.DropIndex).Extra[0]) - - // Multiple indexes. - changes, err = p.FixChange( - nil, - `DROP INDEX CONCURRENTLY i1, i2`, - schema.Changes{ - &schema.ModifyTable{ - T: schema.NewTable("t1"), - Changes: schema.Changes{ - &schema.DropIndex{I: schema.NewIndex("i1")}, - &schema.DropIndex{I: schema.NewIndex("i2")}, - }, - }, - }, - ) - require.NoError(t, err) - m, ok = changes[0].(*schema.ModifyTable) - require.True(t, ok) - require.Equal(t, &postgres.Concurrently{}, m.Changes[0].(*schema.DropIndex).Extra[0]) - require.Equal(t, &postgres.Concurrently{}, m.Changes[1].(*schema.DropIndex).Extra[0]) -} - -func TestFixChange_RenameTable(t *testing.T) { - var p pgparse.Parser - changes, err := p.FixChange( - nil, - "ALTER TABLE t1 RENAME TO t2", - schema.Changes{ - &schema.DropTable{T: schema.NewTable("t1")}, - &schema.AddTable{T: schema.NewTable("t2")}, - &schema.AddTable{T: schema.NewTable("t3")}, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.RenameTable{From: schema.NewTable("t1"), To: schema.NewTable("t2")}, - &schema.AddTable{T: schema.NewTable("t3")}, - }, - changes, - ) -} - -func TestColumnFilledBefore(t *testing.T) { - for i, tt := range []struct { - file string - pos int - wantFilled bool - wantErr bool - }{ - { - file: `UPDATE t SET c = NULL;`, - pos: 100, - }, - { - file: `UPDATE t SET c = 2;`, - }, - { - file: `UPDATE t SET c = 2;`, - }, - { - file: `UPDATE t SET c = 2;`, - pos: 100, - wantFilled: true, - }, - { - file: `UPDATE t SET c = 2 WHERE c IS NULL;`, - pos: 100, - wantFilled: true, - }, - { - file: `UPDATE t SET c = 2 WHERE c IS NOT NULL;`, - pos: 100, - wantFilled: false, - }, - { - file: `UPDATE t SET c = 2 WHERE c <> NULL`, - pos: 100, - wantFilled: false, - }, - { - file: ` - ALTER TABLE t MODIFY COLUMN c INT NOT NULL; - UPDATE t SET c = 2 WHERE c IS NULL; - `, - pos: 2, - wantFilled: false, - }, - { - file: ` - UPDATE t SET c = 2 WHERE c IS NULL; - ALTER TABLE t MODIFY COLUMN c INT NOT NULL; - `, - pos: 30, - wantFilled: true, - }, - } { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var p pgparse.Parser - stmts, err := migrate.Stmts(tt.file) - require.NoError(t, err) - filled, err := p.ColumnFilledBefore(stmts, schema.NewTable("t"), schema.NewColumn("c"), tt.pos) - require.Equal(t, err != nil, tt.wantErr, err) - require.Equal(t, filled, tt.wantFilled) - }) - } -} - -func TestCreateViewAfter(t *testing.T) { - for i, tt := range []struct { - file string - pos int - wantCreated bool - wantErr bool - }{ - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT * FROM new; -`, - pos: 1, - wantCreated: true, - }, - { - file: ` - ALTER TABLE old RENAME TO new; - CREATE VIEW old AS SELECT * FROM users; - `, - pos: 1, - }, - { - file: ` - ALTER TABLE old RENAME TO new; - CREATE VIEW old AS (SELECT * FROM "new"); - `, - pos: 1, - wantCreated: true, - }, - { - file: ` - ALTER TABLE old RENAME TO new; - CREATE VIEW old AS (SELECT * FROM "1"); - `, - pos: 1, - }, - { - file: ` - ALTER TABLE old RENAME TO new; - CREATE VIEW old AS SELECT * FROM new; - `, - pos: 100, - }, - { - file: ` - ALTER TABLE old RENAME TO new; - CREATE VIEW old AS SELECT a, b, c FROM new; - `, - wantCreated: true, - }, - } { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var p pgparse.Parser - stmts, err := migrate.Stmts(tt.file) - require.NoError(t, err) - created, err := p.CreateViewAfter(stmts, "old", "new", tt.pos) - require.Equal(t, err != nil, tt.wantErr, err) - require.Equal(t, tt.wantCreated, created) - }) - } -} diff --git a/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse.go b/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse.go deleted file mode 100644 index da16613e467..00000000000 --- a/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright 2021-present The Atlas Authors. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package sqliteparse - -import ( - "errors" - "fmt" - "slices" - "strconv" - "strings" - - "ariga.io/atlas/cmd/atlas/internal/sqlparse/parseutil" - "ariga.io/atlas/sql/migrate" - "ariga.io/atlas/sql/schema" - - "github.com/antlr4-go/antlr/v4" -) - -type ( - // Stmt provides extended functionality - // to ANTLR parsed statements. - Stmt struct { - stmt antlr.ParseTree - input string - err error - } - - // listenError catches parse errors. - listenError struct { - antlr.DefaultErrorListener - err error - text string - } -) - -// SyntaxError implements ErrorListener.SyntaxError. -func (l *listenError) SyntaxError(_ antlr.Recognizer, _ any, line, column int, msg string, _ antlr.RecognitionException) { - if idx := strings.Index(msg, " expecting "); idx != -1 { - msg = msg[:idx] - } - l.err = fmt.Errorf("line %d:%d: %s", line, column+1, msg) -} - -// ParseStmt parses a statement. -func ParseStmt(text string) (stmt *Stmt, err error) { - l := &listenError{text: text} - defer func() { - if l.err != nil { - err = l.err - stmt = nil - } else if perr := recover(); perr != nil { - m := fmt.Sprint(perr) - if v, ok := perr.(antlr.RecognitionException); ok { - m = v.GetMessage() - } - err = errors.New(m) - stmt = nil - } - }() - lex := NewLexer(antlr.NewInputStream(text)) - lex.RemoveErrorListeners() - lex.AddErrorListener(l) - p := NewParser( - antlr.NewCommonTokenStream(lex, 0), - ) - p.RemoveErrorListeners() - p.AddErrorListener(l) - p.BuildParseTrees = true - stmt = &Stmt{ - stmt: p.Sql_stmt(), - } - return -} - -// IsAlterTable reports if the statement is type ALTER TABLE. -func (s *Stmt) IsAlterTable() bool { - if s.stmt.GetChildCount() != 1 { - return false - } - _, ok := s.stmt.GetChild(0).(*Alter_table_stmtContext) - return ok -} - -// RenameColumn returns the renamed column information from the statement, if exists. -func (s *Stmt) RenameColumn() (*parseutil.Rename, bool) { - if !s.IsAlterTable() { - return nil, false - } - alter := s.stmt.GetChild(0).(*Alter_table_stmtContext) - if alter.old_column_name == nil || alter.new_column_name == nil { - return nil, false - } - return &parseutil.Rename{ - From: unquote(alter.old_column_name.GetText()), - To: unquote(alter.new_column_name.GetText()), - }, true -} - -// RenameTable returns the renamed table information from the statement, if exists. -func (s *Stmt) RenameTable() (*parseutil.Rename, bool) { - if !s.IsAlterTable() { - return nil, false - } - alter := s.stmt.GetChild(0).(*Alter_table_stmtContext) - if alter.new_table_name == nil { - return nil, false - } - return &parseutil.Rename{ - From: unquote(alter.Table_name(0).GetText()), - To: unquote(alter.new_table_name.GetText()), - }, true -} - -// TableUpdate reports if the statement is an UPDATE command for the given table. -func (s *Stmt) TableUpdate(t *schema.Table) (*Update_stmtContext, bool) { - if s.stmt.GetChildCount() != 1 { - return nil, false - } - u, ok := s.stmt.GetChild(0).(*Update_stmtContext) - if !ok { - return nil, false - } - name, ok := u.Qualified_table_name().(*Qualified_table_nameContext) - if !ok || unquote(name.Table_name().GetText()) != t.Name { - return nil, false - } - return u, true -} - -// CreateView reports if the statement is a CREATE VIEW command with the given name. -func (s *Stmt) CreateView(name string) (*Create_view_stmtContext, bool) { - if s.stmt.GetChildCount() != 1 { - return nil, false - } - v, ok := s.stmt.GetChild(0).(*Create_view_stmtContext) - if !ok || unquote(v.View_name().GetText()) != name { - return nil, false - } - return v, true -} - -// FileParser implements the sqlparse.Parser -type FileParser struct{} - -// ColumnFilledBefore checks if the column was filled before the given position. -func (p *FileParser) ColumnFilledBefore(stmts []*migrate.Stmt, t *schema.Table, c *schema.Column, pos int) (bool, error) { - return parseutil.MatchStmtBefore(stmts, pos, func(s *migrate.Stmt) (bool, error) { - stmt, err := ParseStmt(s.Text) - if err != nil { - return false, err - } - u, ok := stmt.TableUpdate(t) - if !ok { - return false, nil - } - // Accept UPDATE that fills all rows or those with NULL values as we cannot - // determine if NULL values were filled in case there is a custom filtering. - affectC := func() bool { - x := u.GetWhere() - if x == nil { - return true - } - if x.GetChildCount() != 3 { - return false - } - x1, ok := x.GetChild(0).(*ExprContext) - if !ok || unquote(x1.GetText()) != c.Name { - return false - } - x2, ok := x.GetChild(1).(*antlr.TerminalNodeImpl) - if !ok || x2.GetSymbol().GetTokenType() != ParserIS_ { - return false - } - return isnull(x.GetChild(2)) - }() - list, ok := u.Assignment_list().(*Assignment_listContext) - if !ok { - return false, nil - } - idx := slices.IndexFunc(list.AllAssignment(), func(a IAssignmentContext) bool { - as, ok := a.(*AssignmentContext) - return ok && unquote(as.Column_name().GetText()) == c.Name && !isnull(as.Expr()) - }) - // Ensure the column was filled. - return affectC && idx != -1, nil - }) -} - -// CreateViewAfter checks if a view was created after the position with the given name to a table. -func (p *FileParser) CreateViewAfter(stmts []*migrate.Stmt, old, new string, pos int) (bool, error) { - return parseutil.MatchStmtAfter(stmts, pos, func(s *migrate.Stmt) (bool, error) { - stmt, err := ParseStmt(s.Text) - if err != nil { - return false, err - } - v, ok := stmt.CreateView(old) - if !ok { - return false, nil - } - sc, ok := v.Select_stmt().(*Select_stmtContext) - if !ok { - return false, nil - } - idx := slices.IndexFunc(sc.Select_core(0).GetChildren(), func(t antlr.Tree) bool { - ts, ok := t.(*Table_or_subqueryContext) - return ok && unquote(ts.GetText()) == new - }) - return idx != -1, nil - }) -} - -// FixChange fixes the changes according to the given statement. -func (p *FileParser) FixChange(_ migrate.Driver, s string, changes schema.Changes) (schema.Changes, error) { - stmt, err := ParseStmt(s) - if err != nil { - return nil, err - } - if !stmt.IsAlterTable() { - return changes, nil - } - if r, ok := stmt.RenameColumn(); ok { - if len(changes) != 1 { - return nil, fmt.Errorf("unexpected number fo changes: %d", len(changes)) - } - modify, ok := changes[0].(*schema.ModifyTable) - if !ok { - return nil, fmt.Errorf("expected modify-table change for alter-table statement, but got: %T", changes[0]) - } - // ALTER COLUMN cannot be combined with additional commands. - if len(changes) > 2 { - return nil, fmt.Errorf("unexpected number of changes found: %d", len(changes)) - } - parseutil.RenameColumn(modify, r) - } - if r, ok := stmt.RenameTable(); ok { - changes = parseutil.RenameTable(changes, r) - } - return changes, nil -} - -func isnull(t antlr.Tree) bool { - x, ok := t.(*ExprContext) - if !ok || x.GetChildCount() != 1 { - return false - } - l, ok := x.GetChild(0).(*Literal_valueContext) - return ok && l.GetChildCount() == 1 && len(l.GetTokens(ParserNULL_)) > 0 -} - -func unquote(s string) string { - switch { - case len(s) < 2: - case s[0] == '`' && s[len(s)-1] == '`', s[0] == '"' && s[len(s)-1] == '"': - if u, err := strconv.Unquote(s); err == nil { - return u - } - } - return s -} diff --git a/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_oss.go b/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_oss.go new file mode 100644 index 00000000000..68e89e73331 --- /dev/null +++ b/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_oss.go @@ -0,0 +1,28 @@ +// Copyright 2021-present The Atlas Authors. All rights reserved. +// This source code is licensed under the Apache 2.0 license found +// in the LICENSE file in the root directory of this source tree. + +//go:build !ent + +package sqliteparse + +import ( + "errors" + + "ariga.io/atlas/sql/migrate" + "ariga.io/atlas/sql/schema" +) + +type FileParser struct{} + +func (*FileParser) ColumnFilledBefore([]*migrate.Stmt, *schema.Table, *schema.Column, int) (bool, error) { + return false, errors.New("unimplemented") +} + +func (*FileParser) CreateViewAfter([]*migrate.Stmt, string, string, int) (bool, error) { + return false, errors.New("unimplemented") +} + +func (*FileParser) FixChange(_ migrate.Driver, _ string, changes schema.Changes) (schema.Changes, error) { + return changes, nil // Unimplemented. +} diff --git a/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_test.go b/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_test.go deleted file mode 100644 index e62d03624b5..00000000000 --- a/cmd/atlas/internal/sqlparse/sqliteparse/sqliteparse_test.go +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright 2021-present The Atlas Authors. All rights reserved. -// This source code is licensed under the Apache 2.0 license found -// in the LICENSE file in the root directory of this source tree. - -package sqliteparse_test - -import ( - "strconv" - "testing" - - "ariga.io/atlas/cmd/atlas/internal/sqlparse/sqliteparse" - "ariga.io/atlas/sql/migrate" - "ariga.io/atlas/sql/schema" - - "github.com/stretchr/testify/require" -) - -func TestFixChange_RenameColumns(t *testing.T) { - var p sqliteparse.FileParser - _, err := p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - nil, - ) - require.Error(t, err) - - _, err = p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - schema.Changes{&schema.AddTable{}}, - ) - require.Error(t, err) - - changes, err := p.FixChange( - nil, - "ALTER TABLE t RENAME COLUMN c1 TO c2", - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.DropColumn{C: schema.NewColumn("c1")}, - &schema.AddColumn{C: schema.NewColumn("c2")}, - }, - }, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.ModifyTable{ - Changes: schema.Changes{ - &schema.RenameColumn{From: schema.NewColumn("c1"), To: schema.NewColumn("c2")}, - }, - }, - }, - changes, - ) -} - -func TestFixChange_RenameTable(t *testing.T) { - var p sqliteparse.FileParser - changes, err := p.FixChange( - nil, - "ALTER TABLE t1 RENAME TO t2", - schema.Changes{ - &schema.DropTable{T: schema.NewTable("t1")}, - &schema.AddTable{T: schema.NewTable("t2")}, - &schema.AddTable{T: schema.NewTable("t3")}, - }, - ) - require.NoError(t, err) - require.Equal( - t, - schema.Changes{ - &schema.RenameTable{From: schema.NewTable("t1"), To: schema.NewTable("t2")}, - &schema.AddTable{T: schema.NewTable("t3")}, - }, - changes, - ) -} - -func TestColumnFilledBefore(t *testing.T) { - for i, tt := range []struct { - file string - pos int - wantFilled bool - wantErr bool - }{ - { - file: `UPDATE t SET c = NULL;`, - pos: 100, - }, - { - file: "UPDATE `t` SET c = 2;", - pos: 100, - wantFilled: true, - }, - { - file: `UPDATE t SET c = 2 WHERE c IS NULL;`, - pos: 100, - wantFilled: true, - }, - { - file: "UPDATE `t` SET `c` = 2 WHERE `c` IS NULL;", - pos: 100, - wantFilled: true, - }, - { - file: `UPDATE t SET c = 2 WHERE c IS NOT NULL;`, - pos: 100, - wantFilled: false, - }, - { - file: `UPDATE t SET c = 2 WHERE c <> NULL`, - pos: 100, - wantFilled: false, - }, - { - file: ` -UPDATE t1 SET c = 2 WHERE c IS NULL; -UPDATE t SET c = 2 WHERE c IS NULL; -`, - pos: 2, - wantFilled: false, - }, - { - file: ` -UPDATE t SET c = 2 WHERE c IS NULL; -UPDATE t1 SET c = 2 WHERE c IS NULL; -`, - pos: 30, - wantFilled: true, - }, - } { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var p sqliteparse.FileParser - stmts, err := migrate.Stmts(tt.file) - require.NoError(t, err) - filled, err := p.ColumnFilledBefore(stmts, schema.NewTable("t"), schema.NewColumn("c"), tt.pos) - require.Equal(t, err != nil, tt.wantErr, err) - require.Equal(t, filled, tt.wantFilled) - }) - } -} - -func TestCreateViewAfter(t *testing.T) { - for i, tt := range []struct { - file string - pos int - wantCreated bool - wantErr bool - }{ - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT * FROM new; -`, - pos: 1, - wantCreated: true, - }, - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT * FROM users; -`, - pos: 1, - }, - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT * FROM new JOIN new; -`, - pos: 1, - }, - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT * FROM new; -`, - pos: 100, - }, - { - file: ` -ALTER TABLE old RENAME TO new; -CREATE VIEW old AS SELECT a, b, c FROM new; -`, - wantCreated: true, - }, - } { - t.Run(strconv.Itoa(i), func(t *testing.T) { - var p sqliteparse.FileParser - stmts, err := migrate.Stmts(tt.file) - require.NoError(t, err) - created, err := p.CreateViewAfter(stmts, "old", "new", tt.pos) - require.Equal(t, err != nil, tt.wantErr, err) - require.Equal(t, created, tt.wantCreated) - }) - } -}