Skip to content

Commit

Permalink
feat: [#280] Implement Mysql driver (#725)
Browse files Browse the repository at this point in the history
* feat: [#280] Implement Mysql driver

* Add unit tests

* fix test

* fix AI comments

* add ready method
  • Loading branch information
hwbrzzl authored Nov 17, 2024
1 parent 197dd26 commit 9267b55
Show file tree
Hide file tree
Showing 22 changed files with 1,000 additions and 70 deletions.
2 changes: 2 additions & 0 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ type Orm interface {
DB() (*sql.DB, error)
// Factory gets a new factory instance for the given model name.
Factory() Factory
// DatabaseName gets the current database name.
DatabaseName() string
// Name gets the current connection name.
Name() string
// Observe registers an observer with the Orm.
Expand Down
10 changes: 8 additions & 2 deletions contracts/database/schema/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ import (
)

type Blueprint interface {
// BigIncrements Create a new auto-incrementing big integer (8-byte) column on the table.
BigIncrements(column string) ColumnDefinition
// BigInteger Create a new big integer (8-byte) column on the table.
BigInteger(column string) ColumnDefinition
// Build Execute the blueprint to build / modify the table.
Build(query orm.Query, grammar Grammar) error
// Create Indicate that the table needs to be created.
Expand All @@ -21,20 +25,22 @@ type Blueprint interface {
GetTableName() string
// HasCommand Determine if the blueprint has a specific command.
HasCommand(command string) bool
// Primary Specify the primary key(s) for the table.
Primary(column ...string)
// ID Create a new auto-incrementing big integer (8-byte) column on the table.
ID(column ...string) ColumnDefinition
// Index Specify an index for the table.
Index(column ...string) IndexDefinition
// Integer Create a new integer (4-byte) column on the table.
Integer(column string) ColumnDefinition
// Primary Specify the primary key(s) for the table.
Primary(column ...string)
// SetTable Set the table that the blueprint operates on.
SetTable(name string)
// String Create a new string column on the table.
String(column string, length ...int) ColumnDefinition
// ToSql Get the raw SQL statements for the blueprint.
ToSql(grammar Grammar) []string
// UnsignedBigInteger Create a new unsigned big integer (8-byte) column on the table.
UnsignedBigInteger(column string) ColumnDefinition
}

type IndexConfig struct {
Expand Down
4 changes: 2 additions & 2 deletions contracts/database/schema/grammar.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ type Grammar interface {
// CompilePrimary Compile a primary key command.
CompilePrimary(blueprint Blueprint, command *Command) string
// CompileTables Compile the query to determine the tables.
CompileTables() string
CompileTables(database string) string
// CompileTypes Compile the query to determine the types.
CompileTypes() string
// CompileViews Compile the query to determine the views.
CompileViews() string
CompileViews(database string) string
// GetAttributeCommands Get the commands for the schema build.
GetAttributeCommands() []string
// TypeBigInteger Create the column definition for a big integer type.
Expand Down
25 changes: 13 additions & 12 deletions contracts/database/schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,19 @@ type Connection interface {
}

type Command struct {
Algorithm string
Column ColumnDefinition
Columns []string
From string
Index string
On string
OnDelete string
OnUpdate string
Name string
To string
References []string
Value string
Algorithm string
Column ColumnDefinition
Columns []string
From string
Index string
On string
OnDelete string
OnUpdate string
Name string
To string
References []string
ShouldBeSkipped bool
Value string
}

type Index struct {
Expand Down
5 changes: 5 additions & 0 deletions database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package orm
import (
"context"
"database/sql"
"fmt"
"sync"

"github.com/goravel/framework/contracts/config"
Expand Down Expand Up @@ -89,6 +90,10 @@ func (r *Orm) Factory() contractsorm.Factory {
return factory.NewFactoryImpl(r.Query())
}

func (r *Orm) DatabaseName() string {
return r.config.GetString(fmt.Sprintf("database.connections.%s.database", r.connection))
}

func (r *Orm) Name() string {
return r.connection
}
Expand Down
4 changes: 4 additions & 0 deletions database/schema/blueprint.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ func (r *Blueprint) ToSql(grammar schema.Grammar) []string {

var statements []string
for _, command := range r.commands {
if command.ShouldBeSkipped {
continue
}

switch command.Name {
case constants.CommandAdd:
statements = append(statements, grammar.CompileAdd(r, command))
Expand Down
4 changes: 2 additions & 2 deletions database/schema/common_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewCommonSchema(grammar schema.Grammar, orm orm.Orm) *CommonSchema {

func (r *CommonSchema) GetTables() ([]schema.Table, error) {
var tables []schema.Table
if err := r.orm.Query().Raw(r.grammar.CompileTables()).Scan(&tables); err != nil {
if err := r.orm.Query().Raw(r.grammar.CompileTables(r.orm.DatabaseName())).Scan(&tables); err != nil {
return nil, err
}

Expand All @@ -28,7 +28,7 @@ func (r *CommonSchema) GetTables() ([]schema.Table, error) {

func (r *CommonSchema) GetViews() ([]schema.View, error) {
var views []schema.View
if err := r.orm.Query().Raw(r.grammar.CompileViews()).Scan(&views); err != nil {
if err := r.orm.Query().Raw(r.grammar.CompileViews(r.orm.DatabaseName())).Scan(&views); err != nil {
return nil, err
}

Expand Down
217 changes: 217 additions & 0 deletions database/schema/grammars/mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package grammars

import (
"fmt"
"slices"
"strings"

contractsdatabase "github.com/goravel/framework/contracts/database"
"github.com/goravel/framework/contracts/database/schema"
"github.com/goravel/framework/database/schema/constants"
)

type Mysql struct {
attributeCommands []string
modifiers []func(schema.Blueprint, schema.ColumnDefinition) string
serials []string
wrap *Wrap
}

func NewMysql(tablePrefix string) *Mysql {
mysql := &Mysql{
attributeCommands: []string{constants.CommandComment},
serials: []string{"bigInteger", "integer", "mediumInteger", "smallInteger", "tinyInteger"},
wrap: NewWrap(contractsdatabase.DriverMysql, tablePrefix),
}
mysql.modifiers = []func(schema.Blueprint, schema.ColumnDefinition) string{
mysql.ModifyDefault,
mysql.ModifyIncrement,
mysql.ModifyNullable,
}

return mysql
}

func (r *Mysql) CompileAdd(blueprint schema.Blueprint, command *schema.Command) string {
return fmt.Sprintf("alter table %s add %s", r.wrap.Table(blueprint.GetTableName()), r.getColumn(blueprint, command.Column))
}

func (r *Mysql) CompileCreate(blueprint schema.Blueprint) string {
columns := r.getColumns(blueprint)
primaryCommand := getCommandByName(blueprint.GetCommands(), "primary")
if primaryCommand != nil {
var algorithm string
if primaryCommand.Algorithm != "" {
algorithm = "using " + primaryCommand.Algorithm
}
columns = append(columns, fmt.Sprintf("primary key %s(%s)", algorithm, r.wrap.Columnize(primaryCommand.Columns)))

primaryCommand.ShouldBeSkipped = true
}

return fmt.Sprintf("create table %s (%s)", r.wrap.Table(blueprint.GetTableName()), strings.Join(columns, ", "))
}

func (r *Mysql) CompileDisableForeignKeyConstraints() string {
return "SET FOREIGN_KEY_CHECKS=0;"
}

func (r *Mysql) CompileDropAllDomains(domains []string) string {
return ""
}

func (r *Mysql) CompileDropAllTables(tables []string) string {
return fmt.Sprintf("drop table %s", r.wrap.Columnize(tables))
}

func (r *Mysql) CompileDropAllTypes(types []string) string {
return ""
}

func (r *Mysql) CompileDropAllViews(views []string) string {
return fmt.Sprintf("drop view %s", r.wrap.Columnize(views))
}

func (r *Mysql) CompileDropIfExists(blueprint schema.Blueprint) string {
return fmt.Sprintf("drop table if exists %s", r.wrap.Table(blueprint.GetTableName()))
}

func (r *Mysql) CompileEnableForeignKeyConstraints() string {
return "SET FOREIGN_KEY_CHECKS=1;"
}

func (r *Mysql) CompileForeign(blueprint schema.Blueprint, command *schema.Command) string {
sql := fmt.Sprintf("alter table %s add constraint %s foreign key (%s) references %s (%s)",
r.wrap.Table(blueprint.GetTableName()),
r.wrap.Column(command.Index),
r.wrap.Columnize(command.Columns),
r.wrap.Table(command.On),
r.wrap.Columnize(command.References))
if command.OnDelete != "" {
sql += " on delete " + command.OnDelete
}
if command.OnUpdate != "" {
sql += " on update " + command.OnUpdate
}

return sql
}

func (r *Mysql) CompileIndex(blueprint schema.Blueprint, command *schema.Command) string {
var algorithm string
if command.Algorithm != "" {
algorithm = " using " + command.Algorithm
}

return fmt.Sprintf("alter table %s add %s %s%s(%s)",
r.wrap.Table(blueprint.GetTableName()),
"index",
r.wrap.Column(command.Index),
algorithm,
r.wrap.Columnize(command.Columns),
)
}

func (r *Mysql) CompileIndexes(schema, table string) string {
return fmt.Sprintf(
"select index_name as `name`, group_concat(column_name order by seq_in_index) as `columns`, "+
"index_type as `type`, not non_unique as `unique` "+
"from information_schema.statistics where table_schema = %s and table_name = %s "+
"group by index_name, index_type, non_unique",
r.wrap.Quote(schema),
r.wrap.Quote(table),
)
}

func (r *Mysql) CompilePrimary(blueprint schema.Blueprint, command *schema.Command) string {
var algorithm string
if command.Algorithm != "" {
algorithm = "using " + command.Algorithm
}

return fmt.Sprintf("alter table %s add primary key %s(%s)", r.wrap.Table(blueprint.GetTableName()), algorithm, r.wrap.Columnize(command.Columns))
}

func (r *Mysql) CompileTables(database string) string {
return fmt.Sprintf("select table_name as `name`, (data_length + index_length) as `size`, "+
"table_comment as `comment`, engine as `engine`, table_collation as `collation` "+
"from information_schema.tables where table_schema = %s and table_type in ('BASE TABLE', 'SYSTEM VERSIONED') "+
"order by table_name", r.wrap.Quote(database))
}

func (r *Mysql) CompileTypes() string {
return ""
}

func (r *Mysql) CompileViews(database string) string {
return fmt.Sprintf("select table_name as `name`, view_definition as `definition` "+
"from information_schema.views where table_schema = %s "+
"order by table_name", r.wrap.Quote(database))
}

func (r *Mysql) GetAttributeCommands() []string {
return r.attributeCommands
}

func (r *Mysql) ModifyDefault(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetDefault() != nil {
return fmt.Sprintf(" default %s", getDefaultValue(column.GetDefault()))
}

return ""
}

func (r *Mysql) ModifyNullable(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if column.GetNullable() {
return " null"
} else {
return " not null"
}
}

func (r *Mysql) ModifyIncrement(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
if slices.Contains(r.serials, column.GetType()) && column.GetAutoIncrement() {
if blueprint.HasCommand("primary") {
return "auto_increment"
}
return " auto_increment primary key"
}

return ""
}

func (r *Mysql) TypeBigInteger(column schema.ColumnDefinition) string {
return "bigint"
}

func (r *Mysql) TypeInteger(column schema.ColumnDefinition) string {
return "int"
}

func (r *Mysql) TypeString(column schema.ColumnDefinition) string {
length := column.GetLength()
if length > 0 {
return fmt.Sprintf("varchar(%d)", length)
}

return "varchar(255)"
}

func (r *Mysql) getColumns(blueprint schema.Blueprint) []string {
var columns []string
for _, column := range blueprint.GetAddedColumns() {
columns = append(columns, r.getColumn(blueprint, column))
}

return columns
}

func (r *Mysql) getColumn(blueprint schema.Blueprint, column schema.ColumnDefinition) string {
sql := fmt.Sprintf("%s %s", r.wrap.Column(column.GetName()), getType(r, column))

for _, modifier := range r.modifiers {
sql += modifier(blueprint, column)
}

return sql
}
Loading

0 comments on commit 9267b55

Please sign in to comment.