Skip to content

Commit

Permalink
Tidy code
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Jun 11, 2024
1 parent c06fe41 commit d9a6bcc
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
60 changes: 32 additions & 28 deletions abcd/backends/atoms_opensearch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from typing import Iterable, Optional, Union
import logging
Expand Down Expand Up @@ -40,20 +40,20 @@
class OpenSearchQuery(AbstractQuerySet):
"""Class to parse and build queries for OpenSearch."""

def __call__(self, query: Union[dict, str, list, None]) -> Union[dict, None]:
def __call__(self, query: Optional[Union[dict, str, list]]) -> Optional[dict]:
"""
Parses and builds queries for OpenSearch.
Parameters
----------
query: Union[dict, str, list, None]
query: Optional[Union[dict, str, list]]
Query to be parsed for OpenSearch. If passed as a dictionary, the query is
left unchanged. If passed a string or list, the query is treated as a query
string, based on Lucene query syntax.
Returns
-------
Union[dict, None]
Optional[dict]
The parsed query for OpenSearch.
"""
if not query:
Expand Down Expand Up @@ -327,6 +327,7 @@ def delete(self, query: Optional[Union[dict, str]] = None):
Query to filter documents to be deleted. Default is `None`.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)
body = {"query": query}

self.client.delete_by_query(
Expand Down Expand Up @@ -354,7 +355,7 @@ def refresh(self):
"""
self.client.indices.refresh(index=self.index_name)

def save_bulk(self, actions: Iterable, **kwargs):
def save_bulk(self, actions: Iterable[dict], **kwargs):
"""
Save a collection of documents in bulk.
Expand Down Expand Up @@ -410,7 +411,7 @@ def push(
)
data.save()

elif isinstance(atoms, Generator) or isinstance(atoms, list):
elif isinstance(atoms, Iterator) or isinstance(atoms, list):
actions = []
for i, item in enumerate(atoms):
if isinstance(extra_info, list):
Expand All @@ -431,7 +432,7 @@ def push(
def upload(
self,
file: Path,
extra_infos: Optional[Union[Iterable, dict]] = None,
extra_infos: Union[Iterable, dict] = (),
store_calc: bool = True,
):
"""
Expand All @@ -441,9 +442,9 @@ def upload(
----------
file: Path
Path to file to be uploaded
extra_infos: Optional[Union[Iterable, dict]]
extra_infos: Union[Iterable, dict]
Extra information to store in the document with the atoms data.
Default is `None`.
Default is `()`.
store_calc: bool, optional
Whether to store data from the calculator attached to atoms.
Default is `True`.
Expand All @@ -452,19 +453,14 @@ def upload(
if isinstance(file, str):
file = Path(file)

extra_info = {}
if extra_infos:
for info in extra_infos:
extra_info.update(extras.parser.parse(info)) # type: ignore
extra_info = dict(map(extras.parser.parse, extra_infos))

extra_info["filename"] = str(file)

data = iread(str(file))
self.push(data, extra_info, store_calc=store_calc)

def get_items(
self, query: Optional[Union[dict, str]] = None
) -> Generator[dict, None, None]:
def get_items(self, query: Optional[Union[dict, str]] = None) -> Iterator[dict]:
"""
Get data as a dictionary from documents in the database.
Expand All @@ -475,9 +471,11 @@ def get_items(
Returns
-------
Generator for dictionary of data.
Iterator[dict]
Iterator for dictionary of data.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)
query = {
"query": query,
}
Expand All @@ -489,9 +487,7 @@ def get_items(
):
yield {"_id": hit["_id"], **hit["_source"]}

def get_atoms(
self, query: Optional[Union[dict, str]] = None
) -> Generator[Atoms, None, None]:
def get_atoms(self, query: Optional[Union[dict, str]] = None) -> Iterator[Atoms]:
"""
Get data as Atoms object from documents in the database.
Expand All @@ -502,9 +498,11 @@ def get_atoms(
Returns
-------
Generator for AtomsModel object of data.
Iterator[Atoms]
Generator for AtomsModel object of data.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)
query = {
"query": query,
}
Expand All @@ -514,7 +512,7 @@ def get_atoms(
index=self.index_name,
query=query,
):
yield AtomsModel(None, None, hit["_source"]).to_ase()
yield AtomsModel(dict=hit["_source"]).to_ase()

def count(self, query: Optional[Union[dict, str]] = None, timeout=30.0) -> int:
"""
Expand Down Expand Up @@ -603,6 +601,7 @@ def property(
if only one property is given.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)
query = {
"query": query,
}
Expand Down Expand Up @@ -662,6 +661,7 @@ def count_property(self, name, query: Optional[Union[dict, str]] = None) -> dict
matching documents.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)

body = {
"size": 0,
Expand All @@ -678,9 +678,9 @@ def count_property(self, name, query: Optional[Union[dict, str]] = None) -> dict

prop = {}

for val in self.client.search(
index=self.index_name, body=body
)["aggregations"][format(name)]["buckets"]:
for val in self.client.search(index=self.index_name, body=body)["aggregations"][
format(name)
]["buckets"]:
prop[val["key"]] = val["doc_count"]

return prop
Expand All @@ -702,6 +702,7 @@ def properties(self, query: Optional[Union[dict, str]] = None) -> dict:
the properties of that type.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)

properties = {}

Expand Down Expand Up @@ -793,6 +794,7 @@ def count_properties(self, query: Optional[Union[dict, str]] = None) -> dict:
corresponding to their counts, categories and data types.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)
properties = {}

try:
Expand Down Expand Up @@ -942,7 +944,7 @@ def delete_property(self, name: str, query: Optional[Union[dict, str]] = None):

def hist(
self, name: str, query: Optional[Union[dict, str]] = None, **kwargs
) -> Union[dict, None]:
) -> Optional[dict]:
"""
Calculate histogram statistics for a property from all matching documents.
Expand All @@ -955,10 +957,12 @@ def hist(
Returns
-------
Dictionary containing histogram statistics, including the number of
bins, edges, counts, min, max, and standard deviation.
Optional[dict]
Dictionary containing histogram statistics, including the number of
bins, edges, counts, min, max, and standard deviation.
"""
query = self.parser(query)
logger.info("parsed query: %s", query)

data = self.property(name, query)
return utils.histogram(name, data, **kwargs)
Expand Down
7 changes: 2 additions & 5 deletions abcd/backends/atoms_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(
"`struct_name_label` must be specified if store_struct_file is"
" True."
)
self.struct_name_label = struct_name_label
self.struct_name_label = struct_name_label
self.set_struct_files()

def _separate_units(self):
Expand Down Expand Up @@ -175,15 +175,12 @@ def get_struct_file(self, struct_name: str) -> str:
-------
Filename for the current structure.
"""
if struct_name is None:
raise ValueError("`struct_name` must be specified")
if "{struct_name}" not in self.struct_file_template:
raise ValueError(
"'struct_name' must be a variable in the template file: "
f"{self.struct_file_template}"
)
else:
return eval(f"f'{self.struct_file_template}'")
return eval(f"f'{self.struct_file_template}'")

def to_list(self) -> list[dict]:
"""
Expand Down
2 changes: 1 addition & 1 deletion abcd/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def histogram(name, data, **kwargs):
if not data:
return None

if data and isinstance(data, list):
if isinstance(data, list):
ptype = type(data[0])

if not all(isinstance(x, ptype) for x in data):
Expand Down
4 changes: 1 addition & 3 deletions abcd/frontends/commandline/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,7 @@ def server(*, abcd_url, url, api_only, **kwargs):
from urllib.parse import urlparse
from abcd.server.app import create_app

logger.info(
"SERVER - abcd: %s, url: %s, api_only: %s", abcd_url, url, api_only
)
logger.info("SERVER - abcd: %s, url: %s, api_only: %s", abcd_url, url, api_only)

if api_only:
print("Not implemented yet!")
Expand Down

0 comments on commit d9a6bcc

Please sign in to comment.