Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ func New(a *analyzer.Analyzer, cfg *Config) *Engine {
}
ret.ReadOnly.Store(cfg.IsReadOnly)
a.Runner = ret
a.ExecBuilder.Runner = ret
return ret
}

Expand Down
1 change: 1 addition & 0 deletions enginetest/engine_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ func TestTableFunctions(t *testing.T) {
harness = harness.WithProvider(engine.Analyzer.Catalog.DbProvider)

engine.EngineAnalyzer().ExecBuilder = rowexec.NewBuilder(nil, sql.EngineOverrides{})
engine.EngineAnalyzer().ExecBuilder.Runner = engine

engine, err := enginetest.RunSetupScripts(harness.NewContext(), engine, setup.MydbData, true)
require.NoError(t, err)
Expand Down
1 change: 1 addition & 0 deletions enginetest/initialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func NewEngineWithProvider(_ *testing.T, harness Harness, provider sql.DatabaseP
idh.InitializeIndexDriver(engine.Analyzer.Catalog.AllDatabases(NewContext(harness)))
}
analyzer.Runner = engine
analyzer.ExecBuilder.Runner = engine

return engine
}
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ type Analyzer struct {
// Parser is the parser used to parse SQL statements.
Parser sql.Parser
// ExecBuilder converts a sql.Node tree into an executable iterator.
ExecBuilder sql.NodeExecBuilder
ExecBuilder *rowexec.BaseBuilder
// Runner represents the engine, which is represented as a separate interface to work around circular dependencies
Runner sql.StatementRunner
// SchemaFormatter is used to format the schema of a node to a string.
Expand Down
28 changes: 14 additions & 14 deletions sql/overrides.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,56 +75,56 @@ type ExecutionHooks struct {
// CreateTable contains hooks related to CREATE TABLE statements. These will take a *plan.CreateTable.
type CreateTable struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// RenameTable contains hooks related to RENAME TABLE statements. These will take a *plan.RenameTable.
type RenameTable struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// DropTable contains hooks related to DROP TABLE statements. These will take a *plan.DropTable.
type DropTable struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// TableAddColumn contains hooks related to ALTER TABLE ... ADD COLUMN statements. These will take a *plan.AddColumn.
type TableAddColumn struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// TableRenameColumn contains hooks related to ALTER TABLE ... RENAME COLUMN statements. These will take a *plan.RenameColumn.
type TableRenameColumn struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// TableModifyColumn contains hooks related to ALTER TABLE ... MODIFY COLUMN statements. These will take a
// *plan.ModifyColumn.
type TableModifyColumn struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}

// TableDropColumn contains hooks related to ALTER TABLE ... DROP COLUMN statements. These will take a *plan.DropColumn.
type TableDropColumn struct {
// PreSQLExecution is called before the final step of statement execution, after analysis.
PreSQLExecution func(*Context, Node) (Node, error)
PreSQLExecution func(*Context, StatementRunner, Node) (Node, error)
// PostSQLExecution is called after the final step of statement execution, after analysis.
PostSQLExecution func(*Context, Node) error
PostSQLExecution func(*Context, StatementRunner, Node) error
}
4 changes: 3 additions & 1 deletion sql/rowexec/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,19 @@ import (
type BaseBuilder struct {
PriorityBuilder sql.NodeExecBuilder
EngineOverrides sql.EngineOverrides
Runner sql.StatementRunner
schemaFormatter sql.SchemaFormatter
}

var _ sql.NodeExecBuilder = (*BaseBuilder)(nil)

// NewBuilder creates a new builder. If a priority builder is given, then it is tried first, and only uses the internal
// builder logic if the given one does not return a result (and does not error).
func NewBuilder(priority sql.NodeExecBuilder, overrides sql.EngineOverrides) sql.NodeExecBuilder {
func NewBuilder(priority sql.NodeExecBuilder, overrides sql.EngineOverrides) *BaseBuilder {
return &BaseBuilder{
PriorityBuilder: priority,
EngineOverrides: overrides,
Runner: nil, // This is often set later (directly on the variable), as it's not yet available during creation
schemaFormatter: sql.GetSchemaFormatter(overrides),
}
}
Expand Down
2 changes: 1 addition & 1 deletion sql/rowexec/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
)

var DefaultBuilder = NewBuilder(nil, sql.EngineOverrides{}).(*BaseBuilder)
var DefaultBuilder = NewBuilder(nil, sql.EngineOverrides{})

func newContext(provider *memory.DbProvider) *sql.Context {
return sql.NewContext(context.Background(), sql.WithSession(memory.NewSession(sql.NewBaseSession(), provider)))
Expand Down
20 changes: 11 additions & 9 deletions sql/rowexec/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func (b *BaseBuilder) buildDropCheck(ctx *sql.Context, n *plan.DropCheck, row sq

func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, row sql.Row) (sql.RowIter, error) {
if b.EngineOverrides.Hooks.RenameTable.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.RenameTable.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.RenameTable.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand All @@ -268,7 +268,7 @@ func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, ro
}
}
if b.EngineOverrides.Hooks.RenameTable.PostSQLExecution != nil {
if err := b.EngineOverrides.Hooks.RenameTable.PostSQLExecution(ctx, n); err != nil {
if err := b.EngineOverrides.Hooks.RenameTable.PostSQLExecution(ctx, b.Runner, n); err != nil {
return nil, err
}
}
Expand All @@ -278,7 +278,7 @@ func (b *BaseBuilder) buildRenameTable(ctx *sql.Context, n *plan.RenameTable, ro

func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn, row sql.Row) (sql.RowIter, error) {
if b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.TableModifyColumn.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -321,6 +321,7 @@ func (b *BaseBuilder) buildModifyColumn(ctx *sql.Context, n *plan.ModifyColumn,
m: n,
alterable: alterable,
overrides: b.EngineOverrides,
runner: b.Runner,
}, nil
}

Expand Down Expand Up @@ -951,7 +952,7 @@ func (b *BaseBuilder) buildDropSchema(ctx *sql.Context, n *plan.DropSchema, row

func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn, row sql.Row) (sql.RowIter, error) {
if b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.TableRenameColumn.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1002,7 +1003,7 @@ func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn,
return nil, err
}
if b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution != nil {
if err = b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution(ctx, n); err != nil {
if err = b.EngineOverrides.Hooks.TableRenameColumn.PostSQLExecution(ctx, b.Runner, n); err != nil {
return nil, err
}
}
Expand All @@ -1012,7 +1013,7 @@ func (b *BaseBuilder) buildRenameColumn(ctx *sql.Context, n *plan.RenameColumn,

func (b *BaseBuilder) buildAddColumn(ctx *sql.Context, n *plan.AddColumn, row sql.Row) (sql.RowIter, error) {
if b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.TableAddColumn.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1096,7 +1097,7 @@ func (b *BaseBuilder) buildAlterDB(ctx *sql.Context, n *plan.AlterDB, row sql.Ro
func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, row sql.Row) (sql.RowIter, error) {
var err error
if b.EngineOverrides.Hooks.CreateTable.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.CreateTable.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.CreateTable.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return sql.RowsToRowIter(), err
}
Expand Down Expand Up @@ -1262,7 +1263,7 @@ func (b *BaseBuilder) buildCreateTable(ctx *sql.Context, n *plan.CreateTable, ro
}

if b.EngineOverrides.Hooks.CreateTable.PostSQLExecution != nil {
if err = b.EngineOverrides.Hooks.CreateTable.PostSQLExecution(ctx, n); err != nil {
if err = b.EngineOverrides.Hooks.CreateTable.PostSQLExecution(ctx, b.Runner, n); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1345,7 +1346,7 @@ func (b *BaseBuilder) buildCreateTrigger(ctx *sql.Context, n *plan.CreateTrigger

func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row sql.Row) (sql.RowIter, error) {
if b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.TableDropColumn.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand All @@ -1370,6 +1371,7 @@ func (b *BaseBuilder) buildDropColumn(ctx *sql.Context, n *plan.DropColumn, row
d: n,
alterable: alterable,
overrides: b.EngineOverrides,
runner: b.Runner,
}, nil
}

Expand Down
16 changes: 9 additions & 7 deletions sql/rowexec/ddl_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ type modifyColumnIter struct {
m *plan.ModifyColumn
alterable sql.AlterableTable
overrides sql.EngineOverrides
runner sql.StatementRunner
runOnce bool
}

Expand Down Expand Up @@ -459,7 +460,7 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}
if rewritten {
if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil {
if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil {
if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.runner, i.m); err != nil {
return nil, err
}
}
Expand All @@ -483,7 +484,7 @@ func (i *modifyColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}
if i.overrides.Hooks.TableModifyColumn.PostSQLExecution != nil {
if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.m); err != nil {
if err = i.overrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.runner, i.m); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1380,7 +1381,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}
if rewritten {
if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil {
if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.a); err != nil {
if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil {
return nil, err
}
}
Expand All @@ -1402,7 +1403,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
// We only need to update all table rows if the new column is non-nil
if i.a.Column().Nullable && i.a.Column().Default == nil {
if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil {
if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil {
if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil {
return nil, err
}
}
Expand All @@ -1415,7 +1416,7 @@ func (i *addColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}

if i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution != nil {
if err = i.b.EngineOverrides.Hooks.TableModifyColumn.PostSQLExecution(ctx, i.a); err != nil {
if err = i.b.EngineOverrides.Hooks.TableAddColumn.PostSQLExecution(ctx, i.b.Runner, i.a); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1772,6 +1773,7 @@ type dropColumnIter struct {
d *plan.DropColumn
alterable sql.AlterableTable
overrides sql.EngineOverrides
runner sql.StatementRunner
runOnce bool
}

Expand Down Expand Up @@ -1799,7 +1801,7 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}
if rewritten {
if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil {
if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil {
if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.runner, i.d); err != nil {
return nil, err
}
}
Expand All @@ -1826,7 +1828,7 @@ func (i *dropColumnIter) Next(ctx *sql.Context) (sql.Row, error) {
}
}
if i.overrides.Hooks.TableDropColumn.PostSQLExecution != nil {
if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.d); err != nil {
if err = i.overrides.Hooks.TableDropColumn.PostSQLExecution(ctx, i.runner, i.d); err != nil {
return nil, err
}
}
Expand Down
4 changes: 2 additions & 2 deletions sql/rowexec/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql.
var curdb sql.Database

if b.EngineOverrides.Hooks.DropTable.PreSQLExecution != nil {
nn, err := b.EngineOverrides.Hooks.DropTable.PreSQLExecution(ctx, n)
nn, err := b.EngineOverrides.Hooks.DropTable.PreSQLExecution(ctx, b.Runner, n)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -274,7 +274,7 @@ func (b *BaseBuilder) buildDropTable(ctx *sql.Context, n *plan.DropTable, _ sql.
}

if b.EngineOverrides.Hooks.DropTable.PostSQLExecution != nil {
if err = b.EngineOverrides.Hooks.DropTable.PostSQLExecution(ctx, n); err != nil {
if err = b.EngineOverrides.Hooks.DropTable.PostSQLExecution(ctx, b.Runner, n); err != nil {
return nil, err
}
}
Expand Down
Loading