Skip to content

Commit b7ab3f1

Browse files
Fix!: exp.Merge condition for Trino/Postgres (#4596)
* Fix!: exp.Merge condition for Trino/Postgres * address PR review comment * Fixups --------- Co-authored-by: Jo <[email protected]>
1 parent 199508a commit b7ab3f1

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

sqlglot/dialects/dialect.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,19 +1570,25 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
15701570
targets.add(normalize(alias.this))
15711571

15721572
for when in expression.args["whens"].expressions:
1573-
# only remove the target names from the THEN clause
1574-
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
1575-
# ref: https://github.com/TobikoData/sqlmesh/issues/2934
1576-
then = when.args.get("then")
1573+
# only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED
1574+
# they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part
1575+
# (not the column list) of the INSERT
1576+
then: exp.Insert | exp.Update | None = when.args.get("then")
15771577
if then:
1578-
then.transform(
1579-
lambda node: (
1580-
exp.column(node.this)
1581-
if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets
1582-
else node
1583-
),
1584-
copy=False,
1585-
)
1578+
if isinstance(then, exp.Update):
1579+
for equals in then.find_all(exp.EQ):
1580+
equal_lhs = equals.this
1581+
if (
1582+
isinstance(equal_lhs, exp.Column)
1583+
and normalize(equal_lhs.args.get("table")) in targets
1584+
):
1585+
equal_lhs.replace(exp.column(equal_lhs.this))
1586+
if isinstance(then, exp.Insert):
1587+
column_list = then.this
1588+
if isinstance(column_list, exp.Tuple):
1589+
for column in column_list.expressions:
1590+
if normalize(column.args.get("table")) in targets:
1591+
column.replace(exp.column(column.this))
15861592

15871593
return self.merge_sql(expression)
15881594

tests/dialects/test_dialect.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,6 +2336,17 @@ def test_merge(self):
23362336
},
23372337
)
23382338

2339+
# needs to preserve the target alias in then WHEN condition and function but not in the THEN clause
2340+
self.validate_all(
2341+
"""MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a
2342+
WHEN MATCHED THEN UPDATE SET target.b = COALESCE(src.b, target.b)
2343+
WHEN NOT MATCHED THEN INSERT (target.a, target.b) VALUES (src.a, src.b)""",
2344+
write={
2345+
"trino": """MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a WHEN MATCHED THEN UPDATE SET b = COALESCE(src.b, target.b) WHEN NOT MATCHED THEN INSERT (a, b) VALUES (src.a, src.b)""",
2346+
"postgres": """MERGE INTO foo AS target USING (SELECT a, b FROM tbl) AS src ON src.a = target.a WHEN MATCHED THEN UPDATE SET b = COALESCE(src.b, target.b) WHEN NOT MATCHED THEN INSERT (a, b) VALUES (src.a, src.b)""",
2347+
},
2348+
)
2349+
23392350
def test_substring(self):
23402351
self.validate_all(
23412352
"SUBSTR('123456', 2, 3)",

0 commit comments

Comments
 (0)