Skip to content

Commit 60e4ee7

Browse files
Minimize use of to_model, especially when fetching many
1 parent b476be9 commit 60e4ee7

File tree

7 files changed

+84
-63
lines changed

7 files changed

+84
-63
lines changed

aiida_restapi/routers/computers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@read_router.get(
2424
'/schema',
25-
response_model=dict,
25+
response_model=dict[str, t.Any],
2626
responses={
2727
422: {'model': errors.RequestValidationError},
2828
},
@@ -32,7 +32,7 @@ async def get_computers_schema(
3232
t.Literal['get', 'post'],
3333
Query(description='Type of schema to retrieve: "get" or "post"'),
3434
] = 'get',
35-
) -> dict:
35+
) -> dict[str, t.Any]:
3636
"""Get JSON schema for AiiDA computers."""
3737
return service.get_schema(which=which)
3838

@@ -61,7 +61,7 @@ async def get_computers(
6161
query.QueryParams,
6262
Depends(query.query_params),
6363
],
64-
) -> PaginatedResults[orm.Computer.Model]:
64+
) -> PaginatedResults[dict[str, t.Any]]:
6565
"""Get AiiDA computers with optional filtering, sorting, and/or pagination."""
6666
return service.get_many(query_params)
6767

@@ -78,7 +78,7 @@ async def get_computers(
7878
},
7979
)
8080
@with_dbenv()
81-
async def get_computer(pk: str) -> orm.Computer.Model:
81+
async def get_computer(pk: str) -> dict[str, t.Any]:
8282
"""Get AiiDA computer by pk."""
8383
return service.get_one(pk)
8484

@@ -112,6 +112,6 @@ async def get_computer_metadata(pk: str) -> dict[str, t.Any]:
112112
async def create_computer(
113113
computer_model: orm.Computer.CreateModel,
114114
current_user: t.Annotated[UserInDB, Depends(get_current_active_user)],
115-
) -> orm.Computer.Model:
115+
) -> dict[str, t.Any]:
116116
"""Create new AiiDA computer."""
117117
return service.add_one(computer_model)

aiida_restapi/routers/groups.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@read_router.get(
2424
'/schema',
25-
response_model=dict,
25+
response_model=dict[str, t.Any],
2626
responses={
2727
422: {'model': errors.RequestValidationError},
2828
},
@@ -32,7 +32,7 @@ async def get_groups_schema(
3232
t.Literal['get', 'post'],
3333
Query(description='Type of schema to retrieve: "get" or "post"'),
3434
] = 'get',
35-
) -> dict:
35+
) -> dict[str, t.Any]:
3636
"""Get JSON schema for AiiDA groups."""
3737
return service.get_schema(which=which)
3838

@@ -61,7 +61,7 @@ async def get_groups(
6161
query.QueryParams,
6262
Depends(query.query_params),
6363
],
64-
) -> PaginatedResults[orm.Group.Model]:
64+
) -> PaginatedResults[dict[str, t.Any]]:
6565
"""Get AiiDA groups with optional filtering, sorting, and/or pagination."""
6666
return service.get_many(query_params)
6767

@@ -78,7 +78,7 @@ async def get_groups(
7878
},
7979
)
8080
@with_dbenv()
81-
async def get_group(uuid: str) -> orm.Group.Model:
81+
async def get_group(uuid: str) -> dict[str, t.Any]:
8282
"""Get AiiDA group by uuid."""
8383
return service.get_one(uuid)
8484

@@ -95,9 +95,9 @@ async def get_group(uuid: str) -> orm.Group.Model:
9595
},
9696
)
9797
@with_dbenv()
98-
async def get_group_user(uuid: str) -> orm.User.Model:
98+
async def get_group_user(uuid: str) -> dict[str, t.Any]:
9999
"""Get the user associated with a group."""
100-
return t.cast(orm.User.Model, service.get_related_one(uuid, orm.User))
100+
return service.get_related_one(uuid, orm.User)
101101

102102

103103
@read_router.get(
@@ -118,7 +118,7 @@ async def get_group_nodes(
118118
query.QueryParams,
119119
Depends(query.query_params),
120120
],
121-
) -> PaginatedResults[orm.Node.Model]:
121+
) -> PaginatedResults[dict[str, t.Any]]:
122122
"""Get the nodes of a group."""
123123
return service.get_related_many(uuid, orm.Node, query_params)
124124

@@ -152,6 +152,6 @@ async def get_group_extras(uuid: str) -> dict[str, t.Any]:
152152
async def create_group(
153153
group_model: orm.Group.CreateModel,
154154
current_user: t.Annotated[UserInDB, Depends(get_current_active_user)],
155-
) -> orm.Group.Model:
155+
) -> dict[str, t.Any]:
156156
"""Create new AiiDA group."""
157157
return service.add_one(group_model)

aiida_restapi/routers/nodes.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
@read_router.get(
4141
'/schema',
42-
response_model=dict,
42+
response_model=dict[str, t.Any],
4343
responses={
4444
422: {'model': errors.InvalidNodeTypeError},
4545
},
@@ -56,7 +56,7 @@ async def get_nodes_schema(
5656
t.Literal['get', 'post'],
5757
Query(description='Type of schema to retrieve'),
5858
] = 'get',
59-
) -> dict:
59+
) -> dict[str, t.Any]:
6060
"""Get JSON schema for the base AiiDA node 'get' model."""
6161
if not node_type:
6262
return orm.Node.Model.model_json_schema()
@@ -120,7 +120,7 @@ async def get_nodes(
120120
query.QueryParams,
121121
Depends(query.query_params),
122122
],
123-
) -> PaginatedResults[orm.Node.Model]:
123+
) -> PaginatedResults[dict[str, t.Any]]:
124124
"""Get AiiDA nodes with optional filtering, sorting, and/or pagination."""
125125
return service.get_many(query_params)
126126

@@ -158,7 +158,7 @@ async def get_node_types() -> list:
158158
},
159159
)
160160
@with_dbenv()
161-
async def get_node(uuid: str) -> orm.Node.Model:
161+
async def get_node(uuid: str) -> dict[str, t.Any]:
162162
"""Get AiiDA node by uuid."""
163163
return service.get_one(uuid)
164164

@@ -175,9 +175,9 @@ async def get_node(uuid: str) -> orm.Node.Model:
175175
},
176176
)
177177
@with_dbenv()
178-
async def get_node_user(uuid: str) -> orm.User.Model:
178+
async def get_node_user(uuid: str) -> dict[str, t.Any]:
179179
"""Get the user associated with a node."""
180-
return t.cast(orm.User.Model, service.get_related_one(uuid, orm.User))
180+
return service.get_related_one(uuid, orm.User)
181181

182182

183183
@read_router.get(
@@ -192,9 +192,9 @@ async def get_node_user(uuid: str) -> orm.User.Model:
192192
},
193193
)
194194
@with_dbenv()
195-
async def get_node_computer(uuid: str) -> orm.Computer.Model:
195+
async def get_node_computer(uuid: str) -> dict[str, t.Any]:
196196
"""Get the computer associated with a node."""
197-
return t.cast(orm.Computer.Model, service.get_related_one(uuid, orm.Computer))
197+
return service.get_related_one(uuid, orm.Computer)
198198

199199

200200
@read_router.get(
@@ -215,7 +215,7 @@ async def get_node_groups(
215215
query.QueryParams,
216216
Depends(query.query_params),
217217
],
218-
) -> PaginatedResults[orm.Group.Model]:
218+
) -> PaginatedResults[dict[str, t.Any]]:
219219
"""Get the groups of a node."""
220220
return service.get_related_many(uuid, orm.Group, query_params)
221221

aiida_restapi/routers/submit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def validate_inputs(cls, inputs: dict[str, t.Any]) -> dict[str, t.Any]:
8484
async def submit_process(
8585
process: ProcessSubmitModel,
8686
current_user: t.Annotated[UserInDB, Depends(get_current_active_user)],
87-
) -> orm.Node.Model:
87+
) -> dict[str, t.Any]:
8888
"""Submit new AiiDA process."""
8989
try:
9090
entry_point_process = load_entry_point_from_string(process.entry_point)
@@ -94,4 +94,4 @@ async def submit_process(
9494
process_node = engine.submit(entry_point_process, **process.inputs)
9595
except Exception as exception:
9696
raise exceptions.InputValidationError(str(exception)) from exception
97-
return t.cast(orm.Node.Model, process_node.to_model())
97+
return process_node.serialize(minimal=True)

aiida_restapi/routers/users.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
@read_router.get(
2424
'/schema',
25-
response_model=dict,
25+
response_model=dict[str, t.Any],
2626
responses={
2727
422: {'model': errors.RequestValidationError},
2828
},
@@ -32,7 +32,7 @@ async def get_users_schema(
3232
t.Literal['get', 'post'],
3333
Query(description='Type of schema to retrieve: "get" or "post"'),
3434
] = 'get',
35-
) -> dict:
35+
) -> dict[str, t.Any]:
3636
"""Get JSON schema for AiiDA users."""
3737
return service.get_schema(which=which)
3838

@@ -61,7 +61,7 @@ async def get_users(
6161
query.QueryParams,
6262
Depends(query.query_params),
6363
],
64-
) -> PaginatedResults[orm.User.Model]:
64+
) -> PaginatedResults[dict[str, t.Any]]:
6565
"""Get AiiDA users with optional filtering, sorting, and/or pagination."""
6666
return service.get_many(query_params)
6767

@@ -76,7 +76,7 @@ async def get_users(
7676
},
7777
)
7878
@with_dbenv()
79-
async def get_user(pk: int) -> orm.User.Model:
79+
async def get_user(pk: int) -> dict[str, t.Any]:
8080
"""Get AiiDA user by pk."""
8181
return service.get_one(pk)
8282

@@ -95,6 +95,6 @@ async def get_user(pk: int) -> orm.User.Model:
9595
async def create_user(
9696
user_model: orm.User.CreateModel,
9797
current_user: t.Annotated[UserInDB, Depends(get_current_active_user)],
98-
) -> orm.User.Model:
98+
) -> dict[str, t.Any]:
9999
"""Create new AiiDA user."""
100100
return service.add_one(user_model)

0 commit comments

Comments
 (0)