diff --git a/database/schema/grammars/sqlserver.go b/database/schema/grammars/sqlserver.go new file mode 100644 index 000000000..aeb812715 --- /dev/null +++ b/database/schema/grammars/sqlserver.go @@ -0,0 +1,216 @@ +package grammars + +import ( + "fmt" + "slices" + "strings" + + "github.com/goravel/framework/contracts/database" + "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/database/schema/constants" +) + +type Sqlserver struct { + attributeCommands []string + modifiers []func(schema.Blueprint, schema.ColumnDefinition) string + serials []string + wrap *Wrap +} + +func NewSqlserver(tablePrefix string) *Sqlserver { + sqlserver := &Sqlserver{ + attributeCommands: []string{constants.CommandComment}, + serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"}, + wrap: NewWrap(database.DriverSqlserver, tablePrefix), + } + sqlserver.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{ + sqlserver.ModifyDefault, + sqlserver.ModifyIncrement, + sqlserver.ModifyNullable, + } + + return sqlserver +} + +func (r *Sqlserver) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string { + return fmt.Sprintf("alter table %s add %s", r.wrap.Table(blueprint.GetTableName()), r.getColumn(blueprint, command.Column)) +} + +func (r *Sqlserver) CompileCreate(blueprint schema.Blueprint) string { + return fmt.Sprintf("create table %s (%s)", r.wrap.Table(blueprint.GetTableName()), strings.Join(r.getColumns(blueprint), ", ")) +} + +func (r *Sqlserver) CompileDropAllDomains(domains []string) string { + return "" +} + +func (r *Sqlserver) CompileDropAllForeignKeys() string { + return `DECLARE @sql NVARCHAR(MAX) = N''; + SELECT @sql += 'ALTER TABLE ' + + QUOTENAME(OBJECT_SCHEMA_NAME(parent_object_id)) + '.' + + QUOTENAME(OBJECT_NAME(parent_object_id)) + + ' DROP CONSTRAINT ' + QUOTENAME(name) + ';' + FROM sys.foreign_keys; + + EXEC sp_executesql @sql;` +} + +func (r *Sqlserver) CompileDropAllTables(tables []string) string { + return "EXEC sp_msforeachtable 'DROP TABLE ?'" +} + +func (r *Sqlserver) CompileDropAllTypes(types []string) string { + return "" +} + +func (r *Sqlserver) CompileDropAllViews(views []string) string { + return `DECLARE @sql NVARCHAR(MAX) = N''; + SELECT @sql += 'DROP VIEW ' + QUOTENAME(OBJECT_SCHEMA_NAME(object_id)) + '.' + QUOTENAME(name) + ';' + FROM sys.views; + + EXEC sp_executesql @sql;` +} + +func (r *Sqlserver) CompileDropIfExists(blueprint schema.Blueprint) string { + table := r.wrap.Table(blueprint.GetTableName()) + + return fmt.Sprintf("if object_id(%s, 'U') is not null drop table %s", r.wrap.Quote(table), table) +} + +func (r *Sqlserver) CompileForeign(blueprint schema.Blueprint, command *schema.Command) string { + sql := fmt.Sprintf("alter table %s add constraint %s foreign key (%s) references %s (%s)", + r.wrap.Table(blueprint.GetTableName()), + r.wrap.Column(command.Index), + r.wrap.Columnize(command.Columns), + r.wrap.Table(command.On), + r.wrap.Columnize(command.References)) + if command.OnDelete != "" { + sql += " on delete " + command.OnDelete + } + if command.OnUpdate != "" { + sql += " on update " + command.OnUpdate + } + + return sql +} + +func (r *Sqlserver) CompileIndex(blueprint schema.Blueprint, command *schema.Command) string { + return fmt.Sprintf("create index %s on %s (%s)", + r.wrap.Column(command.Index), + r.wrap.Table(blueprint.GetTableName()), + r.wrap.Columnize(command.Columns), + ) +} + +func (r *Sqlserver) CompileIndexes(schema, table string) string { + newSchema := "schema_name()" + if schema != "" { + newSchema = r.wrap.Quote(schema) + } + + return fmt.Sprintf( + "select idx.name as name, string_agg(col.name, ',') within group (order by idxcol.key_ordinal) as columns, "+ + "idx.type_desc as [type], idx.is_unique as [unique], idx.is_primary_key as [primary] "+ + "from sys.indexes as idx "+ + "join sys.tables as tbl on idx.object_id = tbl.object_id "+ + "join sys.schemas as scm on tbl.schema_id = scm.schema_id "+ + "join sys.index_columns as idxcol on idx.object_id = idxcol.object_id and idx.index_id = idxcol.index_id "+ + "join sys.columns as col on idxcol.object_id = col.object_id and idxcol.column_id = col.column_id "+ + "where tbl.name = %s and scm.name = %s "+ + "group by idx.name, idx.type_desc, idx.is_unique, idx.is_primary_key", + r.wrap.Quote(table), + newSchema, + ) +} + +func (r *Sqlserver) CompilePrimary(blueprint schema.Blueprint, command *schema.Command) string { + return fmt.Sprintf("alter table %s add constraint %s primary key (%s)", + r.wrap.Table(blueprint.GetTableName()), + r.wrap.Column(command.Index), + r.wrap.Columnize(command.Columns)) +} + +func (r *Sqlserver) CompileTables(database string) string { + return "select t.name as name, schema_name(t.schema_id) as [schema], sum(u.total_pages) * 8 * 1024 as size " + + "from sys.tables as t " + + "join sys.partitions as p on p.object_id = t.object_id " + + "join sys.allocation_units as u on u.container_id = p.hobt_id " + + "group by t.name, t.schema_id " + + "order by t.name" +} + +func (r *Sqlserver) CompileTypes() string { + return "" +} + +func (r *Sqlserver) CompileViews(database string) string { + return "select name, schema_name(v.schema_id) as [schema], definition from sys.views as v " + + "inner join sys.sql_modules as m on v.object_id = m.object_id " + + "order by name" +} + +func (r *Sqlserver) GetAttributeCommands() []string { + return r.attributeCommands +} + +func (r *Sqlserver) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if column.GetDefault() != nil { + return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault())) + } + + return "" +} + +func (r *Sqlserver) ModifyNullable(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if column.GetNullable() { + return " null" + } else { + return " not null" + } +} + +func (r *Sqlserver) ModifyIncrement(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + if slices.Contains(r.serials, column.GetType()) && column.GetAutoIncrement() { + if blueprint.HasCommand("primary") { + return " identity" + } + return " identity primary key" + } + + return "" +} + +func (r *Sqlserver) TypeBigInteger(column schema.ColumnDefinition) string { + return "bigint" +} + +func (r *Sqlserver) TypeInteger(column schema.ColumnDefinition) string { + return "int" +} + +func (r *Sqlserver) TypeString(column schema.ColumnDefinition) string { + length := column.GetLength() + if length > 0 { + return fmt.Sprintf("nvarchar(%d)", length) + } + + return "nvarchar(255)" +} + +func (r *Sqlserver) getColumns(blueprint schema.Blueprint) []string { + var columns []string + for _, column := range blueprint.GetAddedColumns() { + columns = append(columns, r.getColumn(blueprint, column)) + } + + return columns +} + +func (r *Sqlserver) getColumn(blueprint schema.Blueprint, column schema.ColumnDefinition) string { + sql := fmt.Sprintf("%s %s", r.wrap.Column(column.GetName()), getType(r, column)) + + for _, modifier := range r.modifiers { + sql += modifier(blueprint, column) + } + + return sql +} diff --git a/database/schema/grammars/sqlserver_test.go b/database/schema/grammars/sqlserver_test.go new file mode 100644 index 000000000..298af4a9a --- /dev/null +++ b/database/schema/grammars/sqlserver_test.go @@ -0,0 +1,289 @@ +package grammars + +import ( + "testing" + + "github.com/stretchr/testify/suite" + + contractsschema "github.com/goravel/framework/contracts/database/schema" + mocksschema "github.com/goravel/framework/mocks/database/schema" +) + +type SqlserverSuite struct { + suite.Suite + grammar *Sqlserver +} + +func TestSqlserverSuite(t *testing.T) { + suite.Run(t, &SqlserverSuite{}) +} + +func (s *SqlserverSuite) SetupTest() { + s.grammar = NewSqlserver("goravel_") +} + +func (s *SqlserverSuite) TestCompileAdd() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockColumn := mocksschema.NewColumnDefinition(s.T()) + + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + mockColumn.EXPECT().GetName().Return("name").Once() + mockColumn.EXPECT().GetType().Return("string").Twice() + mockColumn.EXPECT().GetDefault().Return("goravel").Twice() + mockColumn.EXPECT().GetNullable().Return(false).Once() + mockColumn.EXPECT().GetLength().Return(1).Once() + + sql := s.grammar.CompileAdd(mockBlueprint, &contractsschema.Command{ + Column: mockColumn, + }) + + s.Equal(`alter table "goravel_users" add "name" nvarchar(1) default 'goravel' not null`, sql) +} + +func (s *SqlserverSuite) TestCompileCreate() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + // postgres.go::CompileCreate + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + // utils.go::getColumns + mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsschema.ColumnDefinition{ + mockColumn1, mockColumn2, + }).Once() + // utils.go::getColumns + mockColumn1.EXPECT().GetName().Return("id").Once() + // utils.go::getType + mockColumn1.EXPECT().GetType().Return("integer").Once() + // postgres.go::TypeInteger + mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once() + // postgres.go::ModifyDefault + mockColumn1.EXPECT().GetDefault().Return(nil).Once() + // postgres.go::ModifyIncrement + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + mockColumn1.EXPECT().GetType().Return("integer").Once() + // postgres.go::ModifyNullable + mockColumn1.EXPECT().GetNullable().Return(false).Once() + + // utils.go::getColumns + mockColumn2.EXPECT().GetName().Return("name").Once() + // utils.go::getType + mockColumn2.EXPECT().GetType().Return("string").Once() + // postgres.go::TypeString + mockColumn2.EXPECT().GetLength().Return(100).Once() + // postgres.go::ModifyDefault + mockColumn2.EXPECT().GetDefault().Return(nil).Once() + // postgres.go::ModifyIncrement + mockColumn2.EXPECT().GetType().Return("string").Once() + // postgres.go::ModifyNullable + mockColumn2.EXPECT().GetNullable().Return(true).Once() + + s.Equal(`create table "goravel_users" ("id" int identity primary key not null, "name" nvarchar(100) null)`, + s.grammar.CompileCreate(mockBlueprint)) +} + +func (s *SqlserverSuite) TestCompileDropIfExists() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + + s.Equal(`if object_id('"goravel_users"', 'U') is not null drop table "goravel_users"`, s.grammar.CompileDropIfExists(mockBlueprint)) +} + +func (s *SqlserverSuite) TestCompileForeign() { + var mockBlueprint *mocksschema.Blueprint + + beforeEach := func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + } + + tests := []struct { + name string + command *contractsschema.Command + expectSql string + }{ + { + name: "with on delete and on update", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + On: "roles", + References: []string{"id", "user_id"}, + OnDelete: "cascade", + OnUpdate: "restrict", + }, + expectSql: `alter table "goravel_users" add constraint "fk_users_role_id" foreign key ("role_id", "user_id") references "goravel_roles" ("id", "user_id") on delete cascade on update restrict`, + }, + { + name: "without on delete and on update", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + On: "roles", + References: []string{"id", "user_id"}, + }, + expectSql: `alter table "goravel_users" add constraint "fk_users_role_id" foreign key ("role_id", "user_id") references "goravel_roles" ("id", "user_id")`, + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + beforeEach() + + sql := s.grammar.CompileForeign(mockBlueprint, test.command) + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *SqlserverSuite) TestCompileIndex() { + var mockBlueprint *mocksschema.Blueprint + + beforeEach := func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + } + + tests := []struct { + name string + command *contractsschema.Command + expectSql string + }{ + { + name: "with Algorithm", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + Algorithm: "btree", + }, + expectSql: `create index "fk_users_role_id" on "goravel_users" ("role_id", "user_id")`, + }, + { + name: "without Algorithm", + command: &contractsschema.Command{ + Index: "fk_users_role_id", + Columns: []string{"role_id", "user_id"}, + }, + expectSql: `create index "fk_users_role_id" on "goravel_users" ("role_id", "user_id")`, + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + beforeEach() + + sql := s.grammar.CompileIndex(mockBlueprint, test.command) + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *SqlserverSuite) TestCompilePrimary() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockBlueprint.EXPECT().GetTableName().Return("users").Once() + + s.Equal(`alter table "goravel_users" add constraint "role" primary key ("role_id", "user_id")`, s.grammar.CompilePrimary(mockBlueprint, &contractsschema.Command{ + Columns: []string{"role_id", "user_id"}, + Index: "role", + })) +} + +func (s *SqlserverSuite) TestGetColumns() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + mockBlueprint.EXPECT().GetAddedColumns().Return([]contractsschema.ColumnDefinition{ + mockColumn1, mockColumn2, + }).Once() + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + + mockColumn1.EXPECT().GetName().Return("id").Once() + mockColumn1.EXPECT().GetType().Return("integer").Twice() + mockColumn1.EXPECT().GetDefault().Return(nil).Once() + mockColumn1.EXPECT().GetNullable().Return(false).Once() + mockColumn1.EXPECT().GetAutoIncrement().Return(true).Once() + + mockColumn2.EXPECT().GetName().Return("name").Once() + mockColumn2.EXPECT().GetType().Return("string").Twice() + mockColumn2.EXPECT().GetDefault().Return("goravel").Twice() + mockColumn2.EXPECT().GetNullable().Return(true).Once() + mockColumn2.EXPECT().GetLength().Return(10).Once() + + s.Equal([]string{`"id" int identity primary key not null`, `"name" nvarchar(10) default 'goravel' null`}, s.grammar.getColumns(mockBlueprint)) +} + +func (s *SqlserverSuite) TestModifyDefault() { + var ( + mockBlueprint *mocksschema.Blueprint + mockColumn *mocksschema.ColumnDefinition + ) + + tests := []struct { + name string + setup func() + expectSql string + }{ + { + name: "without change and default is nil", + setup: func() { + mockColumn.EXPECT().GetDefault().Return(nil).Once() + }, + }, + { + name: "without change and default is not nil", + setup: func() { + mockColumn.EXPECT().GetDefault().Return("goravel").Twice() + }, + expectSql: " default 'goravel'", + }, + } + + for _, test := range tests { + s.Run(test.name, func() { + mockBlueprint = mocksschema.NewBlueprint(s.T()) + mockColumn = mocksschema.NewColumnDefinition(s.T()) + + test.setup() + + sql := s.grammar.ModifyDefault(mockBlueprint, mockColumn) + + s.Equal(test.expectSql, sql) + }) + } +} + +func (s *SqlserverSuite) TestModifyNullable() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + mockColumn := mocksschema.NewColumnDefinition(s.T()) + mockColumn.EXPECT().GetNullable().Return(true).Once() + + s.Equal(" null", s.grammar.ModifyNullable(mockBlueprint, mockColumn)) + + mockColumn.EXPECT().GetNullable().Return(false).Once() + + s.Equal(" not null", s.grammar.ModifyNullable(mockBlueprint, mockColumn)) +} + +func (s *SqlserverSuite) TestModifyIncrement() { + mockBlueprint := mocksschema.NewBlueprint(s.T()) + + mockColumn := mocksschema.NewColumnDefinition(s.T()) + mockBlueprint.EXPECT().HasCommand("primary").Return(false).Once() + mockColumn.EXPECT().GetType().Return("bigInteger").Once() + mockColumn.EXPECT().GetAutoIncrement().Return(true).Once() + + s.Equal(" identity primary key", s.grammar.ModifyIncrement(mockBlueprint, mockColumn)) +} + +func (s *SqlserverSuite) TestTypeString() { + mockColumn1 := mocksschema.NewColumnDefinition(s.T()) + mockColumn1.EXPECT().GetLength().Return(100).Once() + + s.Equal("nvarchar(100)", s.grammar.TypeString(mockColumn1)) + + mockColumn2 := mocksschema.NewColumnDefinition(s.T()) + mockColumn2.EXPECT().GetLength().Return(0).Once() + + s.Equal("nvarchar(255)", s.grammar.TypeString(mockColumn2)) +} diff --git a/database/schema/mysql_schema.go b/database/schema/mysql_schema.go index 6259d7bd1..1ed4e23ee 100644 --- a/database/schema/mysql_schema.go +++ b/database/schema/mysql_schema.go @@ -36,23 +36,25 @@ func (r *MysqlSchema) DropAllTables() error { return nil } - if _, err = r.orm.Query().Exec(r.grammar.CompileDisableForeignKeyConstraints()); err != nil { - return err - } + return r.orm.Transaction(func(tx orm.Query) error { + if _, err = tx.Exec(r.grammar.CompileDisableForeignKeyConstraints()); err != nil { + return err + } + + var dropTables []string + for _, table := range tables { + dropTables = append(dropTables, table.Name) + } + if _, err = tx.Exec(r.grammar.CompileDropAllTables(dropTables)); err != nil { + return err + } + + if _, err = tx.Exec(r.grammar.CompileEnableForeignKeyConstraints()); err != nil { + return err + } - var dropTables []string - for _, table := range tables { - dropTables = append(dropTables, table.Name) - } - if _, err = r.orm.Query().Exec(r.grammar.CompileDropAllTables(dropTables)); err != nil { return err - } - - if _, err = r.orm.Query().Exec(r.grammar.CompileEnableForeignKeyConstraints()); err != nil { - return err - } - - return err + }) } func (r *MysqlSchema) DropAllTypes() error { diff --git a/database/schema/postgres_schema.go b/database/schema/postgres_schema.go index d2b2f4b53..763fecb40 100644 --- a/database/schema/postgres_schema.go +++ b/database/schema/postgres_schema.go @@ -2,13 +2,11 @@ package schema import ( "fmt" - "slices" - "strings" - "github.com/goravel/framework/contracts/database/orm" contractsschema "github.com/goravel/framework/contracts/database/schema" "github.com/goravel/framework/database/schema/grammars" "github.com/goravel/framework/database/schema/processors" + "slices" ) type PostgresSchema struct { @@ -81,19 +79,21 @@ func (r *PostgresSchema) DropAllTypes() error { } } - if len(dropTypes) > 0 { - if _, err := r.orm.Query().Exec(r.grammar.CompileDropAllTypes(dropTypes)); err != nil { - return err + return r.orm.Transaction(func(tx orm.Query) error { + if len(dropTypes) > 0 { + if _, err := tx.Exec(r.grammar.CompileDropAllTypes(dropTypes)); err != nil { + return err + } } - } - if len(dropDomains) > 0 { - if _, err := r.orm.Query().Exec(r.grammar.CompileDropAllDomains(dropDomains)); err != nil { - return err + if len(dropDomains) > 0 { + if _, err := tx.Exec(r.grammar.CompileDropAllDomains(dropDomains)); err != nil { + return err + } } - } - return nil + return nil + }) } func (r *PostgresSchema) DropAllViews() error { @@ -121,7 +121,11 @@ func (r *PostgresSchema) DropAllViews() error { } func (r *PostgresSchema) GetIndexes(table string) ([]contractsschema.Index, error) { - schema, table := r.parseSchemaAndTable(table) + schema, table, err := parseSchemaAndTable(table, r.schema) + if err != nil { + return nil, err + } + table = r.prefix + table var dbIndexes []processors.DBIndex @@ -140,16 +144,3 @@ func (r *PostgresSchema) GetTypes() ([]contractsschema.Type, error) { return r.processor.ProcessTypes(types), nil } - -func (r *PostgresSchema) parseSchemaAndTable(reference string) (schema, table string) { - parts := strings.Split(reference, ".") - schema = r.schema - if len(parts) == 2 { - schema = parts[0] - parts = parts[1:] - } - - table = parts[0] - - return -} diff --git a/database/schema/postgres_schema_test.go b/database/schema/postgres_schema_test.go deleted file mode 100644 index 83090a4db..000000000 --- a/database/schema/postgres_schema_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package schema - -import ( - "testing" - - "github.com/stretchr/testify/suite" - - "github.com/goravel/framework/database/gorm" - "github.com/goravel/framework/database/schema/grammars" - mocksorm "github.com/goravel/framework/mocks/database/orm" - "github.com/goravel/framework/support/docker" - "github.com/goravel/framework/support/env" -) - -type PostgresSchemaSuite struct { - suite.Suite - mockOrm *mocksorm.Orm - postgresSchema *PostgresSchema - testQuery *gorm.TestQuery -} - -func TestPostgresSchemaSuite(t *testing.T) { - if env.IsWindows() { - t.Skip("Skip test that using Docker") - } - - suite.Run(t, &PostgresSchemaSuite{}) -} - -func (s *PostgresSchemaSuite) SetupTest() { - postgresDocker := docker.Postgres() - s.Require().NoError(postgresDocker.Ready()) - - s.testQuery = gorm.NewTestQuery(postgresDocker, true) - s.mockOrm = mocksorm.NewOrm(s.T()) - s.postgresSchema = NewPostgresSchema(grammars.NewPostgres("goravel_"), s.mockOrm, "goravel", "goravel_") -} - -// TODO Implement this after implementing create type -func (s *PostgresSchemaSuite) TestGetTypes() { - -} - -func (s *PostgresSchemaSuite) TestParseSchemaAndTable() { - tests := []struct { - reference string - expectedSchema string - expectedTable string - }{ - {"public.users", "public", "users"}, - {"users", "goravel", "users"}, - } - - for _, test := range tests { - schema, table := s.postgresSchema.parseSchemaAndTable(test.reference) - s.Equal(test.expectedSchema, schema) - s.Equal(test.expectedTable, table) - } -} diff --git a/database/schema/processors/mysql.go b/database/schema/processors/mysql.go index 22ee8ed0d..8b67d4a5c 100644 --- a/database/schema/processors/mysql.go +++ b/database/schema/processors/mysql.go @@ -1,8 +1,6 @@ package processors import ( - "strings" - "github.com/goravel/framework/contracts/database/schema" ) @@ -14,16 +12,5 @@ func NewMysql() Mysql { } func (r Mysql) ProcessIndexes(dbIndexes []DBIndex) []schema.Index { - var indexes []schema.Index - for _, dbIndex := range dbIndexes { - indexes = append(indexes, schema.Index{ - Columns: strings.Split(dbIndex.Columns, ","), - Name: strings.ToLower(dbIndex.Name), - Type: strings.ToLower(dbIndex.Type), - Primary: dbIndex.Primary, - Unique: dbIndex.Unique, - }) - } - - return indexes + return processIndexes(dbIndexes) } diff --git a/database/schema/processors/postgres.go b/database/schema/processors/postgres.go index 4cf344763..dbd26cf59 100644 --- a/database/schema/processors/postgres.go +++ b/database/schema/processors/postgres.go @@ -1,8 +1,6 @@ package processors import ( - "strings" - "github.com/goravel/framework/contracts/database/schema" ) @@ -14,18 +12,7 @@ func NewPostgres() Postgres { } func (r Postgres) ProcessIndexes(dbIndexes []DBIndex) []schema.Index { - var indexes []schema.Index - for _, dbIndex := range dbIndexes { - indexes = append(indexes, schema.Index{ - Columns: strings.Split(dbIndex.Columns, ","), - Name: strings.ToLower(dbIndex.Name), - Type: strings.ToLower(dbIndex.Type), - Primary: dbIndex.Primary, - Unique: dbIndex.Unique, - }) - } - - return indexes + return processIndexes(dbIndexes) } func (r Postgres) ProcessTypes(types []schema.Type) []schema.Type { diff --git a/database/schema/processors/postgres_test.go b/database/schema/processors/postgres_test.go index 97380c76c..9dfb21539 100644 --- a/database/schema/processors/postgres_test.go +++ b/database/schema/processors/postgres_test.go @@ -8,30 +8,6 @@ import ( "github.com/goravel/framework/contracts/database/schema" ) -func TestPostgresProcessIndexes(t *testing.T) { - // Test with valid indexes - input := []DBIndex{ - {Name: "INDEX_A", Type: "BTREE", Columns: "a,b"}, - {Name: "INDEX_B", Type: "HASH", Columns: "c,d"}, - } - expected := []schema.Index{ - {Name: "index_a", Type: "btree", Columns: []string{"a", "b"}}, - {Name: "index_b", Type: "hash", Columns: []string{"c", "d"}}, - } - - postgres := NewPostgres() - result := postgres.ProcessIndexes(input) - - assert.Equal(t, expected, result) - - // Test with empty input - input = []DBIndex{} - - result = postgres.ProcessIndexes(input) - - assert.Nil(t, result) -} - func TestPostgresProcessTypes(t *testing.T) { // ValidTypes_ReturnsProcessedTypes input := []schema.Type{ diff --git a/database/schema/processors/sqlserver.go b/database/schema/processors/sqlserver.go new file mode 100644 index 000000000..51996c273 --- /dev/null +++ b/database/schema/processors/sqlserver.go @@ -0,0 +1,16 @@ +package processors + +import ( + "github.com/goravel/framework/contracts/database/schema" +) + +type Sqlserver struct { +} + +func NewSqlserver() Sqlserver { + return Sqlserver{} +} + +func (r Sqlserver) ProcessIndexes(dbIndexes []DBIndex) []schema.Index { + return processIndexes(dbIndexes) +} diff --git a/database/schema/processors/utils.go b/database/schema/processors/utils.go new file mode 100644 index 000000000..e300535f7 --- /dev/null +++ b/database/schema/processors/utils.go @@ -0,0 +1,22 @@ +package processors + +import ( + "strings" + + "github.com/goravel/framework/contracts/database/schema" +) + +func processIndexes(dbIndexes []DBIndex) []schema.Index { + var indexes []schema.Index + for _, dbIndex := range dbIndexes { + indexes = append(indexes, schema.Index{ + Columns: strings.Split(dbIndex.Columns, ","), + Name: strings.ToLower(dbIndex.Name), + Type: strings.ToLower(dbIndex.Type), + Primary: dbIndex.Primary, + Unique: dbIndex.Unique, + }) + } + + return indexes +} diff --git a/database/schema/processors/mysql_test.go b/database/schema/processors/utils_test.go similarity index 79% rename from database/schema/processors/mysql_test.go rename to database/schema/processors/utils_test.go index 97dab50bb..a55fedbfa 100644 --- a/database/schema/processors/mysql_test.go +++ b/database/schema/processors/utils_test.go @@ -8,7 +8,7 @@ import ( "github.com/goravel/framework/contracts/database/schema" ) -func TestMysqlProcessIndexes(t *testing.T) { +func TestProcessIndexes(t *testing.T) { // Test with valid indexes input := []DBIndex{ {Name: "INDEX_A", Type: "BTREE", Columns: "a,b"}, @@ -19,15 +19,14 @@ func TestMysqlProcessIndexes(t *testing.T) { {Name: "index_b", Type: "hash", Columns: []string{"c", "d"}}, } - mysql := NewMysql() - result := mysql.ProcessIndexes(input) + result := processIndexes(input) assert.Equal(t, expected, result) // Test with empty input input = []DBIndex{} - result = mysql.ProcessIndexes(input) + result = processIndexes(input) assert.Nil(t, result) } diff --git a/database/schema/schema.go b/database/schema/schema.go index a3a92d5e0..0837ab5aa 100644 --- a/database/schema/schema.go +++ b/database/schema/schema.go @@ -49,7 +49,9 @@ func NewSchema(config config.Config, log log.Log, orm contractsorm.Orm, migratio driverSchema = NewMysqlSchema(mysqlGrammar, orm, prefix) grammar = mysqlGrammar case contractsdatabase.DriverSqlserver: - // TODO Optimize here when implementing Sqlserver driver + sqlserverGrammar := grammars.NewSqlserver(prefix) + driverSchema = NewSqlserverSchema(sqlserverGrammar, orm, prefix) + grammar = sqlserverGrammar case contractsdatabase.DriverSqlite: sqliteGrammar := grammars.NewSqlite(prefix) driverSchema = NewSqliteSchema(sqliteGrammar, orm, prefix) diff --git a/database/schema/schema_test.go b/database/schema/schema_test.go index 4f59e0372..6929b2b44 100644 --- a/database/schema/schema_test.go +++ b/database/schema/schema_test.go @@ -26,7 +26,6 @@ func TestSchemaSuite(t *testing.T) { } func (s *SchemaSuite) SetupTest() { - // TODO Add other drivers postgresDocker := docker.Postgres() s.Require().NoError(postgresDocker.Ready()) @@ -40,10 +39,16 @@ func (s *SchemaSuite) SetupTest() { mysqlQuery := gorm.NewTestQuery(mysqlDocker, true) + sqlserverDocker := docker.Sqlserver() + s.Require().NoError(sqlserverDocker.Ready()) + + sqlserverQuery := gorm.NewTestQuery(sqlserverDocker, true) + s.driverToTestQuery = map[database.Driver]*gorm.TestQuery{ - database.DriverPostgres: postgresQuery, - database.DriverSqlite: sqliteQuery, - database.DriverMysql: mysqlQuery, + database.DriverPostgres: postgresQuery, + database.DriverSqlite: sqliteQuery, + database.DriverMysql: mysqlQuery, + database.DriverSqlserver: sqlserverQuery, } } @@ -155,6 +160,9 @@ func (s *SchemaSuite) TestPrimary() { if driver == database.DriverMysql { s.Require().True(schema.HasIndex(table, "primary")) } + if driver == database.DriverSqlserver { + s.Require().True(schema.HasIndex(table, "goravel_primaries_name_age_primary")) + } }) } } @@ -187,6 +195,8 @@ func (s *SchemaSuite) TestIndexMethods() { s.False(index.Primary) if driver == database.DriverSqlite { s.Empty(index.Type) + } else if driver == database.DriverSqlserver { + s.Equal("nonclustered", index.Type) } else { s.Equal("btree", index.Type) } diff --git a/database/schema/sqlite_schema.go b/database/schema/sqlite_schema.go index d02536c18..f08b16394 100644 --- a/database/schema/sqlite_schema.go +++ b/database/schema/sqlite_schema.go @@ -48,20 +48,22 @@ func (r *SqliteSchema) DropAllTypes() error { } func (r *SqliteSchema) DropAllViews() error { - if _, err := r.orm.Query().Exec(r.grammar.CompileEnableWriteableSchema()); err != nil { - return err - } - if _, err := r.orm.Query().Exec(r.grammar.CompileDropAllViews(nil)); err != nil { - return err - } - if _, err := r.orm.Query().Exec(r.grammar.CompileDisableWriteableSchema()); err != nil { - return err - } - if _, err := r.orm.Query().Exec(r.grammar.CompileRebuild()); err != nil { - return err - } + return r.orm.Transaction(func(tx orm.Query) error { + if _, err := tx.Exec(r.grammar.CompileEnableWriteableSchema()); err != nil { + return err + } + if _, err := tx.Exec(r.grammar.CompileDropAllViews(nil)); err != nil { + return err + } + if _, err := tx.Exec(r.grammar.CompileDisableWriteableSchema()); err != nil { + return err + } + if _, err := tx.Exec(r.grammar.CompileRebuild()); err != nil { + return err + } - return nil + return nil + }) } func (r *SqliteSchema) GetIndexes(table string) ([]schema.Index, error) { diff --git a/database/schema/sqlserver_schema.go b/database/schema/sqlserver_schema.go new file mode 100644 index 000000000..bfd8c1ee4 --- /dev/null +++ b/database/schema/sqlserver_schema.go @@ -0,0 +1,69 @@ +package schema + +import ( + "github.com/goravel/framework/contracts/database/orm" + contractsschema "github.com/goravel/framework/contracts/database/schema" + "github.com/goravel/framework/database/schema/grammars" + "github.com/goravel/framework/database/schema/processors" +) + +type SqlserverSchema struct { + contractsschema.CommonSchema + + grammar *grammars.Sqlserver + orm orm.Orm + prefix string + processor processors.Sqlserver +} + +func NewSqlserverSchema(grammar *grammars.Sqlserver, orm orm.Orm, prefix string) *SqlserverSchema { + return &SqlserverSchema{ + CommonSchema: NewCommonSchema(grammar, orm), + grammar: grammar, + orm: orm, + prefix: prefix, + processor: processors.NewSqlserver(), + } +} + +func (r *SqlserverSchema) DropAllTables() error { + if _, err := r.orm.Query().Exec(r.grammar.CompileDropAllForeignKeys()); err != nil { + return err + } + + if _, err := r.orm.Query().Exec(r.grammar.CompileDropAllTables(nil)); err != nil { + return err + } + + return nil +} + +func (r *SqlserverSchema) DropAllTypes() error { + return nil +} + +func (r *SqlserverSchema) DropAllViews() error { + _, err := r.orm.Query().Exec(r.grammar.CompileDropAllViews(nil)) + + return err +} + +func (r *SqlserverSchema) GetIndexes(table string) ([]contractsschema.Index, error) { + schema, table, err := parseSchemaAndTable(table, "") + if err != nil { + return nil, err + } + + table = r.prefix + table + + var dbIndexes []processors.DBIndex + if err := r.orm.Query().Raw(r.grammar.CompileIndexes(schema, table)).Scan(&dbIndexes); err != nil { + return nil, err + } + + return r.processor.ProcessIndexes(dbIndexes), nil +} + +func (r *SqlserverSchema) GetTypes() ([]contractsschema.Type, error) { + return nil, nil +} diff --git a/database/schema/utils.go b/database/schema/utils.go new file mode 100644 index 000000000..a2911d6eb --- /dev/null +++ b/database/schema/utils.go @@ -0,0 +1,28 @@ +package schema + +import ( + "strings" + + "github.com/goravel/framework/errors" +) + +func parseSchemaAndTable(reference, defaultSchema string) (string, string, error) { + if reference == "" { + return "", "", errors.SchemaEmptyReferenceString + } + + parts := strings.Split(reference, ".") + if len(parts) > 2 { + return "", "", errors.SchemaErrorReferenceFormat + } + + schema := defaultSchema + if len(parts) == 2 { + schema = parts[0] + parts = parts[1:] + } + + table := parts[0] + + return schema, table, nil +} diff --git a/database/schema/utils_test.go b/database/schema/utils_test.go new file mode 100644 index 000000000..2db4c4ef5 --- /dev/null +++ b/database/schema/utils_test.go @@ -0,0 +1,30 @@ +package schema + +import ( + "testing" + + "github.com/goravel/framework/errors" + "github.com/stretchr/testify/assert" +) + +func TestParseSchemaAndTable(t *testing.T) { + tests := []struct { + reference string + defaultSchema string + expectedSchema string + expectedTable string + expectedError error + }{ + {"public.users", "public", "public", "users", nil}, + {"users", "goravel", "goravel", "users", nil}, + {"", "", "", "", errors.SchemaEmptyReferenceString}, + {"public.users.extra", "", "", "", errors.SchemaErrorReferenceFormat}, + } + + for _, test := range tests { + schema, table, err := parseSchemaAndTable(test.reference, test.defaultSchema) + assert.Equal(t, test.expectedSchema, schema) + assert.Equal(t, test.expectedTable, table) + assert.Equal(t, test.expectedError, err) + } +} diff --git a/errors/list.go b/errors/list.go index f1b7065db..bddd2d38b 100644 --- a/errors/list.go +++ b/errors/list.go @@ -112,10 +112,12 @@ var ( RouteDefaultDriverNotSet = New("please set default driver") RouteInvalidDriver = New("init %s route driver fail: route must be implement route.Route or func() (route.Route, error)") - SchemaDriverNotSupported = New("driver %s is not supported") - SchemaFailedToCreateTable = New("failed to create %s table: %v") - SchemaFailedToDropTable = New("failed to drop %s table: %v") - SchemaFailedToGetTables = New("failed to get %s tables: %v") + SchemaDriverNotSupported = New("driver %s is not supported") + SchemaFailedToCreateTable = New("failed to create %s table: %v") + SchemaFailedToDropTable = New("failed to drop %s table: %v") + SchemaFailedToGetTables = New("failed to get %s tables: %v") + SchemaEmptyReferenceString = New("reference string can't be empty") + SchemaErrorReferenceFormat = New("invalid format: too many dots in reference") SessionDriverAlreadyExists = New("session driver [%s] already exists") SessionDriverExtensionFailed = New("session failed to extend session [%s] driver [%v]")