Skip to content

Commit c787ad2

Browse files
authoredAug 12, 2024
Refactor code to enhance selection and operations (#9)
* WIP: refactoring enhancements * Update delete method * Fix query filter parse * Update tests * Update multiple_conditions to and * fix parse filters method * fix tests * Add arithmetic filters
1 parent b79b48e commit c787ad2

File tree

12 files changed

+875
-275
lines changed

12 files changed

+875
-275
lines changed
 

Diff for: ‎sqlalchemy_crud_plus/crud.py

+105-134
Original file line numberDiff line numberDiff line change
@@ -1,184 +1,130 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3-
from typing import Any, Generic, Iterable, Literal, Sequence, Type, TypeVar
3+
from typing import Any, Generic, Iterable, Sequence, Type
44

5-
from pydantic import BaseModel
6-
from sqlalchemy import Row, RowMapping, and_, asc, desc, or_, select
5+
from sqlalchemy import Row, RowMapping, select
76
from sqlalchemy import delete as sa_delete
87
from sqlalchemy import update as sa_update
98
from sqlalchemy.ext.asyncio import AsyncSession
109

11-
from sqlalchemy_crud_plus.errors import ModelColumnError, SelectExpressionError
10+
from sqlalchemy_crud_plus.errors import MultipleResultsError
11+
from sqlalchemy_crud_plus.types import CreateSchema, Model, UpdateSchema
12+
from sqlalchemy_crud_plus.utils import apply_sorting, count, parse_filters
1213

13-
_Model = TypeVar('_Model')
14-
_CreateSchema = TypeVar('_CreateSchema', bound=BaseModel)
15-
_UpdateSchema = TypeVar('_UpdateSchema', bound=BaseModel)
1614

17-
18-
class CRUDPlus(Generic[_Model]):
19-
def __init__(self, model: Type[_Model]):
15+
class CRUDPlus(Generic[Model]):
16+
def __init__(self, model: Type[Model]):
2017
self.model = model
2118

22-
async def create_model(self, session: AsyncSession, obj: _CreateSchema, commit: bool = False, **kwargs) -> _Model:
19+
async def create_model(self, session: AsyncSession, obj: CreateSchema, commit: bool = False, **kwargs) -> Model:
2320
"""
2421
Create a new instance of a model
2522
26-
:param session:
27-
:param obj:
28-
:param commit:
29-
:param kwargs:
23+
:param session: The SQLAlchemy async session.
24+
:param obj: The Pydantic schema containing data to be saved.
25+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
26+
:param kwargs: Additional model data not included in the pydantic schema.
3027
:return:
3128
"""
32-
if kwargs:
33-
ins = self.model(**obj.model_dump(), **kwargs)
34-
else:
29+
if not kwargs:
3530
ins = self.model(**obj.model_dump())
31+
else:
32+
ins = self.model(**obj.model_dump(), **kwargs)
3633
session.add(ins)
3734
if commit:
3835
await session.commit()
3936
return ins
4037

4138
async def create_models(
42-
self, session: AsyncSession, obj: Iterable[_CreateSchema], commit: bool = False
43-
) -> list[_Model]:
39+
self, session: AsyncSession, obj: Iterable[CreateSchema], commit: bool = False
40+
) -> list[Model]:
4441
"""
4542
Create new instances of a model
4643
47-
:param session:
48-
:param obj:
49-
:param commit:
44+
:param session: The SQLAlchemy async session.
45+
:param obj: The Pydantic schema list containing data to be saved.
46+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
5047
:return:
5148
"""
5249
ins_list = []
53-
for i in obj:
54-
ins_list.append(self.model(**i.model_dump()))
50+
for ins in obj:
51+
ins_list.append(self.model(**ins.model_dump()))
5552
session.add_all(ins_list)
5653
if commit:
5754
await session.commit()
5855
return ins_list
5956

60-
async def select_model_by_id(self, session: AsyncSession, pk: int) -> _Model | None:
57+
async def select_model(self, session: AsyncSession, pk: int) -> Model | None:
6158
"""
6259
Query by ID
6360
64-
:param session:
65-
:param pk:
61+
:param session: The SQLAlchemy async session.
62+
:param pk: The database primary key value.
6663
:return:
6764
"""
6865
stmt = select(self.model).where(self.model.id == pk)
6966
query = await session.execute(stmt)
7067
return query.scalars().first()
7168

72-
async def select_model_by_column(self, session: AsyncSession, column: str, column_value: Any) -> _Model | None:
69+
async def select_model_by_column(self, session: AsyncSession, **kwargs) -> Model | None:
7370
"""
7471
Query by column
7572
76-
:param session:
77-
:param column:
78-
:param column_value:
79-
:return:
80-
"""
81-
if hasattr(self.model, column):
82-
model_column = getattr(self.model, column)
83-
stmt = select(self.model).where(model_column == column_value) # type: ignore
84-
query = await session.execute(stmt)
85-
return query.scalars().first()
86-
else:
87-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
88-
89-
async def select_model_by_columns(
90-
self, session: AsyncSession, expression: Literal['and', 'or'] = 'and', **conditions
91-
) -> _Model | None:
92-
"""
93-
Query by columns
94-
95-
:param session:
96-
:param expression:
97-
:param conditions: Query conditions, format:column1=value1, column2=value2
73+
:param session: The SQLAlchemy async session.
74+
:param kwargs: Query expressions.
9875
:return:
9976
"""
100-
where_list = []
101-
for column, value in conditions.items():
102-
if hasattr(self.model, column):
103-
model_column = getattr(self.model, column)
104-
where_list.append(model_column == value)
105-
else:
106-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
107-
match expression:
108-
case 'and':
109-
stmt = select(self.model).where(and_(*where_list))
110-
query = await session.execute(stmt)
111-
case 'or':
112-
stmt = select(self.model).where(or_(*where_list))
113-
query = await session.execute(stmt)
114-
case _:
115-
raise SelectExpressionError(
116-
f'Select expression {expression} is not supported, only supports `and`, `or`'
117-
)
77+
filters = await parse_filters(self.model, **kwargs)
78+
stmt = select(self.model).where(*filters)
79+
query = await session.execute(stmt)
11880
return query.scalars().first()
11981

120-
async def select_models(self, session: AsyncSession) -> Sequence[Row[Any] | RowMapping | Any]:
82+
async def select_models(self, session: AsyncSession, **kwargs) -> Sequence[Row[Any] | RowMapping | Any]:
12183
"""
12284
Query all rows
12385
124-
:param session:
86+
:param session: The SQLAlchemy async session.
87+
:param kwargs: Query expressions.
12588
:return:
12689
"""
127-
stmt = select(self.model)
90+
filters = await parse_filters(self.model, **kwargs)
91+
stmt = select(self.model).where(*filters)
12892
query = await session.execute(stmt)
12993
return query.scalars().all()
13094

13195
async def select_models_order(
132-
self,
133-
session: AsyncSession,
134-
*columns,
135-
model_sort: Literal['asc', 'desc'] = 'desc',
96+
self, session: AsyncSession, sort_columns: str | list[str], sort_orders: str | list[str] | None = None, **kwargs
13697
) -> Sequence[Row | RowMapping | Any] | None:
13798
"""
138-
Query all rows asc or desc
99+
Query all rows and sort by columns
139100
140-
:param session:
141-
:param columns:
142-
:param model_sort:
101+
:param session: The SQLAlchemy async session.
102+
:param sort_columns: more details see apply_sorting
103+
:param sort_orders: more details see apply_sorting
143104
:return:
144105
"""
145-
sort_list = []
146-
for column in columns:
147-
if hasattr(self.model, column):
148-
model_column = getattr(self.model, column)
149-
sort_list.append(model_column)
150-
else:
151-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
152-
match model_sort:
153-
case 'asc':
154-
query = await session.execute(select(self.model).order_by(asc(*sort_list)))
155-
case 'desc':
156-
query = await session.execute(select(self.model).order_by(desc(*sort_list)))
157-
case _:
158-
raise SelectExpressionError(
159-
f'Select sort expression {model_sort} is not supported, only supports `asc`, `desc`'
160-
)
106+
filters = await parse_filters(self.model, **kwargs)
107+
stmt = select(self.model).where(*filters)
108+
stmt_sort = await apply_sorting(self.model, stmt, sort_columns, sort_orders)
109+
query = await session.execute(stmt_sort)
161110
return query.scalars().all()
162111

163112
async def update_model(
164-
self, session: AsyncSession, pk: int, obj: _UpdateSchema | dict[str, Any], commit: bool = False, **kwargs
113+
self, session: AsyncSession, pk: int, obj: UpdateSchema | dict[str, Any], commit: bool = False
165114
) -> int:
166115
"""
167-
Update an instance of model's primary key
116+
Update an instance by model's primary key
168117
169-
:param session:
170-
:param pk:
171-
:param obj:
172-
:param commit:
173-
:param kwargs:
118+
:param session: The SQLAlchemy async session.
119+
:param pk: The database primary key value.
120+
:param obj: A pydantic schema or dictionary containing the update data
121+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
174122
:return:
175123
"""
176124
if isinstance(obj, dict):
177125
instance_data = obj
178126
else:
179127
instance_data = obj.model_dump(exclude_unset=True)
180-
if kwargs:
181-
instance_data.update(kwargs)
182128
stmt = sa_update(self.model).where(self.model.id == pk).values(**instance_data)
183129
result = await session.execute(stmt)
184130
if commit:
@@ -188,55 +134,80 @@ async def update_model(
188134
async def update_model_by_column(
189135
self,
190136
session: AsyncSession,
191-
column: str,
192-
column_value: Any,
193-
obj: _UpdateSchema | dict[str, Any],
137+
obj: UpdateSchema | dict[str, Any],
138+
allow_multiple: bool = False,
194139
commit: bool = False,
195140
**kwargs,
196141
) -> int:
197142
"""
198-
Update an instance of model column
143+
Update an instance by model column
199144
200-
:param session:
201-
:param column:
202-
:param column_value:
203-
:param obj:
204-
:param commit:
205-
:param kwargs:
145+
:param session: The SQLAlchemy async session.
146+
:param obj: A pydantic schema or dictionary containing the update data
147+
:param allow_multiple: If `True`, allows updating multiple records that match the filters.
148+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
149+
:param kwargs: Query expressions.
206150
:return:
207151
"""
152+
filters = await parse_filters(self.model, **kwargs)
153+
total_count = await count(session, self.model, filters)
154+
if not allow_multiple and total_count > 1:
155+
raise MultipleResultsError(f'Only one record is expected to be update, found {total_count} records.')
208156
if isinstance(obj, dict):
209157
instance_data = obj
210158
else:
211159
instance_data = obj.model_dump(exclude_unset=True)
212-
if kwargs:
213-
instance_data.update(kwargs)
214-
if hasattr(self.model, column):
215-
model_column = getattr(self.model, column)
216-
else:
217-
raise ModelColumnError(f'Column {column} is not found in {self.model}')
218-
stmt = sa_update(self.model).where(model_column == column_value).values(**instance_data) # type: ignore
160+
stmt = sa_update(self.model).where(*filters).values(**instance_data) # type: ignore
219161
result = await session.execute(stmt)
220162
if commit:
221163
await session.commit()
222164
return result.rowcount # type: ignore
223165

224-
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False, **kwargs) -> int:
166+
async def delete_model(self, session: AsyncSession, pk: int, commit: bool = False) -> int:
225167
"""
226-
Delete an instance of a model
168+
Delete an instance by model's primary key
227169
228-
:param session:
229-
:param pk:
230-
:param commit:
231-
:param kwargs: for soft deletion only
170+
:param session: The SQLAlchemy async session.
171+
:param pk: The database primary key value.
172+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
232173
:return:
233174
"""
234-
if not kwargs:
235-
stmt = sa_delete(self.model).where(self.model.id == pk)
236-
result = await session.execute(stmt)
237-
else:
238-
stmt = sa_update(self.model).where(self.model.id == pk).values(**kwargs)
239-
result = await session.execute(stmt)
175+
stmt = sa_delete(self.model).where(self.model.id == pk)
176+
result = await session.execute(stmt)
240177
if commit:
241178
await session.commit()
242179
return result.rowcount # type: ignore
180+
181+
async def delete_model_by_column(
182+
self,
183+
session: AsyncSession,
184+
allow_multiple: bool = False,
185+
logical_deletion: bool = False,
186+
deleted_flag_column: str = 'del_flag',
187+
commit: bool = False,
188+
**kwargs,
189+
) -> int:
190+
"""
191+
Delete
192+
193+
:param session: The SQLAlchemy async session.
194+
:param commit: If `True`, commits the transaction immediately. Default is `False`.
195+
:param kwargs: Query expressions.
196+
:param allow_multiple: If `True`, allows deleting multiple records that match the filters.
197+
:param logical_deletion: If `True`, enable logical deletion instead of physical deletion
198+
:param deleted_flag_column: Specify the flag column for logical deletion
199+
:return:
200+
"""
201+
filters = await parse_filters(self.model, **kwargs)
202+
total_count = await count(session, self.model, filters)
203+
if not allow_multiple and total_count > 1:
204+
raise MultipleResultsError(f'Only one record is expected to be delete, found {total_count} records.')
205+
if logical_deletion:
206+
deleted_flag = {deleted_flag_column: True}
207+
stmt = sa_update(self.model).where(*filters).values(**deleted_flag)
208+
else:
209+
stmt = sa_delete(self.model).where(*filters)
210+
await session.execute(stmt)
211+
if commit:
212+
await session.commit()
213+
return total_count

Diff for: ‎sqlalchemy_crud_plus/errors.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,22 @@ def __init__(self, msg: str) -> None:
1717
super().__init__(msg)
1818

1919

20-
class SelectExpressionError(SQLAlchemyCRUDPlusException):
20+
class SelectOperatorError(SQLAlchemyCRUDPlusException):
2121
"""Error raised when a select expression is invalid."""
2222

2323
def __init__(self, msg: str) -> None:
2424
super().__init__(msg)
25+
26+
27+
class ColumnSortError(SQLAlchemyCRUDPlusException):
28+
"""Error raised when a column sorting is invalid."""
29+
30+
def __init__(self, msg: str) -> None:
31+
super().__init__(msg)
32+
33+
34+
class MultipleResultsError(SQLAlchemyCRUDPlusException):
35+
"""Error raised when multiple results are invalid."""
36+
37+
def __init__(self, msg: str) -> None:
38+
super().__init__(msg)

Diff for: ‎sqlalchemy_crud_plus/types.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
from typing import TypeVar
4+
5+
from pydantic import BaseModel
6+
7+
Model = TypeVar('Model')
8+
9+
CreateSchema = TypeVar('CreateSchema', bound=BaseModel)
10+
UpdateSchema = TypeVar('UpdateSchema', bound=BaseModel)

Diff for: ‎sqlalchemy_crud_plus/utils.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import warnings
4+
5+
from typing import Any, Callable, Type
6+
7+
from sqlalchemy import ColumnElement, Select, and_, asc, desc, func, or_, select
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
from sqlalchemy.orm.util import AliasedClass
10+
11+
from sqlalchemy_crud_plus.errors import ColumnSortError, ModelColumnError, SelectOperatorError
12+
from sqlalchemy_crud_plus.types import Model
13+
14+
_SUPPORTED_FILTERS = {
15+
# Comparison: https://docs.sqlalchemy.org/en/20/core/operators.html#comparison-operators
16+
'gt': lambda column: column.__gt__,
17+
'lt': lambda column: column.__lt__,
18+
'ge': lambda column: column.__ge__,
19+
'le': lambda column: column.__le__,
20+
'eq': lambda column: column.__eq__,
21+
'ne': lambda column: column.__ne__,
22+
'between': lambda column: column.between,
23+
# IN: https://docs.sqlalchemy.org/en/20/core/operators.html#in-comparisons
24+
'in': lambda column: column.in_,
25+
'not_in': lambda column: column.not_in,
26+
# Identity: https://docs.sqlalchemy.org/en/20/core/operators.html#identity-comparisons
27+
'is': lambda column: column.is_,
28+
'is_not': lambda column: column.is_not,
29+
'is_distinct_from': lambda column: column.is_distinct_from,
30+
'is_not_distinct_from': lambda column: column.is_not_distinct_from,
31+
# String: https://docs.sqlalchemy.org/en/20/core/operators.html#string-comparisons
32+
'like': lambda column: column.like,
33+
'not_like': lambda column: column.not_like,
34+
'ilike': lambda column: column.ilike,
35+
'not_ilike': lambda column: column.not_ilike,
36+
# String Containment: https://docs.sqlalchemy.org/en/20/core/operators.html#string-containment
37+
'startswith': lambda column: column.startswith,
38+
'endswith': lambda column: column.endswith,
39+
'contains': lambda column: column.contains,
40+
# String matching: https://docs.sqlalchemy.org/en/20/core/operators.html#string-matching
41+
'match': lambda column: column.match,
42+
# String Alteration: https://docs.sqlalchemy.org/en/20/core/operators.html#string-alteration
43+
'concat': lambda column: column.concat,
44+
# Arithmetic: https://docs.sqlalchemy.org/en/20/core/operators.html#arithmetic-operators
45+
'add': lambda column: column.__add__,
46+
'radd': lambda column: column.__radd__,
47+
'sub': lambda column: column.__sub__,
48+
'rsub': lambda column: column.__rsub__,
49+
'mul': lambda column: column.__mul__,
50+
'rmul': lambda column: column.__rmul__,
51+
'truediv': lambda column: column.__truediv__,
52+
'rtruediv': lambda column: column.__rtruediv__,
53+
'floordiv': lambda column: column.__floordiv__,
54+
'rfloordiv': lambda column: column.__rfloordiv__,
55+
'mod': lambda column: column.__mod__,
56+
'rmod': lambda column: column.__rmod__,
57+
}
58+
59+
60+
async def get_sqlalchemy_filter(
61+
operator: str, value: Any, allow_arithmetic: bool = True
62+
) -> Callable[[str], Callable] | None:
63+
if operator in ['in', 'not_in', 'between']:
64+
if not isinstance(value, (tuple, list, set)):
65+
raise SelectOperatorError(f'The value of the <{operator}> filter must be tuple, list or set')
66+
67+
if (
68+
operator
69+
in ['add', 'radd', 'sub', 'rsub', 'mul', 'rmul', 'truediv', 'rtruediv', 'floordiv', 'rfloordiv', 'mod', 'rmod']
70+
and not allow_arithmetic
71+
):
72+
raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}')
73+
74+
sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator)
75+
if sqlalchemy_filter is None:
76+
warnings.warn(
77+
f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.',
78+
SyntaxWarning,
79+
)
80+
return None
81+
82+
return sqlalchemy_filter
83+
84+
85+
async def get_column(model: Type[Model] | AliasedClass, field_name: str):
86+
column = getattr(model, field_name, None)
87+
if column is None:
88+
raise ModelColumnError(f'Column {field_name} is not found in {model}')
89+
return column
90+
91+
92+
async def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
93+
filters = []
94+
95+
for key, value in kwargs.items():
96+
if '__' in key:
97+
field_name, op = key.rsplit('__', 1)
98+
column = await get_column(model, field_name)
99+
if op == 'or':
100+
or_filters = [
101+
sqlalchemy_filter(column)(or_value)
102+
for or_op, or_value in value.items()
103+
if (sqlalchemy_filter := await get_sqlalchemy_filter(or_op, or_value)) is not None
104+
]
105+
filters.append(or_(*or_filters))
106+
elif isinstance(value, dict) and {'value', 'condition'}.issubset(value):
107+
advanced_value = value['value']
108+
condition = value['condition']
109+
sqlalchemy_filter = await get_sqlalchemy_filter(op, advanced_value)
110+
if sqlalchemy_filter is not None:
111+
condition_filters = []
112+
for cond_op, cond_value in condition.items():
113+
condition_filter = await get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
114+
condition_filters.append(
115+
condition_filter(sqlalchemy_filter(column)(advanced_value))(cond_value)
116+
if cond_op != 'between'
117+
else condition_filter(sqlalchemy_filter(column)(advanced_value))(*cond_value)
118+
)
119+
filters.append(and_(*condition_filters))
120+
else:
121+
sqlalchemy_filter = await get_sqlalchemy_filter(op, value)
122+
if sqlalchemy_filter is not None:
123+
filters.append(
124+
sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)
125+
)
126+
else:
127+
column = await get_column(model, key)
128+
filters.append(column == value)
129+
130+
return filters
131+
132+
133+
async def apply_sorting(
134+
model: Type[Model] | AliasedClass,
135+
stmt: Select,
136+
sort_columns: str | list[str],
137+
sort_orders: str | list[str] | None = None,
138+
) -> Select:
139+
"""
140+
Apply sorting to a SQLAlchemy query based on specified column names and sort orders.
141+
142+
:param model: The SQLAlchemy model.
143+
:param stmt: The SQLAlchemy `Select` statement to which sorting will be applied.
144+
:param sort_columns: A single column name or list of column names to sort the query results by.
145+
Must be used in conjunction with sort_orders.
146+
:param sort_orders: A single sort order ("asc" or "desc") or a list of sort orders, corresponding to each
147+
column in sort_columns. If not specified, defaults to ascending order for all sort_columns.
148+
:return:
149+
"""
150+
if sort_orders and not sort_columns:
151+
raise ValueError('Sort orders provided without corresponding sort columns.')
152+
153+
if sort_columns:
154+
if not isinstance(sort_columns, list):
155+
sort_columns = [sort_columns]
156+
157+
if sort_orders:
158+
if not isinstance(sort_orders, list):
159+
sort_orders = [sort_orders] * len(sort_columns)
160+
161+
if len(sort_columns) != len(sort_orders):
162+
raise ColumnSortError('The length of sort_columns and sort_orders must match.')
163+
164+
for order in sort_orders:
165+
if order not in ['asc', 'desc']:
166+
raise SelectOperatorError(
167+
f'Select sort operator {order} is not supported, only supports `asc`, `desc`'
168+
)
169+
170+
validated_sort_orders = ['asc'] * len(sort_columns) if not sort_orders else sort_orders
171+
172+
for idx, column_name in enumerate(sort_columns):
173+
column = await get_column(model, column_name)
174+
order = validated_sort_orders[idx]
175+
stmt = stmt.order_by(asc(column) if order == 'asc' else desc(column))
176+
177+
return stmt
178+
179+
180+
async def count(
181+
session: AsyncSession,
182+
model: Type[Model] | AliasedClass,
183+
filters: list[ColumnElement],
184+
) -> int:
185+
"""
186+
Counts records that match specified filters.
187+
188+
:param session: The sqlalchemy session to use for the operation.
189+
:param model: The SQLAlchemy model.
190+
:param filters: Filters to apply for the count.
191+
:return:
192+
"""
193+
stmt = select(func.count()).select_from(model)
194+
if filters:
195+
stmt = stmt.where(*filters)
196+
query = await session.execute(stmt)
197+
total_count = query.scalar()
198+
return total_count if total_count is not None else 0

Diff for: ‎tests/conftest.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
11
#!/usr/bin/env python3
22
# -*- coding: utf-8 -*-
3+
34
import pytest_asyncio
45

56
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
67

7-
from tests.model import Base
8+
from tests.model import Base, Ins
89

9-
async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True, pool_pre_ping=True)
10-
async_db_session = async_sessionmaker(async_engine, autoflush=False, expire_on_commit=False)
10+
_async_engine = create_async_engine('sqlite+aiosqlite:///:memory:', future=True)
11+
_async_session = async_sessionmaker(_async_engine, autoflush=False, expire_on_commit=False)
1112

1213

1314
@pytest_asyncio.fixture(scope='function', autouse=True)
1415
async def init_db():
15-
async with async_engine.begin() as conn:
16+
async with _async_engine.begin() as conn:
1617
await conn.run_sync(Base.metadata.create_all)
1718
yield
1819
await conn.run_sync(Base.metadata.drop_all)
20+
21+
22+
@pytest_asyncio.fixture
23+
async def async_db_session():
24+
yield _async_session
25+
26+
27+
@pytest_asyncio.fixture
28+
async def create_test_model():
29+
async with _async_session.begin() as session:
30+
data = [Ins(name=f'name_{i}') for i in range(1, 10)]
31+
session.add_all(data)

Diff for: ‎tests/model.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ class Ins(Base):
1717

1818
id: Mapped[int] = mapped_column(init=False, primary_key=True, index=True, autoincrement=True)
1919
name: Mapped[str] = mapped_column(String(64))
20+
del_flag: Mapped[bool] = mapped_column(default=False)
2021
created_time: Mapped[datetime] = mapped_column(init=False, default_factory=datetime.now)
2122
updated_time: Mapped[datetime | None] = mapped_column(init=False, onupdate=datetime.now)

Diff for: ‎tests/schema.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
from pydantic import BaseModel
4+
5+
6+
class ModelTest(BaseModel):
7+
name: str

Diff for: ‎tests/test_create.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import pytest
4+
5+
from sqlalchemy import select
6+
7+
from sqlalchemy_crud_plus import CRUDPlus
8+
from tests.model import Ins
9+
from tests.schema import ModelTest
10+
11+
12+
@pytest.mark.asyncio
13+
async def test_create_model(async_db_session):
14+
async with async_db_session.begin() as session:
15+
crud = CRUDPlus(Ins)
16+
for i in range(1, 10):
17+
data = ModelTest(name=f'name_{i}')
18+
await crud.create_model(session, data)
19+
async with async_db_session() as session:
20+
for i in range(1, 10):
21+
query = await session.scalar(select(Ins).where(Ins.id == i))
22+
assert query.name == f'name_{i}'
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_create_models(async_db_session):
27+
async with async_db_session.begin() as session:
28+
crud = CRUDPlus(Ins)
29+
data = []
30+
for i in range(1, 10):
31+
data.append(ModelTest(name=f'name_{i}'))
32+
await crud.create_models(session, data)
33+
async with async_db_session() as session:
34+
for i in range(1, 10):
35+
query = await session.scalar(select(Ins).where(Ins.id == i))
36+
assert query.name == f'name_{i}'

Diff for: ‎tests/test_crud.py

-136
This file was deleted.

Diff for: ‎tests/test_delete.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import pytest
4+
5+
from sqlalchemy_crud_plus import CRUDPlus
6+
from tests.model import Ins
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_delete_model(create_test_model, async_db_session):
11+
async with async_db_session.begin() as session:
12+
crud = CRUDPlus(Ins)
13+
result = await crud.delete_model(session, 1)
14+
assert result == 1
15+
16+
17+
@pytest.mark.asyncio
18+
async def test_delete_model_by_column(create_test_model, async_db_session):
19+
async with async_db_session.begin() as session:
20+
crud = CRUDPlus(Ins)
21+
result = await crud.delete_model_by_column(session, name='name_1')
22+
assert result == 1
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_delete_model_by_column_with_and(create_test_model, async_db_session):
27+
async with async_db_session.begin() as session:
28+
crud = CRUDPlus(Ins)
29+
result = await crud.delete_model_by_column(session, id=1, name='name_1')
30+
assert result == 1
31+
32+
33+
@pytest.mark.asyncio
34+
async def test_delete_model_by_column_logical(create_test_model, async_db_session):
35+
async with async_db_session.begin() as session:
36+
crud = CRUDPlus(Ins)
37+
result = await crud.delete_model_by_column(session, logical_deletion=True, name='name_1')
38+
assert result == 1
39+
40+
41+
@pytest.mark.asyncio
42+
async def test_delete_model_by_column_allow_multiple(create_test_model, async_db_session):
43+
async with async_db_session.begin() as session:
44+
crud = CRUDPlus(Ins)
45+
result = await crud.delete_model_by_column(session, allow_multiple=True, name__startswith='name')
46+
assert result == 9
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_delete_model_by_column_logical_with_multiple(create_test_model, async_db_session):
51+
async with async_db_session.begin() as session:
52+
crud = CRUDPlus(Ins)
53+
result = await crud.delete_model_by_column(
54+
session, allow_multiple=True, logical_deletion=True, name__startswith='name'
55+
)
56+
assert result == 9

Diff for: ‎tests/test_select.py

+368
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import pytest
4+
5+
from sqlalchemy_crud_plus import CRUDPlus
6+
from tests.model import Ins
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_select_model(create_test_model, async_db_session):
11+
async with async_db_session() as session:
12+
crud = CRUDPlus(Ins)
13+
for i in range(1, 10):
14+
result = await crud.select_model(session, i)
15+
assert result.name == f'name_{i}'
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_select_model_by_column(create_test_model, async_db_session):
20+
async with async_db_session() as session:
21+
crud = CRUDPlus(Ins)
22+
for i in range(1, 10):
23+
result = await crud.select_model_by_column(session, name=f'name_{i}')
24+
assert result.name == f'name_{i}'
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_select_model_by_column_with_and(create_test_model, async_db_session):
29+
async with async_db_session() as session:
30+
crud = CRUDPlus(Ins)
31+
for i in range(1, 10):
32+
result = await crud.select_model_by_column(session, id=i, name=f'name_{i}')
33+
assert result.name == f'name_{i}'
34+
35+
36+
@pytest.mark.asyncio
37+
async def test_select_model_by_column_with_gt(create_test_model, async_db_session):
38+
async with async_db_session() as session:
39+
crud = CRUDPlus(Ins)
40+
result = await crud.select_model_by_column(session, id__gt=1)
41+
assert result.id == 2
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_select_model_by_column_with_lt(create_test_model, async_db_session):
46+
async with async_db_session() as session:
47+
crud = CRUDPlus(Ins)
48+
result = await crud.select_model_by_column(session, id__lt=1)
49+
assert result is None
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_select_model_by_column_with_gte(create_test_model, async_db_session):
54+
async with async_db_session() as session:
55+
crud = CRUDPlus(Ins)
56+
result = await crud.select_model_by_column(session, id__ge=1)
57+
assert result.id == 1
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_select_model_by_column_with_lte(create_test_model, async_db_session):
62+
async with async_db_session() as session:
63+
crud = CRUDPlus(Ins)
64+
result = await crud.select_model_by_column(session, id__le=1)
65+
assert result.id == 1
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_select_model_by_column_with_eq(create_test_model, async_db_session):
70+
async with async_db_session() as session:
71+
crud = CRUDPlus(Ins)
72+
result = await crud.select_model_by_column(session, id__eq=1)
73+
assert result.id == 1
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_select_model_by_column_with_ne(create_test_model, async_db_session):
78+
async with async_db_session() as session:
79+
crud = CRUDPlus(Ins)
80+
result = await crud.select_model_by_column(session, id__ne=1)
81+
assert result.id == 2
82+
83+
84+
@pytest.mark.asyncio
85+
async def test_select_model_by_column_with_between(create_test_model, async_db_session):
86+
async with async_db_session() as session:
87+
crud = CRUDPlus(Ins)
88+
result = await crud.select_model_by_column(session, id__between=(0, 11))
89+
assert result.id == 1
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_select_model_by_column_with_is(create_test_model, async_db_session):
94+
async with async_db_session() as session:
95+
crud = CRUDPlus(Ins)
96+
result = await crud.select_model_by_column(session, del_flag__is=False)
97+
assert result.id == 1
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_select_model_by_column_with_is_not(create_test_model, async_db_session):
102+
async with async_db_session() as session:
103+
crud = CRUDPlus(Ins)
104+
result = await crud.select_model_by_column(session, del_flag__is_not=True)
105+
assert result.id == 1
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_select_model_by_column_with_is_distinct_from(create_test_model, async_db_session):
110+
async with async_db_session() as session:
111+
crud = CRUDPlus(Ins)
112+
result = await crud.select_model_by_column(session, del_flag__is_distinct_from=True)
113+
assert result.id == 1
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_select_model_by_column_with_is_not_distinct_from(create_test_model, async_db_session):
118+
async with async_db_session() as session:
119+
crud = CRUDPlus(Ins)
120+
result = await crud.select_model_by_column(session, del_flag__is_not_distinct_from=True)
121+
assert result is None
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_select_model_by_column_with_like(create_test_model, async_db_session):
126+
async with async_db_session() as session:
127+
crud = CRUDPlus(Ins)
128+
result = await crud.select_model_by_column(session, name__like='name%')
129+
assert result.id == 1
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_select_model_by_column_with_not_like(create_test_model, async_db_session):
134+
async with async_db_session() as session:
135+
crud = CRUDPlus(Ins)
136+
result = await crud.select_model_by_column(session, name__not_like='name%')
137+
assert result is None
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_select_model_by_column_with_ilike(create_test_model, async_db_session):
142+
async with async_db_session() as session:
143+
crud = CRUDPlus(Ins)
144+
result = await crud.select_model_by_column(session, name__ilike='NAME%')
145+
assert result.id == 1
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_select_model_by_column_with_not_ilike(create_test_model, async_db_session):
150+
async with async_db_session() as session:
151+
crud = CRUDPlus(Ins)
152+
result = await crud.select_model_by_column(session, name__not_ilike='NAME%')
153+
assert result is None
154+
155+
156+
@pytest.mark.asyncio
157+
async def test_select_model_by_column_with_startwith(create_test_model, async_db_session):
158+
async with async_db_session() as session:
159+
crud = CRUDPlus(Ins)
160+
result = await crud.select_model_by_column(session, name__startswith='name')
161+
assert result.id == 1
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_select_model_by_column_with_endwith(create_test_model, async_db_session):
166+
async with async_db_session() as session:
167+
crud = CRUDPlus(Ins)
168+
result = await crud.select_model_by_column(session, name__endswith='1')
169+
assert result.id == 1
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_select_model_by_column_with_contains(create_test_model, async_db_session):
174+
async with async_db_session() as session:
175+
crud = CRUDPlus(Ins)
176+
result = await crud.select_model_by_column(session, name__contains='name')
177+
assert result.id == 1
178+
179+
180+
@pytest.mark.asyncio
181+
@pytest.mark.skip(reason='match not available in sqlite')
182+
async def test_select_model_by_column_with_match(create_test_model, async_db_session):
183+
async with async_db_session() as session:
184+
crud = CRUDPlus(Ins)
185+
result = await crud.select_model_by_column(session, name__match='name')
186+
assert result.id == 1
187+
188+
189+
@pytest.mark.asyncio
190+
async def test_select_model_by_column_with_concat(create_test_model, async_db_session):
191+
async with async_db_session() as session:
192+
crud = CRUDPlus(Ins)
193+
result = await crud.select_model_by_column(
194+
session, name__concat={'value': '_concat', 'condition': {'eq': 'name_1_concat'}}
195+
)
196+
assert result is not None
197+
198+
199+
@pytest.mark.asyncio
200+
async def test_select_model_by_column_with_add_string(create_test_model, async_db_session):
201+
async with async_db_session() as session:
202+
crud = CRUDPlus(Ins)
203+
result = await crud.select_model_by_column(
204+
session, name__add={'value': '_add', 'condition': {'eq': 'name_1_add'}}
205+
)
206+
assert result is not None
207+
208+
209+
@pytest.mark.asyncio
210+
async def test_select_model_by_column_with_radd_string(create_test_model, async_db_session):
211+
async with async_db_session() as session:
212+
crud = CRUDPlus(Ins)
213+
result = await crud.select_model_by_column(
214+
session, name__radd={'value': 'radd_', 'condition': {'eq': 'radd_name_1'}}
215+
)
216+
assert result is not None
217+
218+
219+
@pytest.mark.asyncio
220+
async def test_select_model_by_column_with_add_number(create_test_model, async_db_session):
221+
async with async_db_session() as session:
222+
crud = CRUDPlus(Ins)
223+
result = await crud.select_model_by_column(session, id__add={'value': 1, 'condition': {'eq': 2}})
224+
assert result is not None
225+
226+
227+
@pytest.mark.asyncio
228+
async def test_select_model_by_column_with_radd_number(create_test_model, async_db_session):
229+
async with async_db_session() as session:
230+
crud = CRUDPlus(Ins)
231+
result = await crud.select_model_by_column(session, id__radd={'value': 1, 'condition': {'eq': 2}})
232+
assert result is not None
233+
234+
235+
@pytest.mark.asyncio
236+
async def test_select_model_by_column_with_sub(create_test_model, async_db_session):
237+
async with async_db_session() as session:
238+
crud = CRUDPlus(Ins)
239+
result = await crud.select_model_by_column(session, id__sub={'value': 1, 'condition': {'eq': 0}})
240+
assert result is not None
241+
242+
243+
@pytest.mark.asyncio
244+
async def test_select_model_by_column_with_rsub(create_test_model, async_db_session):
245+
async with async_db_session() as session:
246+
crud = CRUDPlus(Ins)
247+
result = await crud.select_model_by_column(session, id__rsub={'value': 2, 'condition': {'eq': 1}})
248+
assert result is not None
249+
250+
251+
@pytest.mark.asyncio
252+
async def test_select_model_by_column_with_mul(create_test_model, async_db_session):
253+
async with async_db_session() as session:
254+
crud = CRUDPlus(Ins)
255+
result = await crud.select_model_by_column(session, id__mul={'value': 1, 'condition': {'eq': 1}})
256+
assert result is not None
257+
258+
259+
@pytest.mark.asyncio
260+
async def test_select_model_by_column_with_rmul(create_test_model, async_db_session):
261+
async with async_db_session() as session:
262+
crud = CRUDPlus(Ins)
263+
result = await crud.select_model_by_column(session, id__rmul={'value': 1, 'condition': {'eq': 1}})
264+
assert result is not None
265+
266+
267+
@pytest.mark.asyncio
268+
async def test_select_model_by_column_with_truediv(create_test_model, async_db_session):
269+
async with async_db_session() as session:
270+
crud = CRUDPlus(Ins)
271+
result = await crud.select_model_by_column(session, id__truediv={'value': 1, 'condition': {'eq': 1}})
272+
assert result is not None
273+
274+
275+
@pytest.mark.asyncio
276+
async def test_select_model_by_column_with_rtruediv(create_test_model, async_db_session):
277+
async with async_db_session() as session:
278+
crud = CRUDPlus(Ins)
279+
result = await crud.select_model_by_column(session, id__rtruediv={'value': 1, 'condition': {'eq': 1}})
280+
assert result is not None
281+
282+
283+
@pytest.mark.asyncio
284+
async def test_select_model_by_column_with_floordiv(create_test_model, async_db_session):
285+
async with async_db_session() as session:
286+
crud = CRUDPlus(Ins)
287+
result = await crud.select_model_by_column(session, id__floordiv={'value': 1, 'condition': {'eq': 1}})
288+
assert result is not None
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_select_model_by_column_with_rfloordiv(create_test_model, async_db_session):
293+
async with async_db_session() as session:
294+
crud = CRUDPlus(Ins)
295+
result = await crud.select_model_by_column(session, id__rfloordiv={'value': 1, 'condition': {'eq': 1}})
296+
assert result is not None
297+
298+
299+
@pytest.mark.asyncio
300+
async def test_select_model_by_column_with_mod(create_test_model, async_db_session):
301+
async with async_db_session() as session:
302+
crud = CRUDPlus(Ins)
303+
result = await crud.select_model_by_column(session, id__mod={'value': 1, 'condition': {'eq': 0}})
304+
assert result.id == 1
305+
306+
307+
@pytest.mark.asyncio
308+
async def test_select_model_by_column_with_rmod(create_test_model, async_db_session):
309+
async with async_db_session() as session:
310+
crud = CRUDPlus(Ins)
311+
result = await crud.select_model_by_column(session, id__rmod={'value': 1, 'condition': {'eq': 0}})
312+
assert result.id == 1
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_select_model_by_column_with_in(create_test_model, async_db_session):
317+
async with async_db_session() as session:
318+
crud = CRUDPlus(Ins)
319+
result = await crud.select_model_by_column(session, id__in=(1, 2, 3, 4, 5, 6, 7, 8, 9))
320+
assert result.id == 1
321+
322+
323+
@pytest.mark.asyncio
324+
async def test_select_model_by_column_with_not_in(create_test_model, async_db_session):
325+
async with async_db_session() as session:
326+
crud = CRUDPlus(Ins)
327+
result = await crud.select_model_by_column(session, id__not_in=(1, 2, 3, 4, 5, 6, 7, 8, 9))
328+
assert result is None
329+
330+
331+
@pytest.mark.asyncio
332+
async def test_select_model_by_column_with_or(create_test_model, async_db_session):
333+
async with async_db_session() as session:
334+
crud = CRUDPlus(Ins)
335+
result = await crud.select_model_by_column(session, id__or={'le': 1, 'eq': 1})
336+
assert result.id == 1
337+
338+
339+
@pytest.mark.asyncio
340+
async def test_select_models(create_test_model, async_db_session):
341+
async with async_db_session.begin() as session:
342+
crud = CRUDPlus(Ins)
343+
result = await crud.select_models(session)
344+
assert len(result) == 9
345+
346+
347+
@pytest.mark.asyncio
348+
async def test_select_models_order_default_asc(create_test_model, async_db_session):
349+
async with async_db_session() as session:
350+
crud = CRUDPlus(Ins)
351+
result = await crud.select_models_order(session, ['id', 'name'])
352+
assert result[0].id == 1
353+
354+
355+
@pytest.mark.asyncio
356+
async def test_select_models_order_desc(create_test_model, async_db_session):
357+
async with async_db_session() as session:
358+
crud = CRUDPlus(Ins)
359+
result = await crud.select_models_order(session, ['id', 'name'], ['desc', 'desc'])
360+
assert result[0].id == 9
361+
362+
363+
@pytest.mark.asyncio
364+
async def test_select_models_order_asc_and_desc(create_test_model, async_db_session):
365+
async with async_db_session() as session:
366+
crud = CRUDPlus(Ins)
367+
result = await crud.select_models_order(session, ['id', 'name'], ['asc', 'desc'])
368+
assert result[0].id == 1

Diff for: ‎tests/test_update.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import pytest
4+
5+
from sqlalchemy_crud_plus import CRUDPlus
6+
from tests.model import Ins
7+
from tests.schema import ModelTest
8+
9+
10+
@pytest.mark.asyncio
11+
async def test_update_model(create_test_model, async_db_session):
12+
async with async_db_session.begin() as session:
13+
crud = CRUDPlus(Ins)
14+
data = ModelTest(name='name_update_1')
15+
result = await crud.update_model(session, 1, data)
16+
assert result == 1
17+
result = await session.get(Ins, 1)
18+
assert result.name == 'name_update_1'
19+
20+
21+
@pytest.mark.asyncio
22+
async def test_update_model_by_column(create_test_model, async_db_session):
23+
async with async_db_session.begin() as session:
24+
crud = CRUDPlus(Ins)
25+
data = ModelTest(name='name_update_1')
26+
result = await crud.update_model_by_column(session, data, name='name_1')
27+
assert result == 1
28+
result = await session.get(Ins, 1)
29+
assert result.name == 'name_update_1'
30+
31+
32+
@pytest.mark.asyncio
33+
async def test_update_model_by_column_with_and(create_test_model, async_db_session):
34+
async with async_db_session.begin() as session:
35+
crud = CRUDPlus(Ins)
36+
data = ModelTest(name='name_update_1')
37+
result = await crud.update_model_by_column(session, data, id=1, name='name_1')
38+
assert result == 1
39+
result = await session.get(Ins, 1)
40+
assert result.name == 'name_update_1'
41+
42+
43+
@pytest.mark.asyncio
44+
async def test_update_model_by_column_with_filter(create_test_model, async_db_session):
45+
async with async_db_session.begin() as session:
46+
crud = CRUDPlus(Ins)
47+
data = ModelTest(name='name_update_1')
48+
result = await crud.update_model_by_column(session, data, id__eq=1)
49+
assert result == 1
50+
result = await session.get(Ins, 1)
51+
assert result.name == 'name_update_1'
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_update_model_by_column_allow_multiple(create_test_model, async_db_session):
56+
async with async_db_session.begin() as session:
57+
crud = CRUDPlus(Ins)
58+
data = ModelTest(name='name_update_1')
59+
result = await crud.update_model_by_column(session, data, allow_multiple=True, name__startswith='name')
60+
assert result == 9
61+
result = await session.get(Ins, 1)
62+
assert result.name == 'name_update_1'

0 commit comments

Comments
 (0)
Please sign in to comment.