From dfa4fef0b2680594bfa2c0aa9eaabefcf05bfaef Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Mon, 23 Dec 2024 12:31:42 +0200 Subject: [PATCH 1/2] feat(duckdb): Support simplified UNPIVOT statement --- sqlglot/expressions.py | 7 +++++++ sqlglot/generator.py | 14 ++++++++++--- sqlglot/parser.py | 39 ++++++++++++++++++++++++++++++----- tests/dialects/test_duckdb.py | 31 ++++++++++++++++++++++------ 4 files changed, 77 insertions(+), 14 deletions(-) diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index a9f6a091f6..88a15147bf 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4244,6 +4244,7 @@ class Pivot(Expression): "columns": False, "include_nulls": False, "default_on_null": False, + "into": False, } @property @@ -4251,6 +4252,12 @@ def unpivot(self) -> bool: return bool(self.args.get("unpivot")) +# https://duckdb.org/docs/sql/statements/unpivot#simplified-unpivot-syntax +# UNPIVOT ... INTO [NAME VALUE ][...,] +class UnpivotColumns(Expression): + arg_types = {"this": True, "expressions": True} + + class Window(Condition): arg_types = { "this": True, diff --git a/sqlglot/generator.py b/sqlglot/generator.py index d7e8b0e622..f30472fd0f 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -1988,6 +1988,7 @@ def tablesample_sql( def pivot_sql(self, expression: exp.Pivot) -> str: expressions = self.expressions(expression, flat=True) + direction = "UNPIVOT" if expression.unpivot else "PIVOT" if expression.this: this = self.sql(expression, "this") @@ -1995,14 +1996,15 @@ def pivot_sql(self, expression: exp.Pivot) -> str: return f"UNPIVOT {this}" on = f"{self.seg('ON')} {expressions}" + into = self.sql(expression, "into") + into = f"{self.seg('INTO')} {into}" if into else "" using = self.expressions(expression, key="using", flat=True) using = f"{self.seg('USING')} {using}" if using else "" group = self.sql(expression, "group") - return f"PIVOT {this}{on}{using}{group}" + return f"{direction} {this}{on}{into}{using}{group}" alias = self.sql(expression, "alias") alias = f" AS {alias}" if alias else "" - direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT") field = self.sql(expression, "field") @@ -2014,7 +2016,7 @@ def pivot_sql(self, expression: exp.Pivot) -> str: default_on_null = self.sql(expression, "default_on_null") default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else "" - return f"{direction}{nulls}({expressions} FOR {field}{default_on_null}){alias}" + return f"{self.seg(direction)}{nulls}({expressions} FOR {field}{default_on_null}){alias}" def version_sql(self, expression: exp.Version) -> str: this = f"FOR {expression.name}" @@ -4637,3 +4639,9 @@ def partitionbyrangepropertydynamic_sql( every.this.replace(exp.Literal.number(every.name)) return f"START {self.wrap(start)} END {self.wrap(end)} EVERY {self.wrap(self.sql(every))}" + + def unpivotcolumns_sql(self, expression: exp.UnpivotColumns) -> str: + name = self.sql(expression, "this") + values = self.expressions(expression, flat=True) + + return f"NAME {name} VALUE {values}" diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 4cbaf89a91..3f5dec8693 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -809,6 +809,7 @@ class Parser(metaclass=_Parser): TokenType.SET: lambda self: self._parse_set(), TokenType.TRUNCATE: lambda self: self._parse_truncate_table(), TokenType.UNCACHE: lambda self: self._parse_uncache(), + TokenType.UNPIVOT: lambda self: self._parse_simplified_pivot(is_unpivot=True), TokenType.UPDATE: lambda self: self._parse_update(), TokenType.USE: lambda self: self.expression( exp.Use, @@ -3056,8 +3057,10 @@ def _parse_select( this = self._parse_query_modifiers(this) elif (table or nested) and self._match(TokenType.L_PAREN): - if self._match(TokenType.PIVOT): - this = self._parse_simplified_pivot() + if self._match_set((TokenType.PIVOT, TokenType.UNPIVOT)): + this = self._parse_simplified_pivot( + is_unpivot=self._prev.token_type == TokenType.UNPIVOT + ) elif self._match(TokenType.FROM): this = exp.select("*").from_( t.cast(exp.From, self._parse_from(skip_from_token=True)) @@ -4000,20 +4003,46 @@ def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]: def _parse_joins(self) -> t.Iterator[exp.Join]: return iter(self._parse_join, None) + def _parse_unpivot_columns(self) -> t.Optional[exp.UnpivotColumns]: + if not self._match(TokenType.INTO): + return None + + return self.expression( + exp.UnpivotColumns, + this=self._match_text_seq("NAME") and self._parse_column(), + expressions=self._match_text_seq("VALUE") and self._parse_csv(self._parse_column), + ) + # https://duckdb.org/docs/sql/statements/pivot - def _parse_simplified_pivot(self) -> exp.Pivot: + def _parse_simplified_pivot(self, is_unpivot: t.Optional[bool] = None) -> exp.Pivot: def _parse_on() -> t.Optional[exp.Expression]: this = self._parse_bitwise() - return self._parse_in(this) if self._match(TokenType.IN) else this + + if self._match(TokenType.IN): + # PIVOT ... ON col IN (row_val1, row_val2) + return self._parse_in(this) + elif self._match(TokenType.ALIAS, advance=False): + # UNPIVOT ... ON (col1, col2, col3) AS row_val + return self._parse_alias(this) + + return this this = self._parse_table() expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on) + into = self._parse_unpivot_columns() using = self._match(TokenType.USING) and self._parse_csv( lambda: self._parse_alias(self._parse_function()) ) group = self._parse_group() + return self.expression( - exp.Pivot, this=this, expressions=expressions, using=using, group=group + exp.Pivot, + this=this, + expressions=expressions, + using=using, + group=group, + unpivot=is_unpivot, + into=into, ) def _parse_pivot_in(self) -> exp.In | exp.PivotAny: diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 09f0134dbc..62db38007f 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -276,12 +276,6 @@ def test_duckdb(self): self.validate_identity("SELECT UNNEST(col, recursive := TRUE) FROM t") self.validate_identity("VAR_POP(a)") self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b") - self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") - self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)") - self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country") - self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)") - self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)") - self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name") self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}") @@ -1415,3 +1409,28 @@ def test_attach_detach(self): self.validate_identity("DETACH IF EXISTS file") self.validate_identity("DETACH DATABASE db", "DETACH db") + + def test_simplified_pivot_unpivot(self): + self.validate_identity("PIVOT Cities ON Year USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)") + self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country") + self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)") + self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name") + + self.validate_identity("UNPIVOT (SELECT 1 AS col1, 2 AS col2) ON foo, bar") + self.validate_identity( + "UNPIVOT monthly_sales ON jan, feb, mar, apr, may, jun INTO NAME month VALUE sales" + ) + self.validate_identity( + "UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales" + ) + self.validate_identity( + "UNPIVOT monthly_sales ON (jan, feb, mar) AS q1, (apr, may, jun) AS q2 INTO NAME quarter VALUE month_1_sales, month_2_sales, month_3_sales" + ) + self.validate_identity( + "WITH unpivot_alias AS (UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales) SELECT * FROM unpivot_alias" + ) + self.validate_identity( + "SELECT * FROM (UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales) AS unpivot_alias" + ) From 71f3ca85b85a569d7ab556fb21008326d8e418f2 Mon Sep 17 00:00:00 2001 From: Jo <46752250+georgesittas@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:36:52 +0200 Subject: [PATCH 2/2] Update sqlglot/parser.py --- sqlglot/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 3f5dec8693..20ca97c4bd 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -4021,7 +4021,7 @@ def _parse_on() -> t.Optional[exp.Expression]: if self._match(TokenType.IN): # PIVOT ... ON col IN (row_val1, row_val2) return self._parse_in(this) - elif self._match(TokenType.ALIAS, advance=False): + if self._match(TokenType.ALIAS, advance=False): # UNPIVOT ... ON (col1, col2, col3) AS row_val return self._parse_alias(this)