diff --git a/CHANGELOG.md b/CHANGELOG.md index a8b7d511..a59df898 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added PostgreSQL `MERGE` statement support with full syntax including: + - `MERGE INTO ... USING ... ON ...` with table aliases and `ONLY` modifier + - `WHEN MATCHED`, `WHEN NOT MATCHED`, `WHEN NOT MATCHED BY SOURCE` clauses + - `UPDATE`, `INSERT`, `DELETE`, `DO NOTHING` actions + - Support for `AND condition` in WHEN clauses + - `OVERRIDING SYSTEM VALUE` and `OVERRIDING USER VALUE` for INSERT actions + - `RETURNING` clause support (PostgreSQL 17+) (thanks @atzedus) +- Added `psql.SetVersion`, `psql.GetVersion`, and `psql.VersionAtLeast` functions for context-based PostgreSQL version management (thanks @atzedus) +- Added `Table.Merge()` method for ORM-style MERGE operations with automatic `RETURNING *` for PostgreSQL 17+ (thanks @atzedus) +- Added `mm` package with modifiers for building MERGE queries (`mm.Into`, `mm.Using`, `mm.WhenMatched`, `mm.WhenNotMatched`, `mm.WhenNotMatchedBySource`, etc.) (thanks @atzedus) - Added `PreloadCount` and `ThenLoadCount` to generate code for preloading and then loading counts for relationships. (thanks @jacobmolby) - MySQL support for insert queries executing loaders (e.g., `InsertThenLoad`, `InsertThenLoadCount`). (thanks @jacobmolby) - Added overwritable hooks that are run before the exec or scanning test of generated queries. This allows seeding data before the test runs. diff --git a/dialect/psql/dialect/merge.go b/dialect/psql/dialect/merge.go new file mode 100644 index 00000000..e2c18b2a --- /dev/null +++ b/dialect/psql/dialect/merge.go @@ -0,0 +1,246 @@ +package dialect + +import ( + "context" + "io" + + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/clause" + "github.com/stephenafamo/bob/internal" +) + +// MergeWhenType represents the type of WHEN clause in MERGE statement +type MergeWhenType string + +// MergeWhenType constants for WHEN clause types +const ( + MergeWhenMatched MergeWhenType = "MATCHED" + MergeWhenNotMatched MergeWhenType = "NOT MATCHED" + MergeWhenNotMatchedByTarget MergeWhenType = "NOT MATCHED BY TARGET" + MergeWhenNotMatchedBySource MergeWhenType = "NOT MATCHED BY SOURCE" +) + +// MergeActionType represents the type of action in WHEN clause +type MergeActionType string + +// MergeActionType constants for action types in WHEN clause +const ( + MergeActionDoNothing MergeActionType = "DO NOTHING" + MergeActionDelete MergeActionType = "DELETE" + MergeActionInsert MergeActionType = "INSERT" + MergeActionUpdate MergeActionType = "UPDATE" +) + +// MergeOverridingType represents the OVERRIDING type in INSERT action +type MergeOverridingType string + +// MergeOverridingType constants for OVERRIDING clause in INSERT +const ( + MergeOverridingSystem MergeOverridingType = "SYSTEM" + MergeOverridingUser MergeOverridingType = "USER" +) + +// MergeQuery Trying to represent the merge query structure as documented in +// https://www.postgresql.org/docs/current/sql-merge.html +type MergeQuery struct { + clause.With + Only bool + Table clause.TableRef + Using MergeUsing + When []MergeWhen + clause.Returning + + bob.Load + bob.EmbeddedHook + bob.ContextualModdable[*MergeQuery] +} + +func (m MergeQuery) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) { + var err error + var args []any + + if ctx, err = m.RunContextualMods(ctx, &m); err != nil { + return nil, err + } + + withArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.With, + len(m.With.CTEs) > 0, "", "\n") + if err != nil { + return nil, err + } + args = append(args, withArgs...) + + w.WriteString("MERGE INTO ") + + if m.Only { + w.WriteString("ONLY ") + } + + tableArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Table, true, "", "") + if err != nil { + return nil, err + } + args = append(args, tableArgs...) + + usingArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Using, + m.Using.Source != nil, "\n", "") + if err != nil { + return nil, err + } + args = append(args, usingArgs...) + + for _, when := range m.When { + whenArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), when, true, "\n", "") + if err != nil { + return nil, err + } + args = append(args, whenArgs...) + } + + retArgs, err := bob.ExpressIf(ctx, w, d, start+len(args), m.Returning, + len(m.Returning.Expressions) > 0, "\n", "") + if err != nil { + return nil, err + } + args = append(args, retArgs...) + + return args, nil +} + +// MergeUsing represents the USING clause in a MERGE statement +type MergeUsing struct { + Only bool + Source any // table name or subquery + Alias string + Condition bob.Expression +} + +func (u MergeUsing) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) { + w.WriteString("USING ") + + if u.Only { + w.WriteString("ONLY ") + } + + // Write source (table or subquery) + var sourceArgs []any + var err error + if _, isQuery := u.Source.(bob.Query); isQuery { + w.WriteString("(") + sourceArgs, err = bob.Express(ctx, w, d, start, u.Source) + if err != nil { + return nil, err + } + w.WriteString(")") + } else { + sourceArgs, err = bob.Express(ctx, w, d, start, u.Source) + if err != nil { + return nil, err + } + } + + if u.Alias != "" { + w.WriteString(" AS ") + d.WriteQuoted(w, u.Alias) + } + + onArgs, err := bob.ExpressIf(ctx, w, d, start+len(sourceArgs), u.Condition, + u.Condition != nil, " ON ", "") + if err != nil { + return nil, err + } + + return append(sourceArgs, onArgs...), nil +} + +// MergeWhen represents a WHEN clause in a MERGE statement +type MergeWhen struct { + Type MergeWhenType + Condition bob.Expression + Action MergeAction +} + +func (w MergeWhen) WriteSQL(ctx context.Context, wr io.StringWriter, d bob.Dialect, start int) ([]any, error) { + wr.WriteString("WHEN ") + wr.WriteString(string(w.Type)) + + args, err := bob.ExpressIf(ctx, wr, d, start, w.Condition, + w.Condition != nil, " AND ", "") + if err != nil { + return nil, err + } + + wr.WriteString(" THEN ") + + actionArgs, err := bob.Express(ctx, wr, d, start+len(args), w.Action) + if err != nil { + return nil, err + } + args = append(args, actionArgs...) + + return args, nil +} + +// MergeAction represents the action in a WHEN clause +type MergeAction struct { + Type MergeActionType + Columns []string + Overriding MergeOverridingType // MergeOverridingType for INSERT + Values []bob.Expression + Set []any +} + +func (a MergeAction) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) { + switch a.Type { + case MergeActionDoNothing: + w.WriteString("DO NOTHING") + return nil, nil + + case MergeActionDelete: + w.WriteString("DELETE") + return nil, nil + + case MergeActionInsert: + w.WriteString("INSERT") + + if len(a.Columns) > 0 { + w.WriteString(" (") + for i, col := range a.Columns { + if i > 0 { + w.WriteString(", ") + } + d.WriteQuoted(w, col) + } + w.WriteString(")") + } + + if a.Overriding != "" { + w.WriteString(" OVERRIDING ") + w.WriteString(string(a.Overriding)) + w.WriteString(" VALUE") + } + + if len(a.Values) > 0 { + w.WriteString(" VALUES (") + args, err := bob.ExpressSlice(ctx, w, d, start, a.Values, "", ", ", "") + if err != nil { + return nil, err + } + w.WriteString(")") + return args, nil + } + + w.WriteString(" DEFAULT VALUES") + return nil, nil + + case MergeActionUpdate: + w.WriteString("UPDATE SET ") + args, err := bob.ExpressSlice(ctx, w, d, start, internal.ToAnySlice(a.Set), "", ", ", "") + if err != nil { + return nil, err + } + return args, nil + } + + return nil, nil +} diff --git a/dialect/psql/merge.go b/dialect/psql/merge.go new file mode 100644 index 00000000..c488a9d5 --- /dev/null +++ b/dialect/psql/merge.go @@ -0,0 +1,19 @@ +package psql + +import ( + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/dialect/psql/dialect" +) + +func Merge(queryMods ...bob.Mod[*dialect.MergeQuery]) bob.BaseQuery[*dialect.MergeQuery] { + q := &dialect.MergeQuery{} + for _, mod := range queryMods { + mod.Apply(q) + } + + return bob.BaseQuery[*dialect.MergeQuery]{ + Expression: q, + Dialect: dialect.Dialect, + QueryType: bob.QueryTypeMerge, + } +} diff --git a/dialect/psql/merge_test.go b/dialect/psql/merge_test.go new file mode 100644 index 00000000..e05b7ac6 --- /dev/null +++ b/dialect/psql/merge_test.go @@ -0,0 +1,507 @@ +package psql_test + +import ( + "testing" + + "github.com/stephenafamo/bob/dialect/psql" + "github.com/stephenafamo/bob/dialect/psql/mm" + "github.com/stephenafamo/bob/dialect/psql/sm" + testutils "github.com/stephenafamo/bob/test/utils" +) + +func TestMerge(t *testing.T) { + examples := testutils.Testcases{ + "simple merge with update and insert": { + Query: psql.Merge( + mm.Into("customer_account"), + mm.Using("recent_transactions").As("t").On( + psql.Quote("t", "customer_id").EQ(psql.Quote("customer_account", "customer_id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("balance").ToExpr( + psql.Raw("balance + ?", psql.Quote("t", "transaction_value")), + ), + ), + ), + mm.WhenNotMatched( + mm.ThenInsert( + mm.Columns("customer_id", "balance"), + mm.Values(psql.Quote("t", "customer_id"), psql.Quote("t", "transaction_value")), + ), + ), + ), + ExpectedSQL: `MERGE INTO customer_account + USING recent_transactions AS "t" ON "t"."customer_id" = "customer_account"."customer_id" + WHEN MATCHED THEN UPDATE SET "balance" = balance + "t"."transaction_value" + WHEN NOT MATCHED THEN INSERT ("customer_id", "balance") VALUES ("t"."customer_id", "t"."transaction_value")`, + }, + "merge with condition": { + Query: psql.Merge( + mm.Into("wines"), + mm.Using("wine_stock_changes").As("s").On( + psql.Quote("s", "winename").EQ(psql.Quote("wines", "winename")), + ), + mm.WhenNotMatched( + mm.And(psql.Quote("s", "stock_delta").GT(psql.Arg(0))), + mm.ThenInsert( + mm.Values(psql.Quote("s", "winename"), psql.Quote("s", "stock_delta")), + ), + ), + mm.WhenMatched( + mm.And(psql.Raw("w.stock + s.stock_delta > 0")), + mm.ThenUpdate( + mm.SetCol("stock").ToExpr(psql.Raw("w.stock + s.stock_delta")), + ), + ), + mm.WhenMatched( + mm.ThenDelete(), + ), + ), + ExpectedSQL: `MERGE INTO wines + USING wine_stock_changes AS "s" ON "s"."winename" = "wines"."winename" + WHEN NOT MATCHED AND "s"."stock_delta" > $1 THEN INSERT VALUES ("s"."winename", "s"."stock_delta") + WHEN MATCHED AND w.stock + s.stock_delta > 0 THEN UPDATE SET "stock" = w.stock + s.stock_delta + WHEN MATCHED THEN DELETE`, + ExpectedArgs: []any{0}, + }, + "merge with do nothing": { + Query: psql.Merge( + mm.Into("target"), + mm.Using("source").As("s").On( + psql.Quote("s", "id").EQ(psql.Quote("target", "id")), + ), + mm.WhenMatched( + mm.ThenDoNothing(), + ), + mm.WhenNotMatched( + mm.ThenDoNothing(), + ), + ), + ExpectedSQL: `MERGE INTO target + USING source AS "s" ON "s"."id" = "target"."id" + WHEN MATCHED THEN DO NOTHING + WHEN NOT MATCHED THEN DO NOTHING`, + }, + "merge with target alias": { + Query: psql.Merge( + mm.IntoAs("wines", "w"), + mm.Using("new_wine_list").As("s").On( + psql.Quote("s", "winename").EQ(psql.Quote("w", "winename")), + ), + mm.WhenNotMatchedByTarget( + mm.ThenInsert( + mm.Values(psql.Quote("s", "winename"), psql.Quote("s", "stock")), + ), + ), + mm.WhenMatched( + mm.And(psql.Quote("w", "stock").NE(psql.Quote("s", "stock"))), + mm.ThenUpdate( + mm.SetCol("stock").ToExpr(psql.Quote("s", "stock")), + ), + ), + mm.WhenNotMatchedBySource( + mm.ThenDelete(), + ), + ), + ExpectedSQL: `MERGE INTO wines AS "w" + USING new_wine_list AS "s" ON "s"."winename" = "w"."winename" + WHEN NOT MATCHED BY TARGET THEN INSERT VALUES ("s"."winename", "s"."stock") + WHEN MATCHED AND "w"."stock" <> "s"."stock" THEN UPDATE SET "stock" = "s"."stock" + WHEN NOT MATCHED BY SOURCE THEN DELETE`, + }, + "merge with returning": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("product_updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("price").ToExpr(psql.Quote("u", "price")), + ), + ), + mm.WhenNotMatched( + mm.ThenInsert( + mm.Columns("id", "name", "price"), + mm.Values(psql.Quote("u", "id"), psql.Quote("u", "name"), psql.Quote("u", "price")), + ), + ), + mm.Returning("*"), + ), + ExpectedSQL: `MERGE INTO products + USING product_updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET "price" = "u"."price" + WHEN NOT MATCHED THEN INSERT ("id", "name", "price") VALUES ("u"."id", "u"."name", "u"."price") + RETURNING *`, + }, + "merge with subquery as source": { + Query: psql.Merge( + mm.Into("target_table"), + mm.UsingQuery(psql.Select()).As("src").On( + psql.Quote("src", "id").EQ(psql.Quote("target_table", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("value").ToExpr(psql.Quote("src", "value")), + ), + ), + ), + ExpectedSQL: `MERGE INTO target_table + USING (SELECT *) AS "src" ON "src"."id" = "target_table"."id" + WHEN MATCHED THEN UPDATE SET "value" = "src"."value"`, + }, + "merge with multiple SetCol from source": { + Query: psql.Merge( + mm.Into("employees"), + mm.Using("employee_updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("employees", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(psql.Quote("u", "name")), + mm.SetCol("salary").ToExpr(psql.Quote("u", "salary")), + mm.SetCol("department").ToExpr(psql.Quote("u", "department")), + ), + ), + ), + ExpectedSQL: `MERGE INTO employees + USING employee_updates AS "u" ON "u"."id" = "employees"."id" + WHEN MATCHED THEN UPDATE SET "name" = "u"."name", "salary" = "u"."salary", "department" = "u"."department"`, + }, + "merge with SetCol ToDefault": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("updated_at").ToDefault(), + mm.SetCol("name").ToExpr(psql.Quote("u", "name")), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET "updated_at" = DEFAULT, "name" = "u"."name"`, + }, + "merge with SetCols ToRow": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCols("name", "price").ToRow( + psql.Quote("u", "name"), + psql.Quote("u", "price"), + ), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET ("name", "price") = ROW ("u"."name", "u"."price")`, + }, + "merge with SetCols ToQuery": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCols("name", "price").ToQuery( + psql.Select( + sm.Columns("name", "price"), + sm.From("default_values"), + sm.Where(psql.Quote("category").EQ(psql.Quote("u", "category"))), + ), + ), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET ("name", "price") = (SELECT "name", "price" FROM default_values WHERE (category = u.category))`, + }, + "merge with CTE": { + Query: psql.Merge( + mm.With("source_data").As(psql.Select( + sm.Columns("id", "value"), + sm.From("temp_table"), + )), + mm.Into("target"), + mm.Using("source_data").As("s").On( + psql.Quote("s", "id").EQ(psql.Quote("target", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("value").ToExpr(psql.Quote("s", "value")), + ), + ), + ), + ExpectedSQL: `WITH source_data AS (SELECT "id", "value" FROM temp_table) + MERGE INTO target + USING source_data AS "s" ON "s"."id" = "target"."id" + WHEN MATCHED THEN UPDATE SET "value" = "s"."value"`, + }, + "merge with INSERT DEFAULT VALUES": { + Query: psql.Merge( + mm.Into("audit_log"), + mm.Using("events").As("e").On( + psql.Quote("e", "id").EQ(psql.Quote("audit_log", "event_id")), + ), + mm.WhenNotMatched( + mm.ThenInsertDefaultValues(), + ), + ), + ExpectedSQL: `MERGE INTO audit_log + USING events AS "e" ON "e"."id" = "audit_log"."event_id" + WHEN NOT MATCHED THEN INSERT DEFAULT VALUES`, + }, + "merge with OVERRIDING SYSTEM VALUE": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("new_products").As("n").On( + psql.Quote("n", "sku").EQ(psql.Quote("products", "sku")), + ), + mm.WhenNotMatched( + mm.ThenInsert( + mm.Columns("id", "sku", "name"), + mm.OverridingSystem(), + mm.Values(psql.Quote("n", "id"), psql.Quote("n", "sku"), psql.Quote("n", "name")), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING new_products AS "n" ON "n"."sku" = "products"."sku" + WHEN NOT MATCHED THEN INSERT ("id", "sku", "name") OVERRIDING SYSTEM VALUE VALUES ("n"."id", "n"."sku", "n"."name")`, + }, + "merge with OVERRIDING USER VALUE": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("new_products").As("n").On( + psql.Quote("n", "sku").EQ(psql.Quote("products", "sku")), + ), + mm.WhenNotMatched( + mm.ThenInsert( + mm.Columns("id", "sku", "name"), + mm.OverridingUser(), + mm.Values(psql.Quote("n", "id"), psql.Quote("n", "sku"), psql.Quote("n", "name")), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING new_products AS "n" ON "n"."sku" = "products"."sku" + WHEN NOT MATCHED THEN INSERT ("id", "sku", "name") OVERRIDING USER VALUE VALUES ("n"."id", "n"."sku", "n"."name")`, + }, + "merge with Recursive CTE": { + Query: psql.Merge( + mm.With("hierarchy").As(psql.Select( + sm.Columns("id", "parent_id", "name"), + sm.From("categories"), + )), + mm.Recursive(true), + mm.Into("target"), + mm.Using("hierarchy").As("h").On( + psql.Quote("h", "id").EQ(psql.Quote("target", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(psql.Quote("h", "name")), + ), + ), + ), + ExpectedSQL: `WITH RECURSIVE hierarchy AS (SELECT "id", "parent_id", "name" FROM categories) + MERGE INTO target + USING hierarchy AS "h" ON "h"."id" = "target"."id" + WHEN MATCHED THEN UPDATE SET "name" = "h"."name"`, + }, + "merge with Only target": { + Query: psql.Merge( + mm.Into("parent_table"), + mm.Only(), + mm.Using("source").As("s").On( + psql.Quote("s", "id").EQ(psql.Quote("parent_table", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("value").ToExpr(psql.Quote("s", "value")), + ), + ), + ), + ExpectedSQL: `MERGE INTO ONLY parent_table + USING source AS "s" ON "s"."id" = "parent_table"."id" + WHEN MATCHED THEN UPDATE SET "value" = "s"."value"`, + }, + "merge with Only source": { + Query: psql.Merge( + mm.Into("target"), + mm.Using("parent_source").Only().As("s").On( + psql.Quote("s", "id").EQ(psql.Quote("target", "id")), + ), + mm.WhenMatched( + mm.ThenDelete(), + ), + ), + ExpectedSQL: `MERGE INTO target + USING ONLY parent_source AS "s" ON "s"."id" = "target"."id" + WHEN MATCHED THEN DELETE`, + }, + "merge with OnEQ shortcut": { + Query: psql.Merge( + mm.Into("target"), + mm.Using("source").As("s").OnEQ( + psql.Quote("s", "id"), + psql.Quote("target", "id"), + ), + mm.WhenMatched( + mm.ThenDoNothing(), + ), + ), + ExpectedSQL: `MERGE INTO target + USING source AS "s" ON "s"."id" = "target"."id" + WHEN MATCHED THEN DO NOTHING`, + }, + "merge with SetCol To raw value": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("status").To(psql.Raw("'active'")), + mm.SetCol("counter").To(psql.Raw("counter + 1")), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET "status" = 'active', "counter" = counter + 1`, + }, + "merge with SetCol ToArg": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("status").ToArg("active"), + mm.SetCol("quantity").ToArg(100), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET "status" = $1, "quantity" = $2`, + ExpectedArgs: []any{"active", 100}, + }, + "merge with Set raw expressions": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.Set( + psql.Raw(`"name" = "u"."name"`), + psql.Raw(`"price" = "u"."price" * 1.1`), + ), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET "name" = "u"."name", "price" = "u"."price" * 1.1`, + }, + "merge with SetCols ToExprs without ROW": { + Query: psql.Merge( + mm.Into("products"), + mm.Using("updates").As("u").On( + psql.Quote("u", "id").EQ(psql.Quote("products", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCols("name", "price").ToExprs( + psql.Quote("u", "name"), + psql.Quote("u", "price"), + ), + ), + ), + ), + ExpectedSQL: `MERGE INTO products + USING updates AS "u" ON "u"."id" = "products"."id" + WHEN MATCHED THEN UPDATE SET ("name", "price") = ("u"."name", "u"."price")`, + }, + "merge with multiple WHEN clauses and conditions": { + Query: psql.Merge( + mm.Into("inventory"), + mm.Using("stock_updates").As("s").On( + psql.Quote("s", "product_id").EQ(psql.Quote("inventory", "product_id")), + ), + mm.WhenMatched( + mm.And(psql.Quote("s", "quantity").EQ(psql.Arg(0))), + mm.ThenDelete(), + ), + mm.WhenMatched( + mm.And(psql.Quote("s", "quantity").GT(psql.Arg(0))), + mm.ThenUpdate( + mm.SetCol("quantity").ToExpr(psql.Quote("s", "quantity")), + mm.SetCol("updated_at").ToDefault(), + ), + ), + mm.WhenNotMatchedByTarget( + mm.And(psql.Quote("s", "quantity").GT(psql.Arg(0))), + mm.ThenInsert( + mm.Columns("product_id", "quantity"), + mm.Values(psql.Quote("s", "product_id"), psql.Quote("s", "quantity")), + ), + ), + mm.WhenNotMatchedBySource( + mm.ThenUpdate( + mm.SetCol("quantity").ToArg(0), + ), + ), + mm.Returning("*"), + ), + ExpectedSQL: `MERGE INTO inventory + USING stock_updates AS "s" ON "s"."product_id" = "inventory"."product_id" + WHEN MATCHED AND "s"."quantity" = $1 THEN DELETE + WHEN MATCHED AND "s"."quantity" > $2 THEN UPDATE SET "quantity" = "s"."quantity", "updated_at" = DEFAULT + WHEN NOT MATCHED BY TARGET AND "s"."quantity" > $3 THEN INSERT ("product_id", "quantity") VALUES ("s"."product_id", "s"."quantity") + WHEN NOT MATCHED BY SOURCE THEN UPDATE SET "quantity" = $4 + RETURNING *`, + ExpectedArgs: []any{0, 0, 0, 0}, + }, + "merge with CTE columns": { + Query: psql.Merge( + mm.With("source_data", "id", "name", "value").As(psql.Select( + sm.Columns("product_id", "product_name", "price"), + sm.From("products"), + )), + mm.Into("target"), + mm.Using("source_data").As("s").On( + psql.Quote("s", "id").EQ(psql.Quote("target", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(psql.Quote("s", "name")), + mm.SetCol("value").ToExpr(psql.Quote("s", "value")), + ), + ), + ), + ExpectedSQL: `WITH source_data ("id", "name", "value") AS (SELECT "product_id", "product_name", "price" FROM products) + MERGE INTO target + USING source_data AS "s" ON "s"."id" = "target"."id" + WHEN MATCHED THEN UPDATE SET "name" = "s"."name", "value" = "s"."value"`, + }, + } + + testutils.RunTests(t, examples, formatter) +} diff --git a/dialect/psql/mm/qm.go b/dialect/psql/mm/qm.go new file mode 100644 index 00000000..32ae9cab --- /dev/null +++ b/dialect/psql/mm/qm.go @@ -0,0 +1,430 @@ +package mm + +import ( + "context" + "io" + + "github.com/stephenafamo/bob" + "github.com/stephenafamo/bob/clause" + "github.com/stephenafamo/bob/dialect/psql/dialect" + "github.com/stephenafamo/bob/expr" + "github.com/stephenafamo/bob/internal" + "github.com/stephenafamo/bob/mods" +) + +// rowAssignment represents (columns...) = [ROW] (values...) +type rowAssignment struct { + cols []bob.Expression + values []bob.Expression + isRow bool +} + +func (r rowAssignment) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) { + // Write (col1, col2, ...) + w.WriteString("(") + colArgs, err := bob.ExpressSlice(ctx, w, d, start, r.cols, "", ", ", "") + if err != nil { + return nil, err + } + + w.WriteString(") = ") + + if r.isRow { + w.WriteString("ROW ") + } + + // Write (val1, val2, ...) + w.WriteString("(") + valArgs, err := bob.ExpressSlice(ctx, w, d, start+len(colArgs), r.values, "", ", ", "") + if err != nil { + return nil, err + } + w.WriteString(")") + + return append(colArgs, valArgs...), nil +} + +// queryAssignment represents (columns...) = (subquery) +type queryAssignment struct { + cols []bob.Expression + query bob.Query +} + +func (q queryAssignment) WriteSQL(ctx context.Context, w io.StringWriter, d bob.Dialect, start int) ([]any, error) { + // Write (col1, col2, ...) + w.WriteString("(") + colArgs, err := bob.ExpressSlice(ctx, w, d, start, q.cols, "", ", ", "") + if err != nil { + return nil, err + } + + w.WriteString(") = (") + + // Write subquery + queryArgs, err := bob.Express(ctx, w, d, start+len(colArgs), q.query) + if err != nil { + return nil, err + } + w.WriteString(")") + + return append(colArgs, queryArgs...), nil +} + +func With(name string, columns ...string) dialect.CTEChain[*dialect.MergeQuery] { + return dialect.With[*dialect.MergeQuery](name, columns...) +} + +func Recursive(r bool) bob.Mod[*dialect.MergeQuery] { + return mods.Recursive[*dialect.MergeQuery](r) +} + +// Into specifies the target table for the MERGE statement +func Into(name any) bob.Mod[*dialect.MergeQuery] { + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.Table = clause.TableRef{ + Expression: name, + } + }) +} + +// IntoAs specifies the target table with an alias for the MERGE statement +func IntoAs(name any, alias string) bob.Mod[*dialect.MergeQuery] { + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.Table = clause.TableRef{ + Expression: name, + Alias: alias, + } + }) +} + +// Only specifies ONLY modifier for the target table +func Only() bob.Mod[*dialect.MergeQuery] { + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.Only = true + }) +} + +// Using specifies the data source for the MERGE statement +func Using(source any) UsingChain { + return UsingChain{source: source} +} + +// UsingQuery specifies a subquery as the data source for the MERGE statement +func UsingQuery(q bob.Query) UsingChain { + return UsingChain{source: q} +} + +// UsingChain is a chain for building the USING clause +type UsingChain struct { + source any + alias string + only bool +} + +func (u UsingChain) As(alias string) UsingChain { + u.alias = alias + return u +} + +func (u UsingChain) Only() UsingChain { + u.only = true + return u +} + +func (u UsingChain) On(condition bob.Expression) bob.Mod[*dialect.MergeQuery] { + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.Using = dialect.MergeUsing{ + Only: u.only, + Source: u.source, + Alias: u.alias, + Condition: condition, + } + }) +} + +func (u UsingChain) OnEQ(left, right bob.Expression) bob.Mod[*dialect.MergeQuery] { + return u.On(expr.X[dialect.Expression, dialect.Expression](left).EQ(right)) +} + +// WhenMatched creates a WHEN MATCHED clause +func WhenMatched(mods ...bob.Mod[*WhenClause]) bob.Mod[*dialect.MergeQuery] { + wc := &WhenClause{Type: dialect.MergeWhenMatched} + for _, mod := range mods { + mod.Apply(wc) + } + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.When = append(m.When, dialect.MergeWhen{ + Type: wc.Type, + Condition: wc.Condition, + Action: wc.Action, + }) + }) +} + +// WhenNotMatched creates a WHEN NOT MATCHED (BY TARGET) clause +func WhenNotMatched(mods ...bob.Mod[*WhenClause]) bob.Mod[*dialect.MergeQuery] { + wc := &WhenClause{Type: dialect.MergeWhenNotMatched} + for _, mod := range mods { + mod.Apply(wc) + } + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.When = append(m.When, dialect.MergeWhen{ + Type: wc.Type, + Condition: wc.Condition, + Action: wc.Action, + }) + }) +} + +// WhenNotMatchedByTarget is an alias for WhenNotMatched with explicit BY TARGET +func WhenNotMatchedByTarget(mods ...bob.Mod[*WhenClause]) bob.Mod[*dialect.MergeQuery] { + wc := &WhenClause{Type: dialect.MergeWhenNotMatchedByTarget} + for _, mod := range mods { + mod.Apply(wc) + } + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.When = append(m.When, dialect.MergeWhen{ + Type: wc.Type, + Condition: wc.Condition, + Action: wc.Action, + }) + }) +} + +// WhenNotMatchedBySource creates a WHEN NOT MATCHED BY SOURCE clause +func WhenNotMatchedBySource(mods ...bob.Mod[*WhenClause]) bob.Mod[*dialect.MergeQuery] { + wc := &WhenClause{Type: dialect.MergeWhenNotMatchedBySource} + for _, mod := range mods { + mod.Apply(wc) + } + return bob.ModFunc[*dialect.MergeQuery](func(m *dialect.MergeQuery) { + m.When = append(m.When, dialect.MergeWhen{ + Type: wc.Type, + Condition: wc.Condition, + Action: wc.Action, + }) + }) +} + +// WhenClause is a builder for WHEN clauses +type WhenClause struct { + Type dialect.MergeWhenType + Condition bob.Expression + Action dialect.MergeAction +} + +// And adds a condition to the WHEN clause +func And(condition bob.Expression) bob.Mod[*WhenClause] { + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + if w.Condition == nil { + w.Condition = condition + } else { + w.Condition = expr.X[dialect.Expression, dialect.Expression](w.Condition).And(condition) + } + }) +} + +// ThenDoNothing sets the action to DO NOTHING +func ThenDoNothing() bob.Mod[*WhenClause] { + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + w.Action = dialect.MergeAction{Type: dialect.MergeActionDoNothing} + }) +} + +// ThenDelete sets the action to DELETE +func ThenDelete() bob.Mod[*WhenClause] { + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + w.Action = dialect.MergeAction{Type: dialect.MergeActionDelete} + }) +} + +// ThenUpdate sets the action to UPDATE with SET clauses +// Supports MERGE UPDATE syntax: +// - column = expression +// - column = DEFAULT +// - (columns...) = ROW (expressions...) +// - (columns...) = (subquery) +func ThenUpdate(sets ...bob.Mod[*UpdateAction]) bob.Mod[*WhenClause] { + ua := &UpdateAction{} + for _, s := range sets { + s.Apply(ua) + } + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + w.Action = dialect.MergeAction{ + Type: dialect.MergeActionUpdate, + Set: ua.Set, + } + }) +} + +// UpdateAction is a builder for UPDATE action in MERGE +type UpdateAction struct { + Set []any +} + +// Set adds raw SET expressions to the UPDATE action +func Set(sets ...bob.Expression) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + u.Set = append(u.Set, internal.ToAnySlice(sets)...) + }) +} + +// SetCol creates a single column setter: column = expression | DEFAULT +func SetCol(column string) SetChain { + return SetChain{column: column} +} + +// SetChain is a chain for building SET column = value +type SetChain struct { + column string +} + +// To sets column to a raw value: column = value +func (s SetChain) To(value any) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + u.Set = append(u.Set, expr.OP("=", expr.Quote(s.column), value)) + }) +} + +// ToArg sets column to a parameterized value: column = $N +func (s SetChain) ToArg(value any) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + u.Set = append(u.Set, expr.OP("=", expr.Quote(s.column), expr.Arg(value))) + }) +} + +// ToExpr sets column to an expression: column = expression +// Use psql.Quote("source_alias", "column") to reference source columns +// Use psql.Quote("target_alias", "column") to reference target columns +func (s SetChain) ToExpr(e bob.Expression) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + u.Set = append(u.Set, expr.OP("=", expr.Quote(s.column), e)) + }) +} + +// ToDefault sets column to DEFAULT: column = DEFAULT +func (s SetChain) ToDefault() bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + u.Set = append(u.Set, expr.OP("=", expr.Quote(s.column), expr.Raw("DEFAULT"))) + }) +} + +// SetCols creates a multi-column setter: (columns...) = ROW(...) | (subquery) +func SetCols(columns ...string) SetColsChain { + return SetColsChain{columns: columns} +} + +// SetColsChain is a chain for building SET (columns...) = ROW(...) | (subquery) +type SetColsChain struct { + columns []string +} + +// ToRow sets columns to ROW of expressions: (columns...) = ROW (expressions...) +func (s SetColsChain) ToRow(values ...bob.Expression) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + // Build (col1, col2, ...) = ROW (val1, val2, ...) + cols := make([]bob.Expression, len(s.columns)) + for i, c := range s.columns { + cols[i] = expr.Quote(c) + } + u.Set = append(u.Set, rowAssignment{cols: cols, values: values, isRow: true}) + }) +} + +// ToExprs sets columns to expressions without ROW: (columns...) = (expressions...) +func (s SetColsChain) ToExprs(values ...bob.Expression) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + cols := make([]bob.Expression, len(s.columns)) + for i, c := range s.columns { + cols[i] = expr.Quote(c) + } + u.Set = append(u.Set, rowAssignment{cols: cols, values: values, isRow: false}) + }) +} + +// ToQuery sets columns from a subquery: (columns...) = (subquery) +func (s SetColsChain) ToQuery(q bob.Query) bob.Mod[*UpdateAction] { + return bob.ModFunc[*UpdateAction](func(u *UpdateAction) { + cols := make([]bob.Expression, len(s.columns)) + for i, c := range s.columns { + cols[i] = expr.Quote(c) + } + u.Set = append(u.Set, queryAssignment{cols: cols, query: q}) + }) +} + +// ThenInsert sets the action to INSERT +// Use with Columns(), Values(), OverridingSystem(), OverridingUser() modifiers +// If no Values() is specified, DEFAULT VALUES will be used +func ThenInsert(mods ...bob.Mod[*InsertAction]) bob.Mod[*WhenClause] { + ia := &InsertAction{} + for _, mod := range mods { + mod.Apply(ia) + } + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + w.Action = dialect.MergeAction{ + Type: dialect.MergeActionInsert, + Columns: ia.Columns, + Values: ia.Values, + Overriding: ia.Overriding, + } + }) +} + +// ThenInsertDefaultValues sets the action to INSERT DEFAULT VALUES (shortcut) +func ThenInsertDefaultValues() bob.Mod[*WhenClause] { + return bob.ModFunc[*WhenClause](func(w *WhenClause) { + w.Action = dialect.MergeAction{ + Type: dialect.MergeActionInsert, + // Empty Values signals DEFAULT VALUES + } + }) +} + +// InsertAction is a builder for INSERT action in MERGE +// Supports: INSERT [(columns...)] [OVERRIDING {SYSTEM|USER} VALUE] {VALUES (...) | DEFAULT VALUES} +type InsertAction struct { + Columns []string + Values []bob.Expression + Overriding dialect.MergeOverridingType +} + +// Columns specifies the target columns for INSERT action +// Column names can include subfield names or array subscripts if needed +func Columns(columns ...string) bob.Mod[*InsertAction] { + return bob.ModFunc[*InsertAction](func(i *InsertAction) { + i.Columns = append(i.Columns, columns...) + }) +} + +// Values specifies the values for INSERT action +// Expressions can reference source data columns (for WHEN NOT MATCHED BY TARGET) +// Use psql.Quote("source_alias", "column") to reference source columns +// Use psql.Arg(value) for literal values +// Use expr.Raw("DEFAULT") for DEFAULT keyword +func Values(values ...bob.Expression) bob.Mod[*InsertAction] { + return bob.ModFunc[*InsertAction](func(i *InsertAction) { + i.Values = append(i.Values, values...) + }) +} + +// OverridingSystem adds OVERRIDING SYSTEM VALUE for INSERT action +// Use when inserting into identity columns defined as GENERATED ALWAYS +func OverridingSystem() bob.Mod[*InsertAction] { + return bob.ModFunc[*InsertAction](func(i *InsertAction) { + i.Overriding = dialect.MergeOverridingSystem + }) +} + +// OverridingUser adds OVERRIDING USER VALUE for INSERT action +// Use when identity columns defined as GENERATED BY DEFAULT should use sequence values +func OverridingUser() bob.Mod[*InsertAction] { + return bob.ModFunc[*InsertAction](func(i *InsertAction) { + i.Overriding = dialect.MergeOverridingUser + }) +} + +// Returning adds a RETURNING clause +func Returning(clauses ...any) bob.Mod[*dialect.MergeQuery] { + return mods.Returning[*dialect.MergeQuery](clauses) +} diff --git a/dialect/psql/table.go b/dialect/psql/table.go index 607cbe81..63cc5f03 100644 --- a/dialect/psql/table.go +++ b/dialect/psql/table.go @@ -8,6 +8,7 @@ import ( "github.com/stephenafamo/bob/dialect/psql/dialect" "github.com/stephenafamo/bob/dialect/psql/dm" "github.com/stephenafamo/bob/dialect/psql/im" + "github.com/stephenafamo/bob/dialect/psql/mm" "github.com/stephenafamo/bob/dialect/psql/um" "github.com/stephenafamo/bob/expr" "github.com/stephenafamo/bob/internal" @@ -20,6 +21,7 @@ type ( ormInsertQuery[T any, Tslice ~[]T] = orm.Query[*dialect.InsertQuery, T, Tslice, bob.SliceTransformer[T, Tslice]] ormUpdateQuery[T any, Tslice ~[]T] = orm.Query[*dialect.UpdateQuery, T, Tslice, bob.SliceTransformer[T, Tslice]] ormDeleteQuery[T any, Tslice ~[]T] = orm.Query[*dialect.DeleteQuery, T, Tslice, bob.SliceTransformer[T, Tslice]] + ormMergeQuery[T any, Tslice ~[]T] = orm.Query[*dialect.MergeQuery, T, Tslice, bob.SliceTransformer[T, Tslice]] ) func NewTable[T any, Tset setter[T], C bob.Expression](schema, tableName string, columns C) *Table[T, []T, Tset, C] { @@ -56,9 +58,13 @@ type Table[T any, Tslice ~[]T, Tset setter[T], C bob.Expression] struct { BeforeDeleteHooks bob.Hooks[Tslice, bob.SkipModelHooksKey] AfterDeleteHooks bob.Hooks[Tslice, bob.SkipModelHooksKey] + BeforeMergeHooks bob.Hooks[Tslice, bob.SkipModelHooksKey] + AfterMergeHooks bob.Hooks[Tslice, bob.SkipModelHooksKey] + InsertQueryHooks bob.Hooks[*dialect.InsertQuery, bob.SkipQueryHooksKey] UpdateQueryHooks bob.Hooks[*dialect.UpdateQuery, bob.SkipQueryHooksKey] DeleteQueryHooks bob.Hooks[*dialect.DeleteQuery, bob.SkipQueryHooksKey] + MergeQueryHooks bob.Hooks[*dialect.MergeQuery, bob.SkipQueryHooksKey] } // Returns the primary key columns for this table. @@ -137,3 +143,32 @@ func (t *Table[T, Tslice, Tset, C]) Delete(queryMods ...bob.Mod[*dialect.DeleteQ return q } + +// Starts a Merge query for this table +// The caller must provide USING and WHEN clauses via queryMods +// RETURNING clause is automatically added if version >= 17 is set in context. +// Use psql.SetVersion(ctx, 17) to enable automatic RETURNING for MERGE. +// For older versions, use mm.Returning() explicitly if needed. +func (t *Table[T, Tslice, Tset, C]) Merge(queryMods ...bob.Mod[*dialect.MergeQuery]) *ormMergeQuery[T, Tslice] { + q := &ormMergeQuery[T, Tslice]{ + ExecQuery: orm.ExecQuery[*dialect.MergeQuery]{ + BaseQuery: Merge(mm.Into(t.NameAs())), + Hooks: &t.MergeQueryHooks, + }, + Scanner: t.scanner, + } + + q.Expression.AppendContextualModFunc( + func(ctx context.Context, q *dialect.MergeQuery) (context.Context, error) { + // RETURNING in MERGE requires version 17+ + if VersionAtLeast(ctx, 17) && !q.HasReturning() { + q.AppendReturning(t.Columns) + } + return ctx, nil + }, + ) + + q.Apply(queryMods...) + + return q +} diff --git a/dialect/psql/table_test.go b/dialect/psql/table_test.go index f38a184d..a078ddce 100644 --- a/dialect/psql/table_test.go +++ b/dialect/psql/table_test.go @@ -7,12 +7,14 @@ import ( "log" "os" "os/signal" + "strings" "syscall" "testing" _ "github.com/lib/pq" "github.com/stephenafamo/bob" "github.com/stephenafamo/bob/dialect/psql/dialect" + "github.com/stephenafamo/bob/dialect/psql/mm" "github.com/stephenafamo/bob/dialect/psql/um" "github.com/stephenafamo/bob/expr" "github.com/stephenafamo/bob/internal" @@ -179,3 +181,211 @@ func TestUpdate(t *testing.T) { t.Fatalf("unexpected retrieved user: %#v: %v", *user, err) } } + +func TestMerge(t *testing.T) { + ctx := t.Context() + + tx, err := testDB.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("could not begin transaction: %v", err) + return + } + defer func() { _ = tx.Rollback(ctx) }() + + // Create users table + _, err = tx.ExecContext(ctx, `CREATE TABLE users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE NOT NULL + )`) + if err != nil { + t.Fatalf("could not create users table: %v", err) + } + + // Create source table for merge + _, err = tx.ExecContext(ctx, `CREATE TABLE user_updates ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT NOT NULL + )`) + if err != nil { + t.Fatalf("could not create user_updates table: %v", err) + } + + // Insert initial user + _, err = tx.ExecContext(ctx, `INSERT INTO users (id, name, email) VALUES (1, 'Alice', 'alice@example.com')`) + if err != nil { + t.Fatalf("could not insert user: %v", err) + } + + // Insert updates (one existing, one new) + _, err = tx.ExecContext(ctx, `INSERT INTO user_updates (id, name, email) VALUES + (1, 'Alice Smith', 'alice.smith@example.com'), + (2, 'Bob', 'bob@example.com')`) + if err != nil { + t.Fatalf("could not insert user_updates: %v", err) + } + + // Execute MERGE using table's Merge method + mergeQuery := userTable.Merge( + mm.Using("user_updates").As("u").On( + Quote("u", "id").EQ(Quote("users", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(Quote("u", "name")), + mm.SetCol("email").ToExpr(Quote("u", "email")), + ), + ), + mm.WhenNotMatched( + mm.ThenInsert( + mm.Columns("id", "name", "email"), + mm.Values(Quote("u", "id"), Quote("u", "name"), Quote("u", "email")), + ), + ), + ) + + // Get the SQL for debugging + sql, args, err := bob.Build(ctx, mergeQuery) + if err != nil { + t.Fatalf("could not build merge query: %v", err) + } + t.Logf("MERGE SQL: %s", sql) + t.Logf("MERGE Args: %v", args) + + // Execute the merge + _, err = mergeQuery.Exec(ctx, tx) + if err != nil { + t.Fatalf("could not execute merge: %v", err) + } + + // Verify user 1 was updated + q := "SELECT * FROM users WHERE id = $1" + user, err := scan.One(ctx, tx, scan.StructMapper[*User](), q, 1) + if err != nil { + t.Fatalf("could not get user 1: %v", err) + } + + if *user != (User{ + ID: 1, + Name: "Alice Smith", + Email: "alice.smith@example.com", + }) { + t.Errorf("unexpected user 1 after merge: %#v", *user) + } + + // Verify user 2 was inserted + user, err = scan.One(ctx, tx, scan.StructMapper[*User](), q, 2) + if err != nil { + t.Fatalf("could not get user 2: %v", err) + } + + if *user != (User{ + ID: 2, + Name: "Bob", + Email: "bob@example.com", + }) { + t.Errorf("unexpected user 2 after merge: %#v", *user) + } + + // Verify total count + var count int + err = tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM users").Scan(&count) + if err != nil { + t.Fatalf("could not count users: %v", err) + } + if count != 2 { + t.Errorf("expected 2 users, got %d", count) + } +} + +func TestTableMergeWithVersion(t *testing.T) { + // Use the existing userTable from the test file + + t.Run("version 17+ adds RETURNING automatically", func(t *testing.T) { + ctx := context.Background() + ctx = SetVersion(ctx, 17) + + mergeQuery := userTable.Merge( + mm.Using("source").As("s").On( + Quote("s", "id").EQ(Quote("users", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(Quote("s", "name")), + ), + ), + ) + + sql, args, err := bob.Build(ctx, mergeQuery) + if err != nil { + t.Fatalf("error: %v", err) + } + + // Should contain RETURNING because version is 17+ + if !strings.Contains(sql, "RETURNING") { + t.Errorf("expected RETURNING clause for version 17+, got: %s", sql) + } + if len(args) != 0 { + t.Errorf("expected no args, got %v", args) + } + }) + + t.Run("version below 17 does not add RETURNING automatically", func(t *testing.T) { + ctx := context.Background() + ctx = SetVersion(ctx, 16) + + mergeQuery := userTable.Merge( + mm.Using("source").As("s").On( + Quote("s", "id").EQ(Quote("users", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(Quote("s", "name")), + ), + ), + ) + + sql, args, err := bob.Build(ctx, mergeQuery) + if err != nil { + t.Fatalf("error: %v", err) + } + + // Should NOT contain RETURNING because version is below 17 + if strings.Contains(sql, "RETURNING") { + t.Errorf("expected no RETURNING clause for version 16, got: %s", sql) + } + if len(args) != 0 { + t.Errorf("expected no args, got %v", args) + } + }) + + t.Run("no version set does not add RETURNING automatically", func(t *testing.T) { + ctx := context.Background() + // No version set + + mergeQuery := userTable.Merge( + mm.Using("source").As("s").On( + Quote("s", "id").EQ(Quote("users", "id")), + ), + mm.WhenMatched( + mm.ThenUpdate( + mm.SetCol("name").ToExpr(Quote("s", "name")), + ), + ), + ) + + sql, args, err := bob.Build(ctx, mergeQuery) + if err != nil { + t.Fatalf("error: %v", err) + } + + // Should NOT contain RETURNING because no version set + if strings.Contains(sql, "RETURNING") { + t.Errorf("expected no RETURNING clause when version not set, got: %s", sql) + } + if len(args) != 0 { + t.Errorf("expected no args, got %v", args) + } + }) +} diff --git a/dialect/psql/version.go b/dialect/psql/version.go new file mode 100644 index 00000000..3bbf2feb --- /dev/null +++ b/dialect/psql/version.go @@ -0,0 +1,32 @@ +package psql + +import "context" + +// VersionKey is a context key for storing the database version. +// Use SetVersion to set the version in context. +type VersionKey struct{} + +// SetVersion sets the major version (e.g., 15, 16, 17) in the context. +// This is used to enable version-specific features like MERGE with RETURNING (version 17+). +// +// Example: +// +// ctx := psql.SetVersion(ctx, 17) +func SetVersion(ctx context.Context, version int) context.Context { + return context.WithValue(ctx, VersionKey{}, version) +} + +// GetVersion returns the major version from the context. +// Returns 0 if the version is not set. +func GetVersion(ctx context.Context) int { + if v, ok := ctx.Value(VersionKey{}).(int); ok { + return v + } + return 0 +} + +// VersionAtLeast checks if the version in context is at least the given version. +// Returns false if version is not set in context. +func VersionAtLeast(ctx context.Context, minVersion int) bool { + return GetVersion(ctx) >= minVersion +} diff --git a/gen/templates/models/table/005_one_methods.go.tpl b/gen/templates/models/table/005_one_methods.go.tpl index a975ccc2..2da24e4c 100644 --- a/gen/templates/models/table/005_one_methods.go.tpl +++ b/gen/templates/models/table/005_one_methods.go.tpl @@ -18,6 +18,8 @@ func (o *{{$tAlias.UpSingular}}) AfterQueryHook(ctx context.Context, exec bob.Ex ctx, err = {{$tAlias.UpPlural}}.AfterUpdateHooks.RunHooks(ctx, exec, {{$tAlias.UpSingular}}Slice{o}) case bob.QueryTypeDelete: ctx, err = {{$tAlias.UpPlural}}.AfterDeleteHooks.RunHooks(ctx, exec, {{$tAlias.UpSingular}}Slice{o}) + case bob.QueryTypeMerge: + ctx, err = {{$tAlias.UpPlural}}.AfterMergeHooks.RunHooks(ctx, exec, {{$tAlias.UpSingular}}Slice{o}) {{- end}} } diff --git a/gen/templates/models/table/007_slice_methods.go.tpl b/gen/templates/models/table/007_slice_methods.go.tpl index 48cf527d..f94f9bd1 100644 --- a/gen/templates/models/table/007_slice_methods.go.tpl +++ b/gen/templates/models/table/007_slice_methods.go.tpl @@ -18,6 +18,8 @@ func (o {{$tAlias.UpSingular}}Slice) AfterQueryHook(ctx context.Context, exec bo ctx, err = {{$tAlias.UpPlural}}.AfterUpdateHooks.RunHooks(ctx, exec, o) case bob.QueryTypeDelete: ctx, err = {{$tAlias.UpPlural}}.AfterDeleteHooks.RunHooks(ctx, exec, o) + case bob.QueryTypeMerge: + ctx, err = {{$tAlias.UpPlural}}.AfterMergeHooks.RunHooks(ctx, exec, o) {{- end}} } @@ -130,6 +132,33 @@ func (o {{$tAlias.UpSingular}}Slice) DeleteMod() bob.Mod[*dialect.DeleteQuery] { }) } +// MergeMod modifies a merge query to run BeforeMergeHooks and AfterMergeHooks +// and updates the slice with the returned rows. +func (o {{$tAlias.UpSingular}}Slice) MergeMod() bob.Mod[*dialect.MergeQuery] { + return bob.ModFunc[*dialect.MergeQuery](func(q *dialect.MergeQuery) { + q.AppendHooks(func(ctx context.Context, exec bob.Executor) (context.Context, error) { + return {{$tAlias.UpPlural}}.BeforeMergeHooks.RunHooks(ctx, exec, o) + }) + + q.AppendLoader(bob.LoaderFunc(func(ctx context.Context, exec bob.Executor, retrieved any) error { + var err error + switch retrieved := retrieved.(type) { + case *{{$tAlias.UpSingular}}: + o.copyMatchingRows(retrieved) + case []*{{$tAlias.UpSingular}}: + o.copyMatchingRows(retrieved...) + case {{$tAlias.UpSingular}}Slice: + o.copyMatchingRows(retrieved...) + default: + // If the retrieved value is not a {{$tAlias.UpSingular}} or a slice of {{$tAlias.UpSingular}} + // then run the AfterMergeHooks on the slice + _, err = {{$tAlias.UpPlural}}.AfterMergeHooks.RunHooks(ctx, exec, o) + } + + return err + })) + }) +} {{block "slice_update" . -}} {{$table := .Table}} diff --git a/query.go b/query.go index 424b909e..eb227f8c 100644 --- a/query.go +++ b/query.go @@ -25,6 +25,7 @@ const ( QueryTypeUpdate QueryTypeDelete QueryTypeValues + QueryTypeMerge ) func (q QueryType) String() string { @@ -39,6 +40,8 @@ func (q QueryType) String() string { return "DELETE" case QueryTypeValues: return "VALUES" + case QueryTypeMerge: + return "MERGE" default: return "UNKNOWN" } diff --git a/website/docs/models/hooks.md b/website/docs/models/hooks.md index 7f98b1b2..8e11187d 100644 --- a/website/docs/models/hooks.md +++ b/website/docs/models/hooks.md @@ -23,6 +23,8 @@ In **addition**, TableModels have: * `AfterUpdateHooks` * `BeforeDeleteHooks` * `AfterDeleteHooks` +* `BeforeMergeHooks` +* `AfterMergeHooks` These hooks run at the point one would expect from their naming.