Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameterization support #16

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pypika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,16 @@
CustomFunction,
EmptyCriterion,
Field,
FormatParameter,
Index,
Interval,
NamedParameter,
Not,
NullValue,
NumericParameter,
Parameter,
PyformatParameter,
QmarkParameter,
Parameterizer,
Rollup,
SystemTimeValue,
Tuple,
ValueWrapper,
)

NULL = NullValue()
Expand Down
25 changes: 15 additions & 10 deletions pypika/dialects/mssql.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from typing import Any
from typing import Any, cast

from pypika.enums import Dialects
from pypika.exceptions import QueryException
from pypika.queries import Query, QueryBuilder
from pypika.terms import ValueWrapper
from pypika.utils import builder


Expand Down Expand Up @@ -42,25 +43,29 @@ def top(self, value: str | int) -> MSSQLQueryBuilder: # type:ignore[return]
@builder
def fetch_next(self, limit: int) -> MSSQLQueryBuilder: # type:ignore[return]
# Overridden to provide a more domain-specific API for T-SQL users
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

def _offset_sql(self) -> str:
def _offset_sql(self, **kwargs) -> str:
order_by = ""
if not self._orderbys:
order_by = "ORDER BY (SELECT 0)"
return order_by + " OFFSET {offset} ROWS".format(offset=self._offset or 0)
order_by = " ORDER BY (SELECT 0)"
return order_by + " OFFSET {offset} ROWS".format(
offset=self._offset.get_sql(**kwargs) if self._offset is not None else 0
)

def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))

def _apply_pagination(self, querystring: str) -> str:
def _apply_pagination(self, querystring: str, **kwargs) -> str:
# Note: Overridden as MSSQL specifies offset before the fetch next limit
if self._limit is not None or self._offset:
# Offset has to be present if fetch next is specified in a MSSQL query
querystring += self._offset_sql()
querystring += self._offset_sql(**kwargs)

if self._limit is not None:
querystring += self._limit_sql()
querystring += self._limit_sql(**kwargs)

return querystring

Expand Down
14 changes: 9 additions & 5 deletions pypika/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ def get_sql(self, *args: Any, **kwargs: Any) -> str:
kwargs["groupby_alias"] = False
return super().get_sql(*args, **kwargs)

def _offset_sql(self) -> str:
return " OFFSET {offset} ROWS".format(offset=self._offset)

def _limit_sql(self) -> str:
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset} ROWS".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " FETCH NEXT {limit} ROWS ONLY".format(limit=self._limit.get_sql(**kwargs))
68 changes: 36 additions & 32 deletions pypika/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,8 +535,8 @@ def __init__(
self._set_operation = [(set_operation, set_operation_query)]
self._orderbys: list[tuple[Field, Order | None]] = []

self._limit: int | None = None
self._offset: int | None = None
self._limit: ValueWrapper | None = None
self._offset: ValueWrapper | None = None

self._wrapper_cls = wrapper_cls

Expand All @@ -553,11 +553,11 @@ def orderby(self, *fields: Field, **kwargs: Any) -> "Self": # type:ignore[retur

@builder
def limit(self, limit: int) -> "Self": # type:ignore[return]
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

@builder
def offset(self, offset: int) -> "Self": # type:ignore[return]
self._offset = offset
self._offset = cast(ValueWrapper, self.wrap_constant(offset))

@builder
def union(self, other: Selectable) -> "Self": # type:ignore[return]
Expand Down Expand Up @@ -624,11 +624,8 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An
if self._orderbys:
querystring += self._orderby_sql(**kwargs)

if self._limit is not None:
querystring += self._limit_sql()

if self._offset:
querystring += self._offset_sql()
querystring += self._limit_sql(**kwargs)
querystring += self._offset_sql(**kwargs)

if subquery:
querystring = "({query})".format(query=querystring, **kwargs)
Expand Down Expand Up @@ -668,11 +665,15 @@ def _orderby_sql(self, quote_char: str | None = None, **kwargs: Any) -> str:

return " ORDER BY {orderby}".format(orderby=",".join(clauses))

def _offset_sql(self) -> str:
return " OFFSET {offset}".format(offset=self._offset)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self) -> str:
return " LIMIT {limit}".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs))


class QueryBuilder(Selectable, Term): # type:ignore[misc]
Expand Down Expand Up @@ -725,8 +726,8 @@ def __init__(
self._joins: list[Join] = []
self._unions: list = []

self._limit: int | None = None
self._offset: int | None = None
self._limit: ValueWrapper | None = None
self._offset: ValueWrapper | None = None

self._updates: list[tuple] = []

Expand Down Expand Up @@ -1223,11 +1224,11 @@ def hash_join(self, item: Table | "QueryBuilder" | AliasedQuery) -> "Joiner":

@builder
def limit(self, limit: int) -> "Self": # type:ignore[return]
self._limit = limit
self._limit = cast(ValueWrapper, self.wrap_constant(limit))

@builder
def offset(self, offset: int) -> "Self": # type:ignore[return]
self._offset = offset
self._offset = cast(ValueWrapper, self.wrap_constant(offset))

@builder
def union(self, other: Self) -> _SetOperation:
Expand All @@ -1252,7 +1253,8 @@ def minus(self, other: Self) -> _SetOperation:
@builder
def set(self, field: Field | str, value: Any) -> "Self": # type:ignore[return]
field = Field(field) if not isinstance(field, Field) else field
self._updates.append((field, self._wrapper_cls(value)))
value = self.wrap_constant(value, wrapper_cls=self._wrapper_cls)
self._updates.append((field, value))

def __add__(self, other: Self) -> _SetOperation: # type:ignore[override]
return self.union(other)
Expand All @@ -1265,8 +1267,10 @@ def __sub__(self, other: Self) -> _SetOperation: # type:ignore[override]

@builder
def slice(self, slice: slice) -> "Self": # type:ignore[return]
self._offset = slice.start
self._limit = slice.stop
if slice.start is not None:
self._offset = cast(ValueWrapper, self.wrap_constant(slice.start))
if slice.stop is not None:
self._limit = cast(ValueWrapper, self.wrap_constant(slice.stop))

def __getitem__(self, item: Any) -> Self | Field: # type:ignore[override]
if not isinstance(item, slice):
Expand Down Expand Up @@ -1512,7 +1516,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An
if self._orderbys:
querystring += self._orderby_sql(**kwargs)

querystring = self._apply_pagination(querystring)
querystring = self._apply_pagination(querystring, **kwargs)

if self._for_update:
querystring += self._for_update_sql(**kwargs)
Expand All @@ -1532,13 +1536,9 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An

return querystring

def _apply_pagination(self, querystring: str) -> str:
if self._limit is not None:
querystring += self._limit_sql()

if self._offset:
querystring += self._offset_sql()

def _apply_pagination(self, querystring: str, **kwargs) -> str:
querystring += self._limit_sql(**kwargs)
querystring += self._offset_sql(**kwargs)
return querystring

def _with_sql(self, **kwargs: Any) -> str:
Expand Down Expand Up @@ -1750,11 +1750,15 @@ def _having_sql(self, quote_char: str | None = None, **kwargs: Any) -> str:
having = self._havings.get_sql(quote_char=quote_char, **kwargs) # type:ignore[union-attr]
return f" HAVING {having}"

def _offset_sql(self) -> str:
return " OFFSET {offset}".format(offset=self._offset)
def _offset_sql(self, **kwargs) -> str:
if self._offset is None:
return ""
return " OFFSET {offset}".format(offset=self._offset.get_sql(**kwargs))

def _limit_sql(self) -> str:
return " LIMIT {limit}".format(limit=self._limit)
def _limit_sql(self, **kwargs) -> str:
if self._limit is None:
return ""
return " LIMIT {limit}".format(limit=self._limit.get_sql(**kwargs))

def _set_sql(self, **kwargs: Any) -> str:
return " SET {set}".format(
Expand Down
Loading