Skip to content

Commit 46eb21a

Browse files
author
Roman A. Grigorovich
committed
feat(psql version): add version management for MERGE statement with automatic RETURNING support
1 parent 7228554 commit 46eb21a

File tree

3 files changed

+193
-2
lines changed

3 files changed

+193
-2
lines changed

dialect/psql/merge_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package psql_test
22

33
import (
4+
"context"
45
"testing"
56

7+
"github.com/stephenafamo/bob"
68
"github.com/stephenafamo/bob/dialect/psql"
79
"github.com/stephenafamo/bob/dialect/psql/mm"
810
"github.com/stephenafamo/bob/dialect/psql/sm"
@@ -505,3 +507,149 @@ func TestMerge(t *testing.T) {
505507

506508
testutils.RunTests(t, examples, formatter)
507509
}
510+
511+
func TestMergeWithVersion(t *testing.T) {
512+
t.Run("version 17+ adds RETURNING automatically with mm.Returning", func(t *testing.T) {
513+
ctx := context.Background()
514+
ctx = psql.SetVersion(ctx, 17)
515+
516+
q := psql.Merge(
517+
mm.Into("target"),
518+
mm.Using("source").As("s").On(
519+
psql.Quote("s", "id").EQ(psql.Quote("target", "id")),
520+
),
521+
mm.WhenMatched(
522+
mm.ThenUpdate(
523+
mm.SetCol("name").ToExpr(psql.Quote("s", "name")),
524+
),
525+
),
526+
mm.Returning("*"),
527+
)
528+
529+
sql, args, err := bob.Build(ctx, q)
530+
if err != nil {
531+
t.Fatalf("error: %v", err)
532+
}
533+
534+
expectedSQL := `MERGE INTO target USING source AS "s" ON "s"."id" = "target"."id" WHEN MATCHED THEN UPDATE SET "name" = "s"."name" RETURNING *`
535+
diff, err := testutils.QueryDiff(expectedSQL, sql, formatter)
536+
if err != nil {
537+
t.Fatalf("error: %v", err)
538+
}
539+
if diff != "" {
540+
t.Errorf("SQL mismatch:\n%s\nGot: %s", diff, sql)
541+
}
542+
if len(args) != 0 {
543+
t.Errorf("expected no args, got %v", args)
544+
}
545+
})
546+
547+
t.Run("version 16 does not affect MERGE with explicit RETURNING", func(t *testing.T) {
548+
ctx := context.Background()
549+
ctx = psql.SetVersion(ctx, 16)
550+
551+
q := psql.Merge(
552+
mm.Into("target"),
553+
mm.Using("source").As("s").On(
554+
psql.Quote("s", "id").EQ(psql.Quote("target", "id")),
555+
),
556+
mm.WhenMatched(
557+
mm.ThenUpdate(
558+
mm.SetCol("name").ToExpr(psql.Quote("s", "name")),
559+
),
560+
),
561+
mm.Returning("*"),
562+
)
563+
564+
sql, args, err := bob.Build(ctx, q)
565+
if err != nil {
566+
t.Fatalf("error: %v", err)
567+
}
568+
569+
// RETURNING should still be present because it was explicitly added
570+
expectedSQL := `MERGE INTO target USING source AS "s" ON "s"."id" = "target"."id" WHEN MATCHED THEN UPDATE SET "name" = "s"."name" RETURNING *`
571+
diff, err := testutils.QueryDiff(expectedSQL, sql, formatter)
572+
if err != nil {
573+
t.Fatalf("error: %v", err)
574+
}
575+
if diff != "" {
576+
t.Errorf("SQL mismatch:\n%s\nGot: %s", diff, sql)
577+
}
578+
if len(args) != 0 {
579+
t.Errorf("expected no args, got %v", args)
580+
}
581+
})
582+
583+
t.Run("no version set - MERGE without RETURNING", func(t *testing.T) {
584+
ctx := context.Background()
585+
// No version set
586+
587+
q := psql.Merge(
588+
mm.Into("target"),
589+
mm.Using("source").As("s").On(
590+
psql.Quote("s", "id").EQ(psql.Quote("target", "id")),
591+
),
592+
mm.WhenMatched(
593+
mm.ThenUpdate(
594+
mm.SetCol("name").ToExpr(psql.Quote("s", "name")),
595+
),
596+
),
597+
)
598+
599+
sql, args, err := bob.Build(ctx, q)
600+
if err != nil {
601+
t.Fatalf("error: %v", err)
602+
}
603+
604+
// No RETURNING because no version set
605+
expectedSQL := `MERGE INTO target USING source AS "s" ON "s"."id" = "target"."id" WHEN MATCHED THEN UPDATE SET "name" = "s"."name"`
606+
diff, err := testutils.QueryDiff(expectedSQL, sql, formatter)
607+
if err != nil {
608+
t.Fatalf("error: %v", err)
609+
}
610+
if diff != "" {
611+
t.Errorf("SQL mismatch:\n%s\nGot: %s", diff, sql)
612+
}
613+
if len(args) != 0 {
614+
t.Errorf("expected no args, got %v", args)
615+
}
616+
})
617+
618+
t.Run("version 17+ with WhenNotMatchedBySource (PG17 feature)", func(t *testing.T) {
619+
ctx := context.Background()
620+
ctx = psql.SetVersion(ctx, 17)
621+
622+
q := psql.Merge(
623+
mm.Into("target"),
624+
mm.Using("source").As("s").On(
625+
psql.Quote("s", "id").EQ(psql.Quote("target", "id")),
626+
),
627+
mm.WhenMatched(
628+
mm.ThenUpdate(
629+
mm.SetCol("name").ToExpr(psql.Quote("s", "name")),
630+
),
631+
),
632+
mm.WhenNotMatchedBySource(
633+
mm.ThenDelete(),
634+
),
635+
mm.Returning(psql.Quote("target", "id")),
636+
)
637+
638+
sql, args, err := bob.Build(ctx, q)
639+
if err != nil {
640+
t.Fatalf("error: %v", err)
641+
}
642+
643+
expectedSQL := `MERGE INTO target USING source AS "s" ON "s"."id" = "target"."id" WHEN MATCHED THEN UPDATE SET "name" = "s"."name" WHEN NOT MATCHED BY SOURCE THEN DELETE RETURNING "target"."id"`
644+
diff, err := testutils.QueryDiff(expectedSQL, sql, formatter)
645+
if err != nil {
646+
t.Fatalf("error: %v", err)
647+
}
648+
if diff != "" {
649+
t.Errorf("SQL mismatch:\n%s\nGot: %s", diff, sql)
650+
}
651+
if len(args) != 0 {
652+
t.Errorf("expected no args, got %v", args)
653+
}
654+
})
655+
}

dialect/psql/table.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,9 @@ func (t *Table[T, Tslice, Tset, C]) Delete(queryMods ...bob.Mod[*dialect.DeleteQ
146146

147147
// Starts a Merge query for this table
148148
// The caller must provide USING and WHEN clauses via queryMods
149-
// Note: RETURNING clause is NOT added automatically because it requires PostgreSQL 17+
150-
// Use mm.Returning() explicitly if needed and you are on PostgreSQL 17+
149+
// RETURNING clause is automatically added if version >= 17 is set in context.
150+
// Use psql.SetVersion(ctx, 17) to enable automatic RETURNING for MERGE.
151+
// For older versions, use mm.Returning() explicitly if needed.
151152
func (t *Table[T, Tslice, Tset, C]) Merge(queryMods ...bob.Mod[*dialect.MergeQuery]) *ormMergeQuery[T, Tslice] {
152153
q := &ormMergeQuery[T, Tslice]{
153154
ExecQuery: orm.ExecQuery[*dialect.MergeQuery]{
@@ -157,6 +158,16 @@ func (t *Table[T, Tslice, Tset, C]) Merge(queryMods ...bob.Mod[*dialect.MergeQue
157158
Scanner: t.scanner,
158159
}
159160

161+
q.Expression.AppendContextualModFunc(
162+
func(ctx context.Context, q *dialect.MergeQuery) (context.Context, error) {
163+
// RETURNING in MERGE requires version 17+
164+
if VersionAtLeast(ctx, 17) && !q.HasReturning() {
165+
q.AppendReturning(t.Columns)
166+
}
167+
return ctx, nil
168+
},
169+
)
170+
160171
q.Apply(queryMods...)
161172

162173
return q

dialect/psql/version.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package psql
2+
3+
import "context"
4+
5+
// VersionKey is a context key for storing the database version.
6+
// Use SetVersion to set the version in context.
7+
type VersionKey struct{}
8+
9+
// SetVersion sets the major version (e.g., 15, 16, 17) in the context.
10+
// This is used to enable version-specific features like MERGE with RETURNING (version 17+).
11+
//
12+
// Example:
13+
//
14+
// ctx := psql.SetVersion(ctx, 17)
15+
func SetVersion(ctx context.Context, version int) context.Context {
16+
return context.WithValue(ctx, VersionKey{}, version)
17+
}
18+
19+
// GetVersion returns the major version from the context.
20+
// Returns 0 if the version is not set.
21+
func GetVersion(ctx context.Context) int {
22+
if v, ok := ctx.Value(VersionKey{}).(int); ok {
23+
return v
24+
}
25+
return 0
26+
}
27+
28+
// VersionAtLeast checks if the version in context is at least the given version.
29+
// Returns false if version is not set in context.
30+
func VersionAtLeast(ctx context.Context, minVersion int) bool {
31+
return GetVersion(ctx) >= minVersion
32+
}

0 commit comments

Comments
 (0)