Skip to content
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added PostgreSQL `MERGE` statement support with full syntax including:
- `MERGE INTO ... USING ... ON ...` with table aliases and `ONLY` modifier
- `WHEN MATCHED`, `WHEN NOT MATCHED`, `WHEN NOT MATCHED BY SOURCE` clauses
- `UPDATE`, `INSERT`, `DELETE`, `DO NOTHING` actions
- Support for `AND condition` in WHEN clauses
- `OVERRIDING SYSTEM VALUE` and `OVERRIDING USER VALUE` for INSERT actions
- `RETURNING` clause support (PostgreSQL 17+) (thanks @atzedus)
- Added `psql.SetVersion`, `psql.GetVersion`, and `psql.VersionAtLeast` functions for context-based PostgreSQL version management (thanks @atzedus)
- Added `Table.Merge()` method for ORM-style MERGE operations with automatic `RETURNING *` for PostgreSQL 17+ (thanks @atzedus)
- Added `mm` package with modifiers for building MERGE queries (`mm.Into`, `mm.Using`, `mm.WhenMatched`, `mm.WhenNotMatched`, `mm.WhenNotMatchedBySource`, etc.) (thanks @atzedus)
- Added `PreloadCount` and `ThenLoadCount` to generate code for preloading and then loading counts for relationships. (thanks @jacobmolby)
- MySQL support for insert queries executing loaders (e.g., `InsertThenLoad`, `InsertThenLoadCount`). (thanks @jacobmolby)
- Added overwritable hooks that are run before the exec or scanning test of generated queries. This allows seeding data before the test runs.
Expand Down
246 changes: 246 additions & 0 deletions dialect/psql/dialect/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
package dialect

import (
"context"
"io"

"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/clause"
"github.com/stephenafamo/bob/internal"
)

// MergeWhenType represents the type of WHEN clause in MERGE statement
type MergeWhenType string

// MergeWhenType constants for WHEN clause types
const (
MergeWhenMatched MergeWhenType = "MATCHED"
MergeWhenNotMatched MergeWhenType = "NOT MATCHED"
MergeWhenNotMatchedByTarget MergeWhenType = "NOT MATCHED BY TARGET"
MergeWhenNotMatchedBySource MergeWhenType = "NOT MATCHED BY SOURCE"
)

// MergeActionType represents the type of action in WHEN clause
type MergeActionType string

// MergeActionType constants for action types in WHEN clause
const (
MergeActionDoNothing MergeActionType = "DO NOTHING"
MergeActionDelete MergeActionType = "DELETE"
MergeActionInsert MergeActionType = "INSERT"
MergeActionUpdate MergeActionType = "UPDATE"
)

// MergeOverridingType represents the OVERRIDING type in INSERT action
type MergeOverridingType string

// MergeOverridingType constants for OVERRIDING clause in INSERT
const (
MergeOverridingSystem MergeOverridingType = "SYSTEM"
MergeOverridingUser MergeOverridingType = "USER"
)

// MergeQuery Trying to represent the merge query structure as documented in
// https://www.postgresql.org/docs/current/sql-merge.html
type MergeQuery struct {
clause.With
Only bool
Table clause.TableRef
Using MergeUsing
When []MergeWhen
clause.Returning

bob.Load
bob.EmbeddedHook
bob.ContextualModdable[*MergeQuery]
}

func (m MergeQuery) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) {
var err error
var args []any

if ctx, err = m.RunContextualMods(ctx, &m); err != nil {
return nil, err
}

withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.With,
len(m.With.CTEs) > 0, "", "\n")
if err != nil {
return nil, err
}
args = append(args, withArgs...)

w.WriteString("MERGE INTO ")

if m.Only {
w.WriteString("ONLY ")
}

tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Table, true, "", "")
if err != nil {
return nil, err
}
args = append(args, tableArgs...)

usingArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Using,
m.Using.Source != nil, "\n", "")
if err != nil {
return nil, err
}
args = append(args, usingArgs...)

for _, when := range m.When {
whenArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), when, true, "\n", "")
if err != nil {
return nil, err
}
args = append(args, whenArgs...)
}

retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Returning,
len(m.Returning.Expressions) > 0, "\n", "")
if err != nil {
return nil, err
}
args = append(args, retArgs...)

return args, nil
}

// MergeUsing represents the USING clause in a MERGE statement
type MergeUsing struct {
Only bool
Source any // table name or subquery
Alias string
Condition bob.Expression
}

func (u MergeUsing) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) {
w.WriteString("USING ")

if u.Only {
w.WriteString("ONLY ")
}

// Write source (table or subquery)
var sourceArgs []any
var err error
if _, isQuery := u.Source.(bob.Query); isQuery {
w.WriteString("(")
sourceArgs, err = bob.Express(ctx, w, d, start, u.Source)
if err != nil {
return nil, err
}
w.WriteString(")")
} else {
sourceArgs, err = bob.Express(ctx, w, d, start, u.Source)
if err != nil {
return nil, err
}
}

if u.Alias != "" {
w.WriteString(" AS ")
d.WriteQuoted(w, u.Alias)
}

onArgs, err := bob.ExpressIf(ctx, w, d, start+len(sourceArgs), u.Condition,
u.Condition != nil, " ON ", "")
if err != nil {
return nil, err
}

return append(sourceArgs, onArgs...), nil
}

// MergeWhen represents a WHEN clause in a MERGE statement
type MergeWhen struct {
Type MergeWhenType
Condition bob.Expression
Action MergeAction
}

func (w MergeWhen) WriteSQL(ctx context.Context, wr io.StringWriter, d bob.Dialect, start int) ([]any, error) {
wr.WriteString("WHEN ")
wr.WriteString(string(w.Type))

args, err := bob.ExpressIf(ctx, wr, d, start, w.Condition,
w.Condition != nil, " AND ", "")
if err != nil {
return nil, err
}

wr.WriteString(" THEN ")

actionArgs, err := bob.Express(ctx, wr, d, start+len(args), w.Action)
if err != nil {
return nil, err
}
args = append(args, actionArgs...)

return args, nil
}

// MergeAction represents the action in a WHEN clause
type MergeAction struct {
Type MergeActionType
Columns []string
Overriding MergeOverridingType // MergeOverridingType for INSERT
Values []bob.Expression
Set []any
}

func (a MergeAction) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) {
switch a.Type {
case MergeActionDoNothing:
w.WriteString("DO NOTHING")
return nil, nil

case MergeActionDelete:
w.WriteString("DELETE")
return nil, nil

case MergeActionInsert:
w.WriteString("INSERT")

if len(a.Columns) > 0 {
w.WriteString(" (")
for i, col := range a.Columns {
if i > 0 {
w.WriteString(", ")
}
d.WriteQuoted(w, col)
}
w.WriteString(")")
}

if a.Overriding != "" {
w.WriteString(" OVERRIDING ")
w.WriteString(string(a.Overriding))
w.WriteString(" VALUE")
}

if len(a.Values) > 0 {
w.WriteString(" VALUES (")
args, err := bob.ExpressSlice(ctx, w, d, start, a.Values, "", ", ", "")
if err != nil {
return nil, err
}
w.WriteString(")")
return args, nil
}

w.WriteString(" DEFAULT VALUES")
return nil, nil

case MergeActionUpdate:
w.WriteString("UPDATE SET ")
args, err := bob.ExpressSlice(ctx, w, d, start, internal.ToAnySlice(a.Set), "", ", ", "")
if err != nil {
return nil, err
}
return args, nil
}

return nil, nil
}
19 changes: 19 additions & 0 deletions dialect/psql/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package psql

import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/dialect/psql/dialect"
)

func Merge(queryMods ...bob.Mod[*dialect.MergeQuery]) bob.BaseQuery[*dialect.MergeQuery] {
q := &dialect.MergeQuery{}
for _, mod := range queryMods {
mod.Apply(q)
}

return bob.BaseQuery[*dialect.MergeQuery]{
Expression: q,
Dialect: dialect.Dialect,
QueryType: bob.QueryTypeMerge,
}
}
Loading
Loading