Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(duckdb): Support simplified UNPIVOT statement #4545

Merged
merged 2 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4244,13 +4244,20 @@ class Pivot(Expression):
"columns": False,
"include_nulls": False,
"default_on_null": False,
"into": False,
}

@property
def unpivot(self) -> bool:
return bool(self.args.get("unpivot"))


# https://duckdb.org/docs/sql/statements/unpivot#simplified-unpivot-syntax
# UNPIVOT ... INTO [NAME <col_name> VALUE <col_value>][...,]
class UnpivotColumns(Expression):
arg_types = {"this": True, "expressions": True}


class Window(Condition):
arg_types = {
"this": True,
Expand Down
14 changes: 11 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,21 +1988,23 @@ def tablesample_sql(

def pivot_sql(self, expression: exp.Pivot) -> str:
expressions = self.expressions(expression, flat=True)
direction = "UNPIVOT" if expression.unpivot else "PIVOT"

if expression.this:
this = self.sql(expression, "this")
if not expressions:
return f"UNPIVOT {this}"

on = f"{self.seg('ON')} {expressions}"
into = self.sql(expression, "into")
into = f"{self.seg('INTO')} {into}" if into else ""
using = self.expressions(expression, key="using", flat=True)
using = f"{self.seg('USING')} {using}" if using else ""
group = self.sql(expression, "group")
return f"PIVOT {this}{on}{using}{group}"
return f"{direction} {this}{on}{into}{using}{group}"

alias = self.sql(expression, "alias")
alias = f" AS {alias}" if alias else ""
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")

field = self.sql(expression, "field")

Expand All @@ -2014,7 +2016,7 @@ def pivot_sql(self, expression: exp.Pivot) -> str:

default_on_null = self.sql(expression, "default_on_null")
default_on_null = f" DEFAULT ON NULL ({default_on_null})" if default_on_null else ""
return f"{direction}{nulls}({expressions} FOR {field}{default_on_null}){alias}"
return f"{self.seg(direction)}{nulls}({expressions} FOR {field}{default_on_null}){alias}"

def version_sql(self, expression: exp.Version) -> str:
this = f"FOR {expression.name}"
Expand Down Expand Up @@ -4637,3 +4639,9 @@ def partitionbyrangepropertydynamic_sql(
every.this.replace(exp.Literal.number(every.name))

return f"START {self.wrap(start)} END {self.wrap(end)} EVERY {self.wrap(self.sql(every))}"

def unpivotcolumns_sql(self, expression: exp.UnpivotColumns) -> str:
name = self.sql(expression, "this")
values = self.expressions(expression, flat=True)

return f"NAME {name} VALUE {values}"
39 changes: 34 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ class Parser(metaclass=_Parser):
TokenType.SET: lambda self: self._parse_set(),
TokenType.TRUNCATE: lambda self: self._parse_truncate_table(),
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.UNPIVOT: lambda self: self._parse_simplified_pivot(is_unpivot=True),
TokenType.UPDATE: lambda self: self._parse_update(),
TokenType.USE: lambda self: self.expression(
exp.Use,
Expand Down Expand Up @@ -3056,8 +3057,10 @@ def _parse_select(

this = self._parse_query_modifiers(this)
elif (table or nested) and self._match(TokenType.L_PAREN):
if self._match(TokenType.PIVOT):
this = self._parse_simplified_pivot()
if self._match_set((TokenType.PIVOT, TokenType.UNPIVOT)):
this = self._parse_simplified_pivot(
is_unpivot=self._prev.token_type == TokenType.UNPIVOT
)
elif self._match(TokenType.FROM):
this = exp.select("*").from_(
t.cast(exp.From, self._parse_from(skip_from_token=True))
Expand Down Expand Up @@ -4000,20 +4003,46 @@ def _parse_pivots(self) -> t.Optional[t.List[exp.Pivot]]:
def _parse_joins(self) -> t.Iterator[exp.Join]:
return iter(self._parse_join, None)

def _parse_unpivot_columns(self) -> t.Optional[exp.UnpivotColumns]:
if not self._match(TokenType.INTO):
return None

return self.expression(
exp.UnpivotColumns,
this=self._match_text_seq("NAME") and self._parse_column(),
expressions=self._match_text_seq("VALUE") and self._parse_csv(self._parse_column),
)

# https://duckdb.org/docs/sql/statements/pivot
def _parse_simplified_pivot(self) -> exp.Pivot:
def _parse_simplified_pivot(self, is_unpivot: t.Optional[bool] = None) -> exp.Pivot:
def _parse_on() -> t.Optional[exp.Expression]:
this = self._parse_bitwise()
return self._parse_in(this) if self._match(TokenType.IN) else this

if self._match(TokenType.IN):
# PIVOT ... ON col IN (row_val1, row_val2)
return self._parse_in(this)
elif self._match(TokenType.ALIAS, advance=False):
# UNPIVOT ... ON (col1, col2, col3) AS row_val
return self._parse_alias(this)

return this

this = self._parse_table()
expressions = self._match(TokenType.ON) and self._parse_csv(_parse_on)
into = self._parse_unpivot_columns()
using = self._match(TokenType.USING) and self._parse_csv(
lambda: self._parse_alias(self._parse_function())
)
group = self._parse_group()

return self.expression(
exp.Pivot, this=this, expressions=expressions, using=using, group=group
exp.Pivot,
this=this,
expressions=expressions,
using=using,
group=group,
unpivot=is_unpivot,
into=into,
)

def _parse_pivot_in(self) -> exp.In | exp.PivotAny:
Expand Down
31 changes: 25 additions & 6 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,6 @@ def test_duckdb(self):
self.validate_identity("SELECT UNNEST(col, recursive := TRUE) FROM t")
self.validate_identity("VAR_POP(a)")
self.validate_identity("SELECT * FROM foo ASOF LEFT JOIN bar ON a = b")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country")
self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name")
self.validate_identity("SELECT {'a': 1} AS x")
self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x")
self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}")
Expand Down Expand Up @@ -1415,3 +1409,28 @@ def test_attach_detach(self):
self.validate_identity("DETACH IF EXISTS file")

self.validate_identity("DETACH DATABASE db", "DETACH db")

def test_simplified_pivot_unpivot(self):
self.validate_identity("PIVOT Cities ON Year USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING FIRST(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country")
self.validate_identity("PIVOT Cities ON Country, Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Country || '_' || Name USING SUM(Population)")
self.validate_identity("PIVOT Cities ON Year USING SUM(Population) GROUP BY Country, Name")

self.validate_identity("UNPIVOT (SELECT 1 AS col1, 2 AS col2) ON foo, bar")
self.validate_identity(
"UNPIVOT monthly_sales ON jan, feb, mar, apr, may, jun INTO NAME month VALUE sales"
)
self.validate_identity(
"UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales"
)
self.validate_identity(
"UNPIVOT monthly_sales ON (jan, feb, mar) AS q1, (apr, may, jun) AS q2 INTO NAME quarter VALUE month_1_sales, month_2_sales, month_3_sales"
)
self.validate_identity(
"WITH unpivot_alias AS (UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales) SELECT * FROM unpivot_alias"
)
self.validate_identity(
"SELECT * FROM (UNPIVOT monthly_sales ON COLUMNS(* EXCLUDE (empid, dept)) INTO NAME month VALUE sales) AS unpivot_alias"
)
Loading