diff --git a/abcd/__init__.py b/abcd/__init__.py
index 52279171..b8008379 100644
--- a/abcd/__init__.py
+++ b/abcd/__init__.py
@@ -14,7 +14,7 @@ class ABCD(object):
@classmethod
def from_config(cls, config):
# Factory method
- url = config['url']
+ url = config["url"]
return ABCD.from_url(url)
@classmethod
@@ -23,35 +23,38 @@ def from_url(cls, url, **kwargs):
r = parse.urlparse(url)
logger.info(r)
- if r.scheme == 'mongodb':
+ if r.scheme == "mongodb":
conn_settings = {
- 'host': r.hostname,
- 'port': r.port,
- 'username': r.username,
- 'password': r.password,
- 'authSource': 'admin',
+ "host": r.hostname,
+ "port": r.port,
+ "username": r.username,
+ "password": r.password,
+ "authSource": "admin",
}
- db = r.path.split('/')[1] if r.path else None
- db = db if db else 'abcd'
+ db = r.path.split("/")[1] if r.path else None
+ db = db if db else "abcd"
from abcd.backends.atoms_pymongo import MongoDatabase
+
return MongoDatabase(db_name=db, **conn_settings, **kwargs)
- elif r.scheme == 'http' or r.scheme == 'https':
- raise NotImplementedError('http not yet supported! soon...')
- elif r.scheme == 'ssh':
- raise NotImplementedError('ssh not yet supported! soon...')
+ elif r.scheme == "http" or r.scheme == "https":
+ raise NotImplementedError("http not yet supported! soon...")
+ elif r.scheme == "ssh":
+ raise NotImplementedError("ssh not yet supported! soon...")
else:
- raise NotImplementedError('Unable to recognise the type of connection. (url: {})'.format(url))
+ raise NotImplementedError(
+ "Unable to recognise the type of connection. (url: {})".format(url)
+ )
-if __name__ == '__main__':
+if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
# url = 'mongodb://mongoadmin:secret@localhost:27017'
- url = 'mongodb://mongoadmin:secret@localhost:27017/abcd_new'
+ url = "mongodb://mongoadmin:secret@localhost:27017/abcd_new"
abcd = ABCD.from_url(url)
abcd.print_info()
diff --git a/abcd/backends/atoms_http.py b/abcd/backends/atoms_http.py
index 1ff6e95f..ee62e61a 100644
--- a/abcd/backends/atoms_http.py
+++ b/abcd/backends/atoms_http.py
@@ -13,16 +13,15 @@
class Atoms(ase.Atoms):
-
@classmethod
def from_dict(cls, data):
- return cls(numbers=data['numbers'], positions=data['positions'])
+ return cls(numbers=data["numbers"], positions=data["positions"])
class HttpDatabase(Database):
"""client/local interface"""
- def __init__(self, url='http://localhost'):
+ def __init__(self, url="http://localhost"):
super().__init__()
self.url = url
@@ -30,7 +29,9 @@ def __init__(self, url='http://localhost'):
def push(self, atoms: ase.Atoms):
# todo: list of Atoms, metadata(user, project, tags)
- message = requests.put(self.url + '/calculation', json=DictEncoder().encode(atoms))
+ message = requests.put(
+ self.url + "/calculation", json=DictEncoder().encode(atoms)
+ )
# message = json.dumps(atoms)
# message_hash = hashlib.md5(message.encode('utf-8')).hexdigest()
logger.info(message)
@@ -44,29 +45,33 @@ def query(self, query_string):
pass
def search(self, query_string: str) -> List[str]:
- results = requests.get(self.url + '/calculation').json()
+ results = requests.get(self.url + "/calculation").json()
return results
def get_atoms(self, id: str) -> Atoms:
- data = requests.get(self.url + '/calculation/{}'.format(id)).json()
+ data = requests.get(self.url + "/calculation/{}".format(id)).json()
atoms = Atoms.from_dict(data)
return atoms
def __repr__(self):
- return 'ABCD(type={}, url={}, ...)'.format(self.__class__.__name__, self.url)
+ return "ABCD(type={}, url={}, ...)".format(self.__class__.__name__, self.url)
def _repr_html_(self):
"""jupyter notebook representation"""
- return 'ABCD database'
+ return "ABCD database"
def print_info(self):
"""shows basic information about the connected database"""
- out = linesep.join([
- '{:=^50}'.format(' ABCD Database '),
- '{:>10}: {}'.format('type', 'remote (http/https)'),
- linesep.join('{:>10}: {}'.format(k, v) for k, v in self.db.info().items())
- ])
+ out = linesep.join(
+ [
+ "{:=^50}".format(" ABCD Database "),
+ "{:>10}: {}".format("type", "remote (http/https)"),
+ linesep.join(
+ "{:>10}: {}".format(k, v) for k, v in self.db.info().items()
+ ),
+ ]
+ )
print(out)
@@ -80,6 +85,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass
-if __name__ == '__main__':
- abcd = HttpDatabase(url='http://localhost:8080/api')
+if __name__ == "__main__":
+ abcd = HttpDatabase(url="http://localhost:8080/api")
abcd.print_info()
diff --git a/abcd/backends/atoms_pymongo.py b/abcd/backends/atoms_pymongo.py
index 217e1c52..993c6eb0 100644
--- a/abcd/backends/atoms_pymongo.py
+++ b/abcd/backends/atoms_pymongo.py
@@ -26,12 +26,12 @@
logger = logging.getLogger(__name__)
map_types = {
- bool: 'bool',
- float: 'float',
- int: 'int',
- str: 'str',
- datetime: 'date',
- dict: 'dict'
+ bool: "bool",
+ float: "float",
+ int: "int",
+ str: "str",
+ datetime: "date",
+ dict: "dict",
}
@@ -49,15 +49,14 @@ def from_atoms(cls, collection, atoms: Atoms, extra_info=None, store_calc=True):
@property
def _id(self):
- return self.get('_id', None)
+ return self.get("_id", None)
def save(self):
if not self._id:
self._collection.insert_one(self)
else:
- new_values = { "$set": self }
- self._collection.update_one(
- {"_id": ObjectId(self._id)}, new_values)
+ new_values = {"$set": self}
+ self._collection.update_one({"_id": ObjectId(self._id)}, new_values)
def remove(self):
if self._id:
@@ -66,28 +65,27 @@ def remove(self):
class MongoQuery(AbstractQuerySet):
-
def __init__(self):
pass
def visit(self, syntax_tree):
op, *args = syntax_tree
try:
- fun = self.__getattribute__('visit_' + op.lower())
+ fun = self.__getattribute__("visit_" + op.lower())
return fun(*args)
except KeyError:
pass
def visit_name(self, field):
- return {field: {'$exists': True}}
+ return {field: {"$exists": True}}
def visit_not(self, value):
_, field = value
- return {field: {'$exists': False}}
+ return {field: {"$exists": False}}
def visit_and(self, *args):
print(args)
- return {'$and': [self.visit(arg) for arg in args]}
+ return {"$and": [self.visit(arg) for arg in args]}
# TODO recursively combining all the and statements
# out = {}
# for arg in args:
@@ -97,28 +95,28 @@ def visit_and(self, *args):
# return out
def visit_or(self, *args):
- return {'$or': [self.visit(arg) for arg in args]}
+ return {"$or": [self.visit(arg) for arg in args]}
def visit_eq(self, field, value):
return {field[1]: value[1]}
def visit_re(self, field, value):
- return {field[1]: {'$regex': value[1]}}
+ return {field[1]: {"$regex": value[1]}}
def visit_gt(self, field, value):
- return {field[1]: {'$gt': value[1]}}
+ return {field[1]: {"$gt": value[1]}}
def visit_gte(self, field, value):
- return {field[1]: {'$gte': value[1]}}
+ return {field[1]: {"$gte": value[1]}}
def visit_lt(self, field, value):
- return {field[1]: {'$lt': value[1]}}
+ return {field[1]: {"$lt": value[1]}}
def visit_lte(self, field, value):
- return {field[1]: {'$lte': value[1]}}
+ return {field[1]: {"$lte": value[1]}}
def visit_in(self, field, *values):
- return {field[1]: {'$in': [value[1] for value in values]}}
+ return {field[1]: {"$in": [value[1] for value in values]}}
def __enter__(self):
return self
@@ -127,12 +125,13 @@ def __exit__(self, exc_type, exc_val, exc_tb):
pass
def __call__(self, ast):
- logger.info('parsed ast: {}'.format(ast))
+ logger.info("parsed ast: {}".format(ast))
if isinstance(ast, dict):
return ast
elif isinstance(ast, str):
from abcd.parsers.queries import parser
+
p = parser(ast)
return self.visit(p)
@@ -156,20 +155,43 @@ def wrapper(*args, query=None, **kwargs):
class MongoDatabase(AbstractABCD):
"""Wrapper to make database operations easy"""
- def __init__(self, host='localhost', port=27017,
- db_name='abcd', collection_name='atoms',
- username=None, password=None, authSource='admin', **kwargs):
+ def __init__(
+ self,
+ host="localhost",
+ port=27017,
+ db_name="abcd",
+ collection_name="atoms",
+ username=None,
+ password=None,
+ authSource="admin",
+ **kwargs
+ ):
super().__init__()
- logger.info((host, port, db_name, collection_name, username, password, authSource, kwargs))
+ logger.info(
+ (
+ host,
+ port,
+ db_name,
+ collection_name,
+ username,
+ password,
+ authSource,
+ kwargs,
+ )
+ )
self.client = MongoClient(
- host=host, port=port, username=username, password=password,
- authSource=authSource)
+ host=host,
+ port=port,
+ username=username,
+ password=password,
+ authSource=authSource,
+ )
try:
info = self.client.server_info() # Forces a call.
- logger.info('DB info: {}'.format(info))
+ logger.info("DB info: {}".format(info))
except pymongo.errors.OperationFailure:
raise abcd.errors.AuthenticationError()
@@ -184,12 +206,12 @@ def info(self):
host, port = self.client.address
return {
- 'host': host,
- 'port': port,
- 'db': self.db.name,
- 'collection': self.collection.name,
- 'number of confs': self.collection.count_documents({}),
- 'type': 'mongodb'
+ "host": host,
+ "port": port,
+ "db": self.db.name,
+ "collection": self.collection.name,
+ "number of confs": self.collection.count_documents({}),
+ "type": "mongodb",
}
def delete(self, query=None):
@@ -205,14 +227,18 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True):
extra_info = extras.parser.parse(extra_info)
if isinstance(atoms, Atoms):
- data = AtomsModel.from_atoms(self.collection, atoms, extra_info=extra_info, store_calc=store_calc)
+ data = AtomsModel.from_atoms(
+ self.collection, atoms, extra_info=extra_info, store_calc=store_calc
+ )
data.save()
# self.collection.insert_one(data)
elif isinstance(atoms, types.GeneratorType) or isinstance(atoms, list):
for item in atoms:
- data = AtomsModel.from_atoms(self.collection, item, extra_info=extra_info, store_calc=store_calc)
+ data = AtomsModel.from_atoms(
+ self.collection, item, extra_info=extra_info, store_calc=store_calc
+ )
data.save()
def upload(self, file: Path, extra_infos=None, store_calc=True):
@@ -225,7 +251,7 @@ def upload(self, file: Path, extra_infos=None, store_calc=True):
for info in extra_infos:
extra_info.update(extras.parser.parse(info))
- extra_info['filename'] = str(file)
+ extra_info["filename"] = str(file)
data = iread(str(file))
self.push(data, extra_info, store_calc=store_calc)
@@ -243,7 +269,7 @@ def get_atoms(self, query=None):
def count(self, query=None):
query = parser(query)
- logger.info('query; {}'.format(query))
+ logger.info("query; {}".format(query))
if not query:
query = {}
@@ -254,61 +280,69 @@ def property(self, name, query=None):
query = parser(query)
pipeline = [
- {'$match': query},
- {'$match': {'{}'.format(name): {"$exists": True}}},
- {'$project': {'_id': False, 'data': '${}'.format(name)}}
+ {"$match": query},
+ {"$match": {"{}".format(name): {"$exists": True}}},
+ {"$project": {"_id": False, "data": "${}".format(name)}},
]
- return [val['data'] for val in self.db.atoms.aggregate(pipeline)]
+ return [val["data"] for val in self.db.atoms.aggregate(pipeline)]
def properties(self, query=None):
query = parser(query)
properties = {}
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.info_keys'},
- {'$group': {'_id': '$derived.info_keys'}}
+ {"$match": query},
+ {"$unwind": "$derived.info_keys"},
+ {"$group": {"_id": "$derived.info_keys"}},
+ ]
+ properties["info"] = [
+ value["_id"] for value in self.db.atoms.aggregate(pipeline)
]
- properties['info'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)]
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.arrays_keys'},
- {'$group': {'_id': '$derived.arrays_keys'}}
+ {"$match": query},
+ {"$unwind": "$derived.arrays_keys"},
+ {"$group": {"_id": "$derived.arrays_keys"}},
+ ]
+ properties["arrays"] = [
+ value["_id"] for value in self.db.atoms.aggregate(pipeline)
]
- properties['arrays'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)]
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.derived_keys'},
- {'$group': {'_id': '$derived.derived_keys'}}
+ {"$match": query},
+ {"$unwind": "$derived.derived_keys"},
+ {"$group": {"_id": "$derived.derived_keys"}},
+ ]
+ properties["derived"] = [
+ value["_id"] for value in self.db.atoms.aggregate(pipeline)
]
- properties['derived'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)]
return properties
def get_type_of_property(self, prop, category):
# TODO: Probably it would be nicer to store the type info in the database from the beginning.
- atoms = self.db.atoms.find_one({prop: {'$exists': True}})
+ atoms = self.db.atoms.find_one({prop: {"$exists": True}})
data = atoms[prop]
- if category == 'arrays':
+ if category == "arrays":
if type(data[0]) == list:
- return 'array({}, N x {})'.format(map_types[type(data[0][0])], len(data[0]))
+ return "array({}, N x {})".format(
+ map_types[type(data[0][0])], len(data[0])
+ )
else:
- return 'vector({}, N)'.format(map_types[type(data[0])])
+ return "vector({}, N)".format(map_types[type(data[0])])
if type(data) == list:
if type(data[0]) == list:
if type(data[0][0]) == list:
- return 'list(list(...)'
+ return "list(list(...)"
else:
- return 'array({})'.format(map_types[type(data[0][0])])
+ return "array({})".format(map_types[type(data[0][0])])
else:
- return 'vector({})'.format(map_types[type(data[0])])
+ return "vector({})".format(map_types[type(data[0])])
else:
- return 'scalar({})'.format(map_types[type(data)])
+ return "scalar({})".format(map_types[type(data)])
def count_properties(self, query=None):
query = parser(query)
@@ -316,67 +350,69 @@ def count_properties(self, query=None):
properties = {}
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.info_keys'},
- {'$group': {'_id': '$derived.info_keys', 'count': {'$sum': 1}}}
+ {"$match": query},
+ {"$unwind": "$derived.info_keys"},
+ {"$group": {"_id": "$derived.info_keys", "count": {"$sum": 1}}},
]
info_keys = self.db.atoms.aggregate(pipeline)
for val in info_keys:
- properties[val['_id']] = {
- 'count': val['count'],
- 'category': 'info',
- 'dtype': self.get_type_of_property(val['_id'], 'info')
+ properties[val["_id"]] = {
+ "count": val["count"],
+ "category": "info",
+ "dtype": self.get_type_of_property(val["_id"], "info"),
}
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.arrays_keys'},
- {'$group': {'_id': '$derived.arrays_keys', 'count': {'$sum': 1}}}
+ {"$match": query},
+ {"$unwind": "$derived.arrays_keys"},
+ {"$group": {"_id": "$derived.arrays_keys", "count": {"$sum": 1}}},
]
arrays_keys = list(self.db.atoms.aggregate(pipeline))
for val in arrays_keys:
- properties[val['_id']] = {
- 'count': val['count'],
- 'category': 'arrays',
- 'dtype': self.get_type_of_property(val['_id'], 'arrays')
+ properties[val["_id"]] = {
+ "count": val["count"],
+ "category": "arrays",
+ "dtype": self.get_type_of_property(val["_id"], "arrays"),
}
pipeline = [
- {'$match': query},
- {'$unwind': '$derived.derived_keys'},
- {'$group': {'_id': '$derived.derived_keys', 'count': {'$sum': 1}}}
+ {"$match": query},
+ {"$unwind": "$derived.derived_keys"},
+ {"$group": {"_id": "$derived.derived_keys", "count": {"$sum": 1}}},
]
arrays_keys = list(self.db.atoms.aggregate(pipeline))
for val in arrays_keys:
- properties[val['_id']] = {
- 'count': val['count'],
- 'category': 'derived',
- 'dtype': self.get_type_of_property(val['_id'], 'derived')
+ properties[val["_id"]] = {
+ "count": val["count"],
+ "category": "derived",
+ "dtype": self.get_type_of_property(val["_id"], "derived"),
}
return properties
def add_property(self, data, query=None):
- logger.info('add: data={}, query={}'.format(data, query))
+ logger.info("add: data={}, query={}".format(data, query))
self.collection.update_many(
parser(query),
- {'$push': {'derived.info_keys': {'$each': list(data.keys())}},
- '$set': data})
+ {
+ "$push": {"derived.info_keys": {"$each": list(data.keys())}},
+ "$set": data,
+ },
+ )
def rename_property(self, name, new_name, query=None):
- logger.info('rename: query={}, old={}, new={}'.format(query, name, new_name))
+ logger.info("rename: query={}, old={}, new={}".format(query, name, new_name))
# TODO name in derived.info_keys OR name in derived.arrays_keys OR name in derived.derived_keys
self.collection.update_many(
- parser(query),
- {'$push': {'derived.info_keys': new_name}})
+ parser(query), {"$push": {"derived.info_keys": new_name}}
+ )
self.collection.update_many(
parser(query),
- {
- '$pull': {'derived.info_keys': name},
- '$rename': {name: new_name}})
+ {"$pull": {"derived.info_keys": name}, "$rename": {name: new_name}},
+ )
# self.collection.update_many(
# parser(query + ['arrays.{}'.format(name)]),
@@ -388,13 +424,15 @@ def rename_property(self, name, new_name, query=None):
# '$rename': {'arrays.{}'.format(name): 'arrays.{}'.format(new_name)}})
def delete_property(self, name, query=None):
- logger.info('delete: query={}, porperty={}'.format(name, query))
+ logger.info("delete: query={}, porperty={}".format(name, query))
self.collection.update_many(
parser(query),
- {'$pull': {'derived.info_keys': name,
- 'derived.arrays_keys': name},
- '$unset': {name: ''}})
+ {
+ "$pull": {"derived.info_keys": name, "derived.arrays_keys": name},
+ "$unset": {name: ""},
+ },
+ )
def hist(self, name, query=None, **kwargs):
@@ -411,21 +449,27 @@ def exec(self, code, query=None):
def __repr__(self):
host, port = self.client.address
- return '{}('.format(self.__class__.__name__) + \
- 'url={}:{}, '.format(host, port) + \
- 'db={}, '.format(self.db.name) + \
- 'collection={})'.format(self.collection.name)
+ return (
+ "{}(".format(self.__class__.__name__)
+ + "url={}:{}, ".format(host, port)
+ + "db={}, ".format(self.db.name)
+ + "collection={})".format(self.collection.name)
+ )
def _repr_html_(self):
"""Jupyter notebook representation"""
- return 'ABCD MongoDB database'
+ return "ABCD MongoDB database"
def print_info(self):
"""shows basic information about the connected database"""
- out = linesep.join(['{:=^50}'.format(' ABCD MongoDB '),
- '{:>10}: {}'.format('type', 'mongodb'),
- linesep.join('{:>10}: {}'.format(k, v) for k, v in self.info().items())])
+ out = linesep.join(
+ [
+ "{:=^50}".format(" ABCD MongoDB "),
+ "{:>10}: {}".format("type", "mongodb"),
+ linesep.join("{:>10}: {}".format(k, v) for k, v in self.info().items()),
+ ]
+ )
print(out)
@@ -449,26 +493,36 @@ def histogram(name, data, **kwargs):
return None
if ptype == float:
- bins = kwargs.get('bins', 10)
+ bins = kwargs.get("bins", 10)
return _hist_float(name, data, bins)
elif ptype == int:
- bins = kwargs.get('bins', 10)
+ bins = kwargs.get("bins", 10)
return _hist_int(name, data, bins)
elif ptype == str:
return _hist_str(name, data, **kwargs)
elif ptype == datetime:
- bins = kwargs.get('bins', 10)
+ bins = kwargs.get("bins", 10)
return _hist_date(name, data, bins)
else:
- print('{}: Histogram for list of {} types are not supported!'.format(name, type(data[0])))
- logger.info('{}: Histogram for list of {} types are not supported!'.format(name, type(data[0])))
+ print(
+ "{}: Histogram for list of {} types are not supported!".format(
+ name, type(data[0])
+ )
+ )
+ logger.info(
+ "{}: Histogram for list of {} types are not supported!".format(
+ name, type(data[0])
+ )
+ )
else:
- logger.info('{}: Histogram for {} types are not supported!'.format(name, type(data)))
+ logger.info(
+ "{}: Histogram for {} types are not supported!".format(name, type(data))
+ )
return None
@@ -477,16 +531,16 @@ def _hist_float(name, data, bins=10):
hist, bin_edges = np.histogram(data, bins=bins)
return {
- 'type': 'hist_float',
- 'name': name,
- 'bins': bins,
- 'edges': bin_edges,
- 'counts': hist,
- 'min': data.min(),
- 'max': data.max(),
- 'median': data.mean(),
- 'std': data.std(),
- 'var': data.var()
+ "type": "hist_float",
+ "name": name,
+ "bins": bins,
+ "edges": bin_edges,
+ "counts": hist,
+ "min": data.min(),
+ "max": data.max(),
+ "median": data.mean(),
+ "std": data.std(),
+ "var": data.var(),
}
@@ -497,16 +551,16 @@ def _hist_date(name, data, bins=10):
fromtimestamp = datetime.fromtimestamp
return {
- 'type': 'hist_date',
- 'name': name,
- 'bins': bins,
- 'edges': [fromtimestamp(d) for d in bin_edges],
- 'counts': hist,
- 'min': fromtimestamp(hist_data.min()),
- 'max': fromtimestamp(hist_data.max()),
- 'median': fromtimestamp(hist_data.mean()),
- 'std': fromtimestamp(hist_data.std()),
- 'var': fromtimestamp(hist_data.var())
+ "type": "hist_date",
+ "name": name,
+ "bins": bins,
+ "edges": [fromtimestamp(d) for d in bin_edges],
+ "counts": hist,
+ "min": fromtimestamp(hist_data.min()),
+ "max": fromtimestamp(hist_data.max()),
+ "median": fromtimestamp(hist_data.mean()),
+ "std": fromtimestamp(hist_data.std()),
+ "var": fromtimestamp(hist_data.var()),
}
@@ -520,16 +574,16 @@ def _hist_int(name, data, bins=10):
hist, bin_edges = np.histogram(data, bins=bins)
return {
- 'type': 'hist_int',
- 'name': name,
- 'bins': bins,
- 'edges': bin_edges,
- 'counts': hist,
- 'min': data.min(),
- 'max': data.max(),
- 'median': data.mean(),
- 'std': data.std(),
- 'var': data.var()
+ "type": "hist_int",
+ "name": name,
+ "bins": bins,
+ "edges": bin_edges,
+ "counts": hist,
+ "min": data.min(),
+ "max": data.max(),
+ "median": data.mean(),
+ "std": data.std(),
+ "var": data.var(),
}
@@ -538,7 +592,9 @@ def _hist_str(name, data, bins=10, truncate=20):
if truncate:
# data = (item[:truncate] for item in data)
- data = (item[:truncate] + '...' if len(item) > truncate else item for item in data)
+ data = (
+ item[:truncate] + "..." if len(item) > truncate else item for item in data
+ )
data = Counter(data)
@@ -548,27 +604,27 @@ def _hist_str(name, data, bins=10, truncate=20):
labels, counts = zip(*data.items())
return {
- 'type': 'hist_str',
- 'name': name,
- 'total': sum(data.values()),
- 'unique': n_unique,
- 'labels': labels[:bins],
- 'counts': counts[:bins]
+ "type": "hist_str",
+ "name": name,
+ "total": sum(data.values()),
+ "unique": n_unique,
+ "labels": labels[:bins],
+ "counts": counts[:bins],
}
-if __name__ == '__main__':
+if __name__ == "__main__":
# import json
# from ase.io import iread
# from pprint import pprint
# from server.styles.myjson import JSONEncoderOld, JSONDecoderOld, JSONEncoder
- print('hello')
- db = MongoDatabase(username='mongoadmin', password='secret')
+ print("hello")
+ db = MongoDatabase(username="mongoadmin", password="secret")
print(db.info())
print(db.count())
- print(db.hist('uploaded'))
+ print(db.hist("uploaded"))
# for atoms in iread('../../tutorials/data/bcc_bulk_54_expanded_2_high.xyz', index=slice(None)):
# # print(at)
diff --git a/abcd/frontends/commandline/commands.py b/abcd/frontends/commandline/commands.py
index ebe6a213..de158a5c 100644
--- a/abcd/frontends/commandline/commands.py
+++ b/abcd/frontends/commandline/commands.py
@@ -11,32 +11,37 @@
@init_config
def login(*, config, name, url, **kwargs):
logger.info(
- 'login args: \nconfig:{}, name:{}, url:{}, kwargs:{}'.format(config, name, url, kwargs))
+ "login args: \nconfig:{}, name:{}, url:{}, kwargs:{}".format(
+ config, name, url, kwargs
+ )
+ )
from abcd import ABCD
db = ABCD.from_url(url=url)
info = db.info()
- config['url'] = url
+ config["url"] = url
config.save()
- print('Successfully connected to the database!')
- print(" type: {type}\n"
- " hostname: {host}\n"
- " port: {port}\n"
- " database: {db}\n"
- " # of confs: {number of confs}".format(**info))
+ print("Successfully connected to the database!")
+ print(
+ " type: {type}\n"
+ " hostname: {host}\n"
+ " port: {port}\n"
+ " database: {db}\n"
+ " # of confs: {number of confs}".format(**info)
+ )
@init_config
@init_db
def download(*, db, query, fileformat, filename, **kwargs):
- logger.info('download\n kwargs: {}'.format(kwargs))
+ logger.info("download\n kwargs: {}".format(kwargs))
from ase.io import write
- if kwargs.pop('remote'):
- write('-', list(db.get_atoms(query=query)), format=fileformat)
+ if kwargs.pop("remote"):
+ write("-", list(db.get_atoms(query=query)), format=fileformat)
return
write(filename, list(db.get_atoms(query=query)), format=fileformat)
@@ -46,14 +51,18 @@ def download(*, db, query, fileformat, filename, **kwargs):
@init_db
@check_remote
def delete(*, db, query, yes, **kwargs):
- logger.info('delete\n kwargs: {}'.format(kwargs))
+ logger.info("delete\n kwargs: {}".format(kwargs))
if not yes:
- print('Please use --yes for deleting {} configurations'.format(db.count(query=query)))
+ print(
+ "Please use --yes for deleting {} configurations".format(
+ db.count(query=query)
+ )
+ )
exit(1)
count = db.delete(query=query)
- print('{} configuration has been deleted'.format(count))
+ print("{} configuration has been deleted".format(count))
@init_config
@@ -69,11 +78,11 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs):
db.upload(path, extra_infos, store_calc=calculator)
elif path.is_dir():
- for file in path.glob('.xyz'):
- logger.info('Uploaded file: {}'.format(file))
+ for file in path.glob(".xyz"):
+ logger.info("Uploaded file: {}".format(file))
db.upload(file, extra_infos, store_calc=calculator)
else:
- logger.info('No file found: {}'.format(path))
+ logger.info("No file found: {}".format(path))
raise FileNotFoundError()
else:
@@ -83,8 +92,8 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs):
@init_config
@init_db
def summary(*, db, query, print_all, bins, truncate, props, **kwargs):
- logger.info('summary\n kwargs: {}'.format(kwargs))
- logger.info('query: {}'.format(query))
+ logger.info("summary\n kwargs: {}".format(kwargs))
+ logger.info("query: {}".format(query))
if print_all:
truncate = None
@@ -97,15 +106,15 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs):
props_list = []
for prop in props:
# TODO: Check that is this the right place?
- props_list.extend(re.split(r';\s*|,\s*|\s+', prop))
+ props_list.extend(re.split(r";\s*|,\s*|\s+", prop))
- if '*' in props_list:
- props_list = '*'
+ if "*" in props_list:
+ props_list = "*"
- logging.info('property list: {}'.format(props_list))
+ logging.info("property list: {}".format(props_list))
total = db.count(query)
- print('Total number of configurations: {}'.format(total))
+ print("Total number of configurations: {}".format(total))
if total == 0:
return
@@ -118,13 +127,13 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs):
labels, categories, dtypes, counts = [], [], [], []
for k in sorted(props, key=str.lower):
labels.append(k)
- counts.append(props[k]['count'])
- categories.append(props[k]['category'])
- dtypes.append(props[k]['dtype'])
+ counts.append(props[k]["count"])
+ categories.append(props[k]["category"])
+ dtypes.append(props[k]["dtype"])
f.hist_labels(counts, categories, dtypes, labels)
- elif props_list == '*':
+ elif props_list == "*":
props = db.properties(query=query)
for ptype in props:
@@ -149,8 +158,8 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs):
@init_config
@init_db
def show(*, db, query, print_all, props, **kwargs):
- logger.info('show\n kwargs: {}'.format(kwargs))
- logger.info('query: {}'.format(query))
+ logger.info("show\n kwargs: {}".format(kwargs))
+ logger.info("query: {}".format(query))
if not props:
print("Please define at least on property by using the -p option!")
@@ -162,7 +171,7 @@ def show(*, db, query, print_all, props, **kwargs):
for dct in islice(db.get_items(query), 0, limit):
print(" | ".join(str(dct.get(prop, None)) for prop in props))
- logging.info('property list: {}'.format(props))
+ logging.info("property list: {}".format(props))
@check_remote
@@ -171,17 +180,19 @@ def show(*, db, query, print_all, props, **kwargs):
def key_add(*, db, query, keys, **kwargs):
from abcd.parsers.extras import parser
- keys = ' '.join(keys)
+ keys = " ".join(keys)
data = parser.parse(keys)
if query:
- test = ('AND', query, ("OR", *(('NAME', key) for key in data.keys())))
+ test = ("AND", query, ("OR", *(("NAME", key) for key in data.keys())))
else:
- test = ("OR", *(('NAME', key) for key in data.keys()))
+ test = ("OR", *(("NAME", key) for key in data.keys()))
if db.count(query=test):
- print('The new key already exist for the given query! '
- 'Please make sure that the target key name don\'t exist')
+ print(
+ "The new key already exist for the given query! "
+ "Please make sure that the target key name don't exist"
+ )
exit(1)
db.add_property(data, query=query)
@@ -192,8 +203,10 @@ def key_add(*, db, query, keys, **kwargs):
@init_db
def key_rename(*, db, query, old_keys, new_keys, **kwargs):
if db.count(query=query + [old_keys, new_keys]):
- print('The new key already exist for the given query! '
- 'Please make sure that the target key name don\'t exist')
+ print(
+ "The new key already exist for the given query! "
+ "Please make sure that the target key name don't exist"
+ )
exit(1)
db.rename_property(old_keys, new_keys, query=query)
@@ -205,14 +218,17 @@ def key_rename(*, db, query, old_keys, new_keys, **kwargs):
def key_delete(*, db, query, yes, keys, **kwargs):
from abcd.parsers.extras import parser
- keys = ' '.join(keys)
+ keys = " ".join(keys)
data = parser.parse(keys)
- query = ('AND', query, ('OR', *(('NAME', key) for key in data.keys())))
+ query = ("AND", query, ("OR", *(("NAME", key) for key in data.keys())))
if not yes:
- print('Please use --yes for deleting keys from {} configurations'.format(
- db.count(query=query)))
+ print(
+ "Please use --yes for deleting keys from {} configurations".format(
+ db.count(query=query)
+ )
+ )
exit(1)
for k in keys:
@@ -224,8 +240,11 @@ def key_delete(*, db, query, yes, keys, **kwargs):
@init_db
def execute(*, db, query, yes, python_code, **kwargs):
if not yes:
- print('Please use --yes for executing code on {} configurations'.format(
- db.count(query=query)))
+ print(
+ "Please use --yes for executing code on {} configurations".format(
+ db.count(query=query)
+ )
+ )
exit(1)
db.exec(python_code, query)
@@ -236,7 +255,9 @@ def server(*, abcd_url, url, api_only, **kwargs):
from urllib.parse import urlparse
from abcd.server.app import create_app
- logger.info("SERVER - abcd: {}, url: {}, api_only:{}".format(abcd_url, url, api_only))
+ logger.info(
+ "SERVER - abcd: {}, url: {}, api_only:{}".format(abcd_url, url, api_only)
+ )
if api_only:
print("Not implemented yet!")
@@ -252,26 +273,37 @@ class Formater(object):
partialBlocks = ["▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"] # char=pb
def title(self, title):
- print('', title, '=' * len(title), sep=os.linesep)
+ print("", title, "=" * len(title), sep=os.linesep)
def describe(self, data):
- if data['type'] == 'hist_float':
+ if data["type"] == "hist_float":
print(
- '{} count: {} min: {:11.4e} med: {:11.4e} max: {:11.4e} std: {:11.4e} var:{'
- ':11.4e}'.format(
- data["name"], sum(data["counts"]),
- data["min"], data["median"], data["max"],
- data["std"], data["var"])
+ "{} count: {} min: {:11.4e} med: {:11.4e} max: {:11.4e} std: {:11.4e} var:{"
+ ":11.4e}".format(
+ data["name"],
+ sum(data["counts"]),
+ data["min"],
+ data["median"],
+ data["max"],
+ data["std"],
+ data["var"],
+ )
)
- elif data['type'] == 'hist_int':
- print('{} count: {} '.format(data["name"], sum(data["counts"])),
- 'min: {:d} med: {:d} max: {:d} '.format(int(data["min"]), int(data["median"]),
- int(data["max"]))
- )
+ elif data["type"] == "hist_int":
+ print(
+ "{} count: {} ".format(data["name"], sum(data["counts"])),
+ "min: {:d} med: {:d} max: {:d} ".format(
+ int(data["min"]), int(data["median"]), int(data["max"])
+ ),
+ )
- elif data['type'] == 'hist_str':
- print('{} count: {} unique: {}'.format(data["name"], data["total"], data["unique"]))
+ elif data["type"] == "hist_str":
+ print(
+ "{} count: {} unique: {}".format(
+ data["name"], data["total"], data["unique"]
+ )
+ )
else:
pass
@@ -282,10 +314,11 @@ def hist_float(self, bin_edges, counts, width_hist=40):
for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]):
scale = int(ratio * count)
- self.print('{:<{}} {:>{}d} [{: >11.4e}, {: >11.4f})'.format(
- "▉" * scale, width_hist,
- count, width_count,
- lower, upper))
+ self.print(
+ "{:<{}} {:>{}d} [{: >11.4e}, {: >11.4f})".format(
+ "▉" * scale, width_hist, count, width_count, lower, upper
+ )
+ )
def hist_int(self, bin_edges, counts, width_hist=40):
@@ -294,35 +327,50 @@ def hist_int(self, bin_edges, counts, width_hist=40):
for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]):
scale = int(ratio * count)
- self.print('{:<{}} {:>{}d} [{:d}, {:d})'.format(
- "▉" * scale, width_hist,
- count, width_count,
- np.ceil(lower).astype(int), np.floor(upper).astype(int)))
+ self.print(
+ "{:<{}} {:>{}d} [{:d}, {:d})".format(
+ "▉" * scale,
+ width_hist,
+ count,
+ width_count,
+ np.ceil(lower).astype(int),
+ np.floor(upper).astype(int),
+ )
+ )
def hist_date(self, bin_edges, counts, width_hist=40):
- dateformat = '%y-%m-%d %H:%M'
+ dateformat = "%y-%m-%d %H:%M"
ratio = width_hist / max(counts)
width_count = len(str(max(counts)))
for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]):
scale = int(ratio * count)
- self.print('{:<{}} {:>{}d} [{}, {})'.format(
- "▉" * scale, width_hist,
- count, width_count,
- lower.strftime(dateformat), upper.strftime(dateformat)))
+ self.print(
+ "{:<{}} {:>{}d} [{}, {})".format(
+ "▉" * scale,
+ width_hist,
+ count,
+ width_count,
+ lower.strftime(dateformat),
+ upper.strftime(dateformat),
+ )
+ )
def hist_str(self, total, counts, labels, width_hist=40):
remain = total - sum(counts)
if remain > 0:
counts = (*counts, remain)
- labels = (*labels, '...')
+ labels = (*labels, "...")
width_count = len(str(max(counts)))
ratio = width_hist / max(counts)
for label, count in zip(labels, counts):
scale = int(ratio * count)
self.print(
- '{:<{}} {:>{}d} {}'.format("▉" * scale, width_hist, count, width_count, label))
+ "{:<{}} {:>{}d} {}".format(
+ "▉" * scale, width_hist, count, width_count, label
+ )
+ )
def hist_labels(self, counts, categories, dtypes, labels, width_hist=40):
@@ -330,18 +378,21 @@ def hist_labels(self, counts, categories, dtypes, labels, width_hist=40):
ratio = width_hist / max(counts)
for label, count, dtype in zip(labels, counts, dtypes):
scale = int(ratio * count)
- self.print('{:<{}} {:<21} {:>{}d} {}'.format(
- "▉" * scale, width_hist, dtype, count, width_count, label))
+ self.print(
+ "{:<{}} {:<21} {:>{}d} {}".format(
+ "▉" * scale, width_hist, dtype, count, width_count, label
+ )
+ )
def hist(self, data: dict, width_hist=40):
- if data['type'] == 'hist_float':
- self.hist_float(data['edges'], data['counts'])
- elif data['type'] == 'hist_int':
- self.hist_int(data['edges'], data['counts'])
- elif data['type'] == 'hist_date':
- self.hist_date(data['edges'], data['counts'])
- elif data['type'] == 'hist_str':
- self.hist_str(data['total'], data['counts'], data['labels'])
+ if data["type"] == "hist_float":
+ self.hist_float(data["edges"], data["counts"])
+ elif data["type"] == "hist_int":
+ self.hist_int(data["edges"], data["counts"])
+ elif data["type"] == "hist_date":
+ self.hist_date(data["edges"], data["counts"])
+ elif data["type"] == "hist_str":
+ self.hist_str(data["total"], data["counts"], data["labels"])
else:
pass
@@ -350,4 +401,4 @@ def print(*args, **kwargs):
print(*args, **kwargs)
def _trunc(self, text, width=80):
- return text if len(text) < width else text[:width - 3] + '...'
+ return text if len(text) < width else text[: width - 3] + "..."
diff --git a/abcd/frontends/commandline/config.py b/abcd/frontends/commandline/config.py
index 96f1380c..3aa21bea 100644
--- a/abcd/frontends/commandline/config.py
+++ b/abcd/frontends/commandline/config.py
@@ -21,27 +21,33 @@ def from_json(cls, filename):
@classmethod
def load(cls):
- if os.environ.get('ABCD_CONFIG') and Path(os.environ.get('ABCD_CONFIG')).is_file():
- file = Path(os.environ.get('ABCD_CONFIG'))
- elif (Path.home() / '.abcd').is_file():
- file = Path.home() / '.abcd'
+ if (
+ os.environ.get("ABCD_CONFIG")
+ and Path(os.environ.get("ABCD_CONFIG")).is_file()
+ ):
+ file = Path(os.environ.get("ABCD_CONFIG"))
+ elif (Path.home() / ".abcd").is_file():
+ file = Path.home() / ".abcd"
else:
return cls()
- logger.info('Using config file: {}'.format(file))
+ logger.info("Using config file: {}".format(file))
config = cls.from_json(file)
return config
def save(self):
- file = Path(os.environ.get('ABCD_CONFIG')) if os.environ.get('ABCD_CONFIG') \
- else Path.home() / '.abcd'
+ file = (
+ Path(os.environ.get("ABCD_CONFIG"))
+ if os.environ.get("ABCD_CONFIG")
+ else Path.home() / ".abcd"
+ )
- logger.info('The saved config\'s file: {}'.format(file))
+ logger.info("The saved config's file: {}".format(file))
- with open(str(file), 'w') as file:
+ with open(str(file), "w") as file:
json.dump(self, file)
def __repr__(self):
- return '<{} {}>'.format(self.__class__.__name__, dict.__repr__(self))
+ return "<{} {}>".format(self.__class__.__name__, dict.__repr__(self))
diff --git a/abcd/frontends/commandline/decorators.py b/abcd/frontends/commandline/decorators.py
index d46e9ced..c2439be7 100644
--- a/abcd/frontends/commandline/decorators.py
+++ b/abcd/frontends/commandline/decorators.py
@@ -19,10 +19,10 @@ def wrapper(*args, **kwargs):
def init_db(func):
def wrapper(*args, config, **kwargs):
- url = config.get('url', None)
+ url = config.get("url", None)
if url is None:
- print('Please use abcd login first!')
+ print("Please use abcd login first!")
exit(1)
db = ABCD.from_url(url=url)
@@ -32,10 +32,10 @@ def wrapper(*args, config, **kwargs):
# TODO: better ast optimisation
query_list = []
- for q in kwargs.pop('default_query', []):
+ for q in kwargs.pop("default_query", []):
query_list.append(parser(q))
- for q in kwargs.pop('query', []):
+ for q in kwargs.pop("query", []):
query_list.append(parser(q))
if not query_list:
@@ -43,7 +43,7 @@ def wrapper(*args, config, **kwargs):
elif len(query_list) == 1:
query = query_list[0]
else:
- query = ('AND', *query_list)
+ query = ("AND", *query_list)
func(*args, db=db, query=query, **kwargs)
@@ -52,8 +52,8 @@ def wrapper(*args, config, **kwargs):
def check_remote(func):
def wrapper(*args, **kwargs):
- if kwargs.pop('remote'):
- print('In read only mode, you can\'t modify the data in the database')
+ if kwargs.pop("remote"):
+ print("In read only mode, you can't modify the data in the database")
exit(1)
func(*args, **kwargs)
diff --git a/abcd/frontends/commandline/parser.py b/abcd/frontends/commandline/parser.py
index 6856bd9e..9b2c1af2 100644
--- a/abcd/frontends/commandline/parser.py
+++ b/abcd/frontends/commandline/parser.py
@@ -5,135 +5,234 @@
logger = logging.getLogger(__name__)
-parser = ArgumentParser(description='Command line interface for ABCD database')
-parser.add_argument('-v', '--verbose', help='Enable verbose mode', action='store_true')
-parser.add_argument('-q', '--query', dest='default_query', action='append', help='Filtering extra quantities',
- default=[])
-
-parser.add_argument('--remote', help='Disables all the functions which would modify the database',
- action='store_true')
-
-subparsers = parser.add_subparsers(title='Commands', dest='command', parser_class=ArgumentParser)
-
-login_parser = subparsers.add_parser('login', help='login to the database')
+parser = ArgumentParser(description="Command line interface for ABCD database")
+parser.add_argument("-v", "--verbose", help="Enable verbose mode", action="store_true")
+parser.add_argument(
+ "-q",
+ "--query",
+ dest="default_query",
+ action="append",
+ help="Filtering extra quantities",
+ default=[],
+)
+
+parser.add_argument(
+ "--remote",
+ help="Disables all the functions which would modify the database",
+ action="store_true",
+)
+
+subparsers = parser.add_subparsers(
+ title="Commands", dest="command", parser_class=ArgumentParser
+)
+
+login_parser = subparsers.add_parser("login", help="login to the database")
login_parser.set_defaults(callback_func=commands.login)
-login_parser.add_argument('-n', '--name', help='name of the database', default='default')
-login_parser.add_argument(dest='url',
- help='url of abcd api (default: http://localhost)',
- default='http://localhost')
-
-download_parser = subparsers.add_parser('download', help='download data from the database')
+login_parser.add_argument(
+ "-n", "--name", help="name of the database", default="default"
+)
+login_parser.add_argument(
+ dest="url",
+ help="url of abcd api (default: http://localhost)",
+ default="http://localhost",
+)
+
+download_parser = subparsers.add_parser(
+ "download", help="download data from the database"
+)
download_parser.set_defaults(callback_func=commands.download)
-download_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[])
-download_parser.add_argument('-f', '--format', help='Valid ASE file format (optional)', dest='fileformat', default='extxyz')
-download_parser.add_argument(dest='filename', help='name of the file to store the configurations', nargs='?')
-
-upload_parser = subparsers.add_parser('upload', help='upload any ase supported files to the database')
-upload_parser.add_argument('-e', '--extra_infos', action='append', help='Adding extra quantities')
-upload_parser.add_argument('-i', '--ignore_calc_results', action='store_true',
- help='Ignore calculators results/parameters')
-upload_parser.add_argument('--upload-duplicates', action='store_true',
- help='Upload but still report all of the duplicates')
-upload_parser.add_argument('--upload-structure-duplicates', action='store_true',
- help='Ignore all the "exact" duplicates but store all of the structural duplicates')
-upload_parser.add_argument('--upload-duplicates-replace', action='store_true',
- help='Upload everything and duplicates overwrite previously existing data')
-upload_parser.add_argument(dest='path', help='Path to the file or folder.')
+download_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering extra quantities", default=[]
+)
+download_parser.add_argument(
+ "-f",
+ "--format",
+ help="Valid ASE file format (optional)",
+ dest="fileformat",
+ default="extxyz",
+)
+download_parser.add_argument(
+ dest="filename", help="name of the file to store the configurations", nargs="?"
+)
+
+upload_parser = subparsers.add_parser(
+ "upload", help="upload any ase supported files to the database"
+)
+upload_parser.add_argument(
+ "-e", "--extra_infos", action="append", help="Adding extra quantities"
+)
+upload_parser.add_argument(
+ "-i",
+ "--ignore_calc_results",
+ action="store_true",
+ help="Ignore calculators results/parameters",
+)
+upload_parser.add_argument(
+ "--upload-duplicates",
+ action="store_true",
+ help="Upload but still report all of the duplicates",
+)
+upload_parser.add_argument(
+ "--upload-structure-duplicates",
+ action="store_true",
+ help='Ignore all the "exact" duplicates but store all of the structural duplicates',
+)
+upload_parser.add_argument(
+ "--upload-duplicates-replace",
+ action="store_true",
+ help="Upload everything and duplicates overwrite previously existing data",
+)
+upload_parser.add_argument(dest="path", help="Path to the file or folder.")
upload_parser.set_defaults(callback_func=commands.upload)
-summary_parser = subparsers.add_parser('summary', help='Discovery mode')
+summary_parser = subparsers.add_parser("summary", help="Discovery mode")
summary_parser.set_defaults(callback_func=commands.summary)
-summary_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[])
-summary_parser.add_argument('-p', '--props', action='append',
- help='Selecting properties for detailed description')
-summary_parser.add_argument('-a', '--all',
- help='Show everything without truncation of strings and limits of lines',
- action='store_true', dest='print_all')
-summary_parser.add_argument('-n', '--bins', help='The number of bins of the histogram', default=10, type=int)
-summary_parser.add_argument('-t', '--trunc',
- help='Length of string before truncation',
- default=20, type=int, dest='truncate')
-
-show_parser = subparsers.add_parser('show', help='shows the first 10 items')
+summary_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering extra quantities", default=[]
+)
+summary_parser.add_argument(
+ "-p",
+ "--props",
+ action="append",
+ help="Selecting properties for detailed description",
+)
+summary_parser.add_argument(
+ "-a",
+ "--all",
+ help="Show everything without truncation of strings and limits of lines",
+ action="store_true",
+ dest="print_all",
+)
+summary_parser.add_argument(
+ "-n", "--bins", help="The number of bins of the histogram", default=10, type=int
+)
+summary_parser.add_argument(
+ "-t",
+ "--trunc",
+ help="Length of string before truncation",
+ default=20,
+ type=int,
+ dest="truncate",
+)
+
+show_parser = subparsers.add_parser("show", help="shows the first 10 items")
show_parser.set_defaults(callback_func=commands.show)
-show_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[])
-show_parser.add_argument('-p', '--props', action='append',
- help='Selecting properties for detailed description')
-show_parser.add_argument('-a', '--all',
- help='Show everything without truncation of strings and limits of lines',
- action='store_true', dest='print_all')
-
-delete_parser = subparsers.add_parser('delete', help='Delete configurations from the database')
+show_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering extra quantities", default=[]
+)
+show_parser.add_argument(
+ "-p",
+ "--props",
+ action="append",
+ help="Selecting properties for detailed description",
+)
+show_parser.add_argument(
+ "-a",
+ "--all",
+ help="Show everything without truncation of strings and limits of lines",
+ action="store_true",
+ dest="print_all",
+)
+
+delete_parser = subparsers.add_parser(
+ "delete", help="Delete configurations from the database"
+)
delete_parser.set_defaults(callback_func=commands.delete)
-delete_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[])
-delete_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual deletion.')
-
-key_add_parser = subparsers.add_parser('add-key', help='Adding new key value pairs for a given query')
+delete_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering by a query", default=[]
+)
+delete_parser.add_argument(
+ "-y", "--yes", action="store_true", help="Do the actual deletion."
+)
+
+key_add_parser = subparsers.add_parser(
+ "add-key", help="Adding new key value pairs for a given query"
+)
key_add_parser.set_defaults(callback_func=commands.key_add)
-key_add_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[])
-key_add_parser.add_argument('-y', '--yes', action='store_true', help='Overwrite?')
-key_add_parser.add_argument('keys', help='keys(=value) pairs', nargs='+')
-
-key_rename_parser = subparsers.add_parser('rename-key', help='Rename a specific keys for a given query')
+key_add_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering by a query", default=[]
+)
+key_add_parser.add_argument("-y", "--yes", action="store_true", help="Overwrite?")
+key_add_parser.add_argument("keys", help="keys(=value) pairs", nargs="+")
+
+key_rename_parser = subparsers.add_parser(
+ "rename-key", help="Rename a specific keys for a given query"
+)
key_rename_parser.set_defaults(callback_func=commands.key_rename)
-key_rename_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[])
-key_rename_parser.add_argument('-y', '--yes', action='store_true', help='Overwrite?')
-key_rename_parser.add_argument('old_keys', help='name of the old key')
-key_rename_parser.add_argument('new_keys', help='new name of the key')
-
-key_delete_parser = subparsers.add_parser('delete-key', help='Delete all the keys for a given query')
+key_rename_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering by a query", default=[]
+)
+key_rename_parser.add_argument("-y", "--yes", action="store_true", help="Overwrite?")
+key_rename_parser.add_argument("old_keys", help="name of the old key")
+key_rename_parser.add_argument("new_keys", help="new name of the key")
+
+key_delete_parser = subparsers.add_parser(
+ "delete-key", help="Delete all the keys for a given query"
+)
key_delete_parser.set_defaults(callback_func=commands.key_delete)
-key_delete_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[])
-key_delete_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual deletion.')
-key_delete_parser.add_argument('keys', help='keys(=value) data', nargs='+')
-
-exec_parser = subparsers.add_parser('exec', help='Running custom python code')
+key_delete_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering by a query", default=[]
+)
+key_delete_parser.add_argument(
+ "-y", "--yes", action="store_true", help="Do the actual deletion."
+)
+key_delete_parser.add_argument("keys", help="keys(=value) data", nargs="+")
+
+exec_parser = subparsers.add_parser("exec", help="Running custom python code")
exec_parser.set_defaults(callback_func=commands.execute)
-exec_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[])
-exec_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual execution.')
-exec_parser.add_argument('python_code', help='Selecting properties for detailed description')
-
-server = subparsers.add_parser('server', help='Running custom python code')
+exec_parser.add_argument(
+ "-q", "--query", action="append", help="Filtering by a query", default=[]
+)
+exec_parser.add_argument(
+ "-y", "--yes", action="store_true", help="Do the actual execution."
+)
+exec_parser.add_argument(
+ "python_code", help="Selecting properties for detailed description"
+)
+
+server = subparsers.add_parser("server", help="Running custom python code")
server.set_defaults(callback_func=commands.server)
-server.add_argument('abcd_url', help='Url for abcd database.')
-server.add_argument('--api-only', action='store_true', help='Running only the API.')
-server.add_argument('-u', '--url', help='Url to run the server.', default='http://localhost:5000')
+server.add_argument("abcd_url", help="Url for abcd database.")
+server.add_argument("--api-only", action="store_true", help="Running only the API.")
+server.add_argument(
+ "-u", "--url", help="Url to run the server.", default="http://localhost:5000"
+)
def main(args=None):
kwargs = parser.parse_args(args).__dict__
- if kwargs.pop('verbose'):
+ if kwargs.pop("verbose"):
# Remove all handlers associated with the root logger object.
# https://stackoverflow.com/questions/12158048/changing-loggings-basicconfig-which-is-already-set
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.basicConfig(level=logging.INFO)
- logger.info('Verbose mode is active')
+ logger.info("Verbose mode is active")
- if not kwargs.pop('command'):
+ if not kwargs.pop("command"):
print(parser.format_help())
return
try:
- callback_func = kwargs.pop('callback_func')
+ callback_func = kwargs.pop("callback_func")
callback_func(**kwargs)
except URLError:
- print('Wrong connection: Please check the parameters of the url!')
+ print("Wrong connection: Please check the parameters of the url!")
exit(1)
except AuthenticationError:
- print('Authentication failed: Please check the parameters of the connection!')
+ print("Authentication failed: Please check the parameters of the connection!")
exit(1)
except TimeoutError:
print("Timeout: Please check the parameters of the connection!")
exit(1)
-if __name__ == '__main__':
- main(['summary'])
- main('delete-key -q pbc pbc'.split())
+if __name__ == "__main__":
+ main(["summary"])
+ main("delete-key -q pbc pbc".split())
# main('upload -e cas -i ../../../tutorials/GB_alphaFe_001/tilt/00110391110_v6bxv2_tv0.4bxv0.2_d1.6z_traj.xyz'.split())
# main('summary -q formula~"Si2"'.split())
# main('upload -e cas -i ../../../tutorials/GB_alphaFe_001/tilt/00110391110_v6bxv2_tv0.4bxv0.2_d1.6z_traj.xyz'.split())
@@ -142,7 +241,7 @@ def main(args=None):
# main('-v summary -p energy'.split())
# main('-v summary -p *'.split())
# main('add-key -q cas selected user="cas"'.split())
- main('delete-key user'.split())
+ main("delete-key user".split())
# main(['summary', '-p', '*'])
# main(['summary', '-p', 'info.config_name, info.energy'])
# main(['summary', '-p', 'info.config_name, info.energy,info.energy;info.energy info.energy'])
diff --git a/abcd/model.py b/abcd/model.py
index c62d14ec..f4c87b61 100644
--- a/abcd/model.py
+++ b/abcd/model.py
@@ -18,13 +18,13 @@ def __init__(self, method=md5()):
def update(self, value):
if isinstance(value, int):
- self.update(str(value).encode('ascii'))
+ self.update(str(value).encode("ascii"))
elif isinstance(value, str):
- self.update(value.encode('utf-8'))
+ self.update(value.encode("utf-8"))
elif isinstance(value, float):
- self.update('{:.8e}'.format(value).encode('ascii'))
+ self.update("{:.8e}".format(value).encode("ascii"))
elif isinstance(value, (tuple, list)):
for e in value:
@@ -33,7 +33,7 @@ def update(self, value):
elif isinstance(value, (dict, UserDict)):
keys = value.keys()
for k in sorted(keys):
- self.update(k.encode('utf-8'))
+ self.update(k.encode("utf-8"))
self.update(value[k])
elif isinstance(value, datetime.datetime):
@@ -43,7 +43,9 @@ def update(self, value):
self.method.update(value)
else:
- raise ValueError("The {} type cannot be hashed! (Value: {})", format(type(value), value))
+ raise ValueError(
+ "The {} type cannot be hashed! (Value: {})", format(type(value), value)
+ )
def __call__(self):
"""Retrieve the digest of the hash."""
@@ -51,7 +53,14 @@ def __call__(self):
class AbstractModel(UserDict):
- reserved_keys = {'n_atoms', 'cell', 'pbc', 'calculator_name', 'calculator_parameters', 'derived'}
+ reserved_keys = {
+ "n_atoms",
+ "cell",
+ "pbc",
+ "calculator_name",
+ "calculator_parameters",
+ "derived",
+ }
def __init__(self, dict=None, **kwargs):
self.arrays_keys = []
@@ -64,22 +73,22 @@ def __init__(self, dict=None, **kwargs):
@property
def derived(self):
return {
- 'arrays_keys': self.arrays_keys,
- 'info_keys': self.info_keys,
- 'results_keys': self.results_keys,
- 'derived_keys': self.derived_keys
+ "arrays_keys": self.arrays_keys,
+ "info_keys": self.info_keys,
+ "results_keys": self.results_keys,
+ "derived_keys": self.derived_keys,
}
def __getitem__(self, key):
- if key == 'derived':
+ if key == "derived":
return self.derived
return super().__getitem__(key)
def __setitem__(self, key, value):
- if key == 'derived':
+ if key == "derived":
# raise KeyError('Please do not use "derived" as key because it is protected!')
# Silent return to avoid raising error in pymongo package
return
@@ -99,31 +108,31 @@ def convert(self, value):
def update_key_category(self, key, value):
- if key == '_id':
+ if key == "_id":
# raise KeyError('Please do not use "derived" as key because it is protected!')
return
- for category in ('arrays_keys', 'info_keys', 'results_keys', 'derived_keys'):
+ for category in ("arrays_keys", "info_keys", "results_keys", "derived_keys"):
if key in self.derived[category]:
return
- if key in ('positions', 'forces'):
- self.derived['arrays_keys'].append(key)
+ if key in ("positions", "forces"):
+ self.derived["arrays_keys"].append(key)
return
- if key in ('n_atoms', 'cell', 'pbc'):
- self.derived['info_keys'].append(key)
+ if key in ("n_atoms", "cell", "pbc"):
+ self.derived["info_keys"].append(key)
return
# Guess the category based in the shape of the value
- n_atoms = self['n_atoms']
+ n_atoms = self["n_atoms"]
if isinstance(value, (np.ndarray, list)) and len(value) == n_atoms:
- self.derived['arrays_keys'].append(key)
+ self.derived["arrays_keys"].append(key)
else:
- self.derived['info_keys'].append(key)
+ self.derived["info_keys"].append(key)
def __delitem__(self, key):
- for category in ('arrays_keys', 'info_keys', 'results_keys', 'derived_keys'):
+ for category in ("arrays_keys", "info_keys", "results_keys", "derived_keys"):
if key in self.derived[category]:
self.derived[category].remove(key)
break
@@ -133,34 +142,44 @@ def __delitem__(self, key):
def __iter__(self):
for item in super().__iter__():
yield item
- yield 'derived'
+ yield "derived"
@classmethod
def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
"""ASE's original implementation"""
- reserved_keys = {'n_atoms', 'cell', 'pbc', 'calculator_name', 'calculator_parameters', 'derived', 'formula'}
+ reserved_keys = {
+ "n_atoms",
+ "cell",
+ "pbc",
+ "calculator_name",
+ "calculator_parameters",
+ "derived",
+ "formula",
+ }
arrays_keys = set(atoms.arrays.keys())
info_keys = set(atoms.info.keys())
- results_keys = set(atoms.calc.results.keys()) if store_calc and atoms.calc else {}
+ results_keys = (
+ set(atoms.calc.results.keys()) if store_calc and atoms.calc else {}
+ )
all_keys = (reserved_keys, arrays_keys, info_keys, results_keys)
if len(set.union(*all_keys)) != sum(map(len, all_keys)):
print(all_keys)
- raise ValueError('All the keys must be unique!')
+ raise ValueError("All the keys must be unique!")
item = cls()
n_atoms = len(atoms)
dct = {
- 'n_atoms': n_atoms,
- 'cell': atoms.cell.tolist(),
- 'pbc': atoms.pbc.tolist(),
- 'formula': atoms.get_chemical_formula()
+ "n_atoms": n_atoms,
+ "cell": atoms.cell.tolist(),
+ "pbc": atoms.pbc.tolist(),
+ "formula": atoms.get_chemical_formula(),
}
- info_keys.update({'n_atoms', 'cell', 'pbc', 'formula'})
+ info_keys.update({"n_atoms", "cell", "pbc", "formula"})
for key, value in atoms.arrays.items():
if isinstance(value, np.ndarray):
@@ -175,9 +194,9 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True):
dct[key] = value
if store_calc and atoms.calc:
- dct['calculator_name'] = atoms.calc.__class__.__name__
- dct['calculator_parameters'] = atoms.calc.todict()
- info_keys.update({'calculator_name', 'calculator_parameters'})
+ dct["calculator_name"] = atoms.calc.__class__.__name__
+ dct["calculator_parameters"] = atoms.calc.todict()
+ info_keys.update({"calculator_name", "calculator_parameters"})
for key, value in atoms.calc.results.items():
@@ -205,26 +224,22 @@ def to_ase(self):
arrays_keys = set(self.arrays_keys)
info_keys = set(self.info_keys)
- cell = self.pop('cell', None)
- pbc = self.pop('pbc', None)
- numbers = self.pop('numbers', None)
- positions = self.pop('positions', None)
- results_keys = self.derived['results_keys']
+ cell = self.pop("cell", None)
+ pbc = self.pop("pbc", None)
+ numbers = self.pop("numbers", None)
+ positions = self.pop("positions", None)
+ results_keys = self.derived["results_keys"]
- info_keys -= {'cell', 'pbc'}
- arrays_keys -= {'numbers', 'positions'}
+ info_keys -= {"cell", "pbc"}
+ arrays_keys -= {"numbers", "positions"}
- atoms = Atoms(
- cell=cell,
- pbc=pbc,
- numbers=numbers,
- positions=positions)
+ atoms = Atoms(cell=cell, pbc=pbc, numbers=numbers, positions=positions)
- if 'calculator_name' in self:
+ if "calculator_name" in self:
# calculator_name = self['info'].pop('calculator_name')
# atoms.calc = get_calculator(data['results']['calculator_name'])(**params)
- params = self.pop('calculator_parameters', {})
+ params = self.pop("calculator_parameters", {})
atoms.calc = SinglePointCalculator(atoms, **params)
atoms.calc.results.update((key, self[key]) for key in results_keys)
@@ -235,38 +250,38 @@ def to_ase(self):
return atoms
def pre_save(self):
- self.derived_keys = ['elements', 'username', 'uploaded', 'modified']
+ self.derived_keys = ["elements", "username", "uploaded", "modified"]
- cell = self['cell']
+ cell = self["cell"]
if cell:
volume = abs(np.linalg.det(cell)) # atoms.get_volume()
- self['volume'] = volume
- self.derived_keys.append('volume')
+ self["volume"] = volume
+ self.derived_keys.append("volume")
- virial = self.get('virial')
+ virial = self.get("virial")
if virial:
# pressure P = -1/3 Tr(stress) = -1/3 Tr(virials/volume)
- self['pressure'] = -1 / 3 * np.trace(virial / volume)
- self.derived_keys.append('pressure')
+ self["pressure"] = -1 / 3 * np.trace(virial / volume)
+ self.derived_keys.append("pressure")
# 'elements': Counter(atoms.get_chemical_symbols()),
- self['elements'] = Counter(str(element) for element in self['numbers'])
+ self["elements"] = Counter(str(element) for element in self["numbers"])
- self['username'] = getpass.getuser()
+ self["username"] = getpass.getuser()
- if not self.get('uploaded'):
- self['uploaded'] = datetime.datetime.utcnow()
+ if not self.get("uploaded"):
+ self["uploaded"] = datetime.datetime.utcnow()
- self['modified'] = datetime.datetime.utcnow()
+ self["modified"] = datetime.datetime.utcnow()
m = Hasher()
- for key in ('numbers', 'positions', 'cell', 'pbc'):
+ for key in ("numbers", "positions", "cell", "pbc"):
m.update(self[key])
- self.derived_keys.append('hash_structure')
- self['hash_structure'] = m()
+ self.derived_keys.append("hash_structure")
+ self["hash_structure"] = m()
m = Hasher()
for key in self.arrays_keys:
@@ -274,11 +289,11 @@ def pre_save(self):
for key in self.info_keys:
m.update(self[key])
- self.derived_keys.append('hash')
- self['hash'] = m()
+ self.derived_keys.append("hash")
+ self["hash"] = m()
-if __name__ == '__main__':
+if __name__ == "__main__":
import io
from pprint import pprint
from ase.io import read
@@ -286,7 +301,7 @@ def pre_save(self):
logging.basicConfig(level=logging.INFO)
# from ase.io import jsonio
- atoms = read('test.xyz', format='xyz', index=0)
+ atoms = read("test.xyz", format="xyz", index=0)
atoms.set_cell([1, 1, 1])
print(atoms)
diff --git a/abcd/parsers/extras.py b/abcd/parsers/extras.py
index 939210bb..c007acf6 100644
--- a/abcd/parsers/extras.py
+++ b/abcd/parsers/extras.py
@@ -90,40 +90,43 @@ def string(self, s):
# parser = Lark(grammar, parser='lalr', lexer='contextual', debug=False)
-parser = Lark(grammar, parser='lalr', lexer='contextual', transformer=TreeToDict(), debug=False)
-
-if __name__ == '__main__':
- test_string = ' '.join([
- ' ' # start with a separator
- 'flag',
- 'quotedd_string="quoteddd value"',
- r'quotedddd_string_escaped="esc\"aped"',
- 'false_value = F',
- 'integer=22',
- 'floating=1.1',
- 'int_array={1 2 3}',
- 'scientific_float=1.2e7',
- 'scientific_float_2=5e-6',
- 'scientific_float_array="1.2 2.2e3 4e1 3.3e-1 2e-2"',
- 'not_array="1.2 3.4 text"',
- 'array_nested=[[1,2],[3,4]] ' # gets flattented if not 3x3
- 'array_many_other_quotes=({[4 8 12]})',
- 'array_boolean={T F T F}',
- 'array_boolean_2=" T, F, T " ' # leading spaces
- # 'not_bool_array=[T F S]',
- # # read and write
- # u'\xfcnicode_key=val\xfce',
- # # u'unquoted_special_value=a_to_Z_$%%^&*\xfc\u2615',
- # '2body=33.3',
- 'hyphen-ated',
- # # # parse only
- 'comma_separated="7, 4, -1"',
- 'array_bool_commas=[T, T, F, T]',
- # # 'Properties=species:S:1:pos:R:3',
- # 'double_equals=abc=xyz',
- 'multiple_separators ',
- 'trailing'
- ])
+parser = Lark(
+ grammar, parser="lalr", lexer="contextual", transformer=TreeToDict(), debug=False
+)
+
+if __name__ == "__main__":
+ test_string = " ".join(
+ [
+ " " "flag", # start with a separator
+ 'quotedd_string="quoteddd value"',
+ r'quotedddd_string_escaped="esc\"aped"',
+ "false_value = F",
+ "integer=22",
+ "floating=1.1",
+ "int_array={1 2 3}",
+ "scientific_float=1.2e7",
+ "scientific_float_2=5e-6",
+ 'scientific_float_array="1.2 2.2e3 4e1 3.3e-1 2e-2"',
+ 'not_array="1.2 3.4 text"',
+ "array_nested=[[1,2],[3,4]] " # gets flattented if not 3x3
+ "array_many_other_quotes=({[4 8 12]})",
+ "array_boolean={T F T F}",
+ 'array_boolean_2=" T, F, T " ' # leading spaces
+ # 'not_bool_array=[T F S]',
+ # # read and write
+ # u'\xfcnicode_key=val\xfce',
+ # # u'unquoted_special_value=a_to_Z_$%%^&*\xfc\u2615',
+ # '2body=33.3',
+ "hyphen-ated",
+ # # # parse only
+ 'comma_separated="7, 4, -1"',
+ "array_bool_commas=[T, T, F, T]",
+ # # 'Properties=species:S:1:pos:R:3',
+ # 'double_equals=abc=xyz',
+ "multiple_separators ",
+ "trailing",
+ ]
+ )
j = parser.parse(test_string)
diff --git a/abcd/parsers/queries.py b/abcd/parsers/queries.py
index 707fd4bc..41ef1c33 100644
--- a/abcd/parsers/queries.py
+++ b/abcd/parsers/queries.py
@@ -90,17 +90,17 @@ class TreeTransformer(Transformer):
def array(self, *items):
return list(items)
- true = lambda _: ('VALUE', True)
- false = lambda _: ('VALUE', False)
+ true = lambda _: ("VALUE", True)
+ false = lambda _: ("VALUE", False)
def float(self, number):
- return 'NUMBER', float(number)
+ return "NUMBER", float(number)
def int(self, number):
- return 'NUMBER', int(number)
+ return "NUMBER", int(number)
def string(self, s):
- return 'STRING', s[1:-1].replace('\\"', '"')
+ return "STRING", s[1:-1].replace('\\"', '"')
def single_statement(self, expression):
return expression
@@ -109,24 +109,24 @@ def multi_statement(self, statement, operator, expression):
return operator.type, statement, expression
def single_expression(self, name):
- return 'NAME', str(name)
+ return "NAME", str(name)
def grouped_expression(self, statement):
# return statement
- return 'GROUP', statement
+ return "GROUP", statement
def operator_expression(self, name, operator, value):
- return operator.type, ('NAME', str(name)), value
+ return operator.type, ("NAME", str(name)), value
def reversed_expression(self, value, operator, name):
- return operator.type, ('NAME', str(name)), value
+ return operator.type, ("NAME", str(name)), value
def negation_expression(self, operator, expression):
return operator.type, expression
class Parser:
- parser = Lark(grammar, start='statement')
+ parser = Lark(grammar, start="statement")
transformer = TreeTransformer()
def parse(self, string):
@@ -138,29 +138,29 @@ def __call__(self, string):
parser = Parser()
-if __name__ == '__main__':
+if __name__ == "__main__":
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
queries = (
- ' ',
- 'single',
- 'not single',
- 'operator_gt > 23 ',
- 'operator_gt > -2.31e-5 ',
+ " ",
+ "single",
+ "not single",
+ "operator_gt > 23 ",
+ "operator_gt > -2.31e-5 ",
'string = "some string"',
'regexp ~ ".*H"',
- 'aa & not bb',
- 'aa & bb > 23.54 | cc & dd',
+ "aa & not bb",
+ "aa & bb > 23.54 | cc & dd",
# 'aa bb > 22 cc > 33 dd > 44 ',
- 'aa and bb > 22 and cc > 33 and dd > 44 ',
- '((aa and bb > 22) and cc > 33) and dd > 44 ',
- '(aa and bb > 22) and (cc > 33 and dd > 44) ',
- '(aa and bb > 22 and cc > 33 and dd > 44) ',
- 'aa and bb > 23.54 or 22 in cc and dd',
- 'aa & bb > 23.54 | (22 in cc & dd)',
- 'aa and bb > 23.54 or (22 in cc and dd)',
- 'aa and not (bb > 23.54 or (22 in cc and dd))',
+ "aa and bb > 22 and cc > 33 and dd > 44 ",
+ "((aa and bb > 22) and cc > 33) and dd > 44 ",
+ "(aa and bb > 22) and (cc > 33 and dd > 44) ",
+ "(aa and bb > 22 and cc > 33 and dd > 44) ",
+ "aa and bb > 23.54 or 22 in cc and dd",
+ "aa & bb > 23.54 | (22 in cc & dd)",
+ "aa and bb > 23.54 or (22 in cc and dd)",
+ "aa and not (bb > 23.54 or (22 in cc and dd))",
# 'expression = (bb/3-1)*cc',
# 'energy/n_atoms > 3',
# '1=3',
@@ -176,7 +176,7 @@ def __call__(self, string):
# print(parser.parse(query).pretty())
try:
tree = parser.parse(query)
- logger.info('=> tree: {}'.format(tree))
- logger.info('==> ast: {}'.format(parser(query)))
+ logger.info("=> tree: {}".format(tree))
+ logger.info("==> ast: {}".format(parser(query)))
except LarkError:
raise NotImplementedError
diff --git a/abcd/parsers/queries_new.py b/abcd/parsers/queries_new.py
index 8f176d0c..22004d11 100644
--- a/abcd/parsers/queries_new.py
+++ b/abcd/parsers/queries_new.py
@@ -8,21 +8,21 @@
# TODO: Reversed operator in the grammar (value op prop VS prop op value VS IN)
-class DebugTransformer(Transformer): # pragma: no cover
+class DebugTransformer(Transformer): # pragma: no cover
def __init__(self):
super().__init__()
def __default__(self, data, children, meta):
- print('Node: ', data, children)
+ print("Node: ", data, children)
return data
-if __name__ == '__main__':
+if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
- with open('query_new.lark') as file:
- parser = Lark(file.read(), start='expression')
+ with open("query_new.lark") as file:
+ parser = Lark(file.read(), start="expression")
transformer = DebugTransformer()
@@ -72,30 +72,30 @@ def __default__(self, data, children, meta):
# )
queries = (
- ' ',
- 'single',
- 'not single',
- 'operator_gt > 23 ',
- 'operator_gt > -2.31e-5 ',
+ " ",
+ "single",
+ "not single",
+ "operator_gt > 23 ",
+ "operator_gt > -2.31e-5 ",
'string = "some string"',
'regexp ~ ".*H"',
- 'aa & not bb',
- 'aa & bb > 23.54 | cc & dd',
- 'aa and bb > 22 and cc > 33 and dd > 44 ',
- '((aa and bb > 22) and cc > 33) and dd > 44 ',
- '(aa and bb > 22) and (cc > 33 and dd > 44) ',
- '(aa and bb > 22 and cc > 33 and dd > 44) ',
- 'aa and bb > 23.54 or 22 in cc and dd',
- 'aa & bb > 23.54 | (22 in cc & dd)',
- 'aa and bb > 23.54 or (22 in cc and dd)',
- 'aa and not (bb > 23.54 or (22 in cc and dd))',
- 'expression = (bb/3-1)*cc',
- 'energy/n_atoms > 3',
- '1=3',
- 'all(aa) > 3',
- 'any(aa) > 3',
- 'aa = False',
- 'aa = [True, True, True]',
+ "aa & not bb",
+ "aa & bb > 23.54 | cc & dd",
+ "aa and bb > 22 and cc > 33 and dd > 44 ",
+ "((aa and bb > 22) and cc > 33) and dd > 44 ",
+ "(aa and bb > 22) and (cc > 33 and dd > 44) ",
+ "(aa and bb > 22 and cc > 33 and dd > 44) ",
+ "aa and bb > 23.54 or 22 in cc and dd",
+ "aa & bb > 23.54 | (22 in cc & dd)",
+ "aa and bb > 23.54 or (22 in cc and dd)",
+ "aa and not (bb > 23.54 or (22 in cc and dd))",
+ "expression = (bb/3-1)*cc",
+ "energy/n_atoms > 3",
+ "1=3",
+ "all(aa) > 3",
+ "any(aa) > 3",
+ "aa = False",
+ "aa = [True, True, True]",
)
for query in queries:
diff --git a/abcd/server/__init__.py b/abcd/server/__init__.py
index 86601b85..3b21e3bb 100644
--- a/abcd/server/__init__.py
+++ b/abcd/server/__init__.py
@@ -2,5 +2,5 @@
app = create_app()
-if __name__ == '__main__':
- app.run(host='0.0.0.0')
+if __name__ == "__main__":
+ app.run(host="0.0.0.0")
diff --git a/abcd/server/app/__init__.py b/abcd/server/app/__init__.py
index bc4614f6..d7176a42 100644
--- a/abcd/server/app/__init__.py
+++ b/abcd/server/app/__init__.py
@@ -11,34 +11,34 @@
def create_app(abcd_url=None):
# Define the WSGI application object
app = Flask(__name__)
- app.logger.info('Creating an application')
+ app.logger.info("Creating an application")
- app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', os.urandom(12).hex())
- app.config['ABCD_URL'] = os.getenv('ABCD_URL', 'mongodb://localhost:27017/abcd')
+ app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(12).hex())
+ app.config["ABCD_URL"] = os.getenv("ABCD_URL", "mongodb://localhost:27017/abcd")
# Initialize extensions/add-ons/plugins.
nav.init_app(app)
- register_renderer(app, 'BootstrapRenderer', BootstrapRenderer)
- register_renderer(app, 'DatabaseNav', DatabaseNav)
+ register_renderer(app, "BootstrapRenderer", BootstrapRenderer)
+ register_renderer(app, "DatabaseNav", DatabaseNav)
db.init_app(app)
# Setup redirects and register blueprints.
app.register_blueprint(index.bp)
- app.register_blueprint(database.bp, url_prefix='/db')
- app.register_blueprint(api.bp, url_prefix='/api')
+ app.register_blueprint(database.bp, url_prefix="/db")
+ app.register_blueprint(api.bp, url_prefix="/api")
@app.errorhandler(404)
def not_found(error):
- return render_template('404.html'), 404
+ return render_template("404.html"), 404
return app
-if __name__ == '__main__':
+if __name__ == "__main__":
import logging
logging.basicConfig(level=logging.DEBUG)
app = create_app()
- app.run(host='0.0.0.0', debug=True)
+ app.run(host="0.0.0.0", debug=True)
diff --git a/abcd/server/app/nav.py b/abcd/server/app/nav.py
index 09af4385..ec8ea2d1 100644
--- a/abcd/server/app/nav.py
+++ b/abcd/server/app/nav.py
@@ -16,29 +16,29 @@ def __init__(self, title, *args, **kwargs):
@nav.navigation()
def main_navbar():
return TopNavbar(
- 'ABCD',
- View('Home', 'index.index'),
- View('API', 'api.index'),
- View('Databases', 'database.database', database_name='default'),
- Link('Docs', 'https://libatoms.github.io/abcd/'),
- Link('Github', 'https://github.com/libatoms/abcd'),
+ "ABCD",
+ View("Home", "index.index"),
+ View("API", "api.index"),
+ View("Databases", "database.database", database_name="default"),
+ Link("Docs", "https://libatoms.github.io/abcd/"),
+ Link("Github", "https://github.com/libatoms/abcd"),
)
@nav.navigation()
def database_navbar():
return Navbar(
- '',
- View('Database', 'database.database'),
- Link('Collections', '#'),
- Link('History', '#'),
- Link('Statistics', '#'),
- View('Settings', 'database.settings'),
+ "",
+ View("Database", "database.database"),
+ Link("Collections", "#"),
+ Link("History", "#"),
+ Link("Statistics", "#"),
+ View("Settings", "database.settings"),
)
class DatabaseNav(Renderer):
- def __init__(self, database_name='atoms'):
+ def __init__(self, database_name="atoms"):
self.database_name = database_name
def visit_Navbar(self, node):
@@ -50,22 +50,24 @@ def visit_Navbar(self, node):
return root
def visit_Text(self, node):
- return tags.li(tags.a(node.text, _class='nav-link disabled'), _class="nav-item")
+ return tags.li(tags.a(node.text, _class="nav-link disabled"), _class="nav-item")
def visit_Link(self, node):
- item = tags.li(_class='nav-item')
- item.add(tags.a(node.text, href=node.get_url(), _class='nav-link'))
+ item = tags.li(_class="nav-item")
+ item.add(tags.a(node.text, href=node.get_url(), _class="nav-link"))
return item
def visit_View(self, node):
# Dinamically modify the url
- node.url_for_kwargs.update({'database_name': self.database_name})
+ node.url_for_kwargs.update({"database_name": self.database_name})
- item = tags.li(_class='nav-item')
- item.add(tags.a(node.text, href=node.get_url(), title=node.text, _class='nav-link'))
+ item = tags.li(_class="nav-item")
+ item.add(
+ tags.a(node.text, href=node.get_url(), title=node.text, _class="nav-link")
+ )
if node.active:
- item['class'] = 'nav-item active'
+ item["class"] = "nav-item active"
return item
@@ -78,34 +80,41 @@ def __init__(self, nav_id=None):
def visit_Navbar(self, node):
node_id = self.id or sha1(str(id(node)).encode()).hexdigest()
- root = tags.nav(_class='navbar navbar-expand-md navbar-dark bg-dark fixed-top')
+ root = tags.nav(_class="navbar navbar-expand-md navbar-dark bg-dark fixed-top")
# title may also have a 'get_url()' method, in which case we render
# a brand-link
if node.title is not None:
- if hasattr(node.title, 'get_url'):
- root.add(tags.a(node.title.text, _class='navbar-brand',
- href=node.title.get_url()))
+ if hasattr(node.title, "get_url"):
+ root.add(
+ tags.a(
+ node.title.text,
+ _class="navbar-brand",
+ href=node.title.get_url(),
+ )
+ )
else:
- root.add(tags.span(node.title, _class='navbar-brand'))
+ root.add(tags.span(node.title, _class="navbar-brand"))
btn = root.add(tags.button())
- btn['type'] = 'button'
- btn['class'] = 'navbar-toggler'
- btn['data-toggle'] = 'collapse'
- btn['data-target'] = '#' + node_id
- btn['aria-controls'] = 'navbarCollapse'
- btn['aria-expanded'] = 'false'
- btn['aria-label'] = "Toggle navigation"
+ btn["type"] = "button"
+ btn["class"] = "navbar-toggler"
+ btn["data-toggle"] = "collapse"
+ btn["data-target"] = "#" + node_id
+ btn["aria-controls"] = "navbarCollapse"
+ btn["aria-expanded"] = "false"
+ btn["aria-label"] = "Toggle navigation"
- btn.add(tags.span('', _class='navbar-toggler-icon'))
+ btn.add(tags.span("", _class="navbar-toggler-icon"))
- bar = root.add(tags.div(
- _class='navbar-collapse collapse',
- id=node_id,
- ))
+ bar = root.add(
+ tags.div(
+ _class="navbar-collapse collapse",
+ id=node_id,
+ )
+ )
- bar_list = bar.add(tags.ul(_class='navbar-nav mr-auto'))
+ bar_list = bar.add(tags.ul(_class="navbar-nav mr-auto"))
for item in node.items:
bar_list.add(self.visit(item))
@@ -117,56 +126,58 @@ def visit_Navbar(self, node):
# search_input['aria-label'] = "Search"
search_btn = search_form.add(tags.button(_class="btn btn-success my-2 my-sm-0"))
- search_btn['type'] = "submit"
- search_btn.add_raw_string('+')
+ search_btn["type"] = "submit"
+ search_btn.add_raw_string("+")
search_btn = search_form.add(tags.button(_class="btn btn-success my-2 my-sm-0"))
- search_btn['type'] = "submit"
- search_btn.add_raw_string('Login')
+ search_btn["type"] = "submit"
+ search_btn.add_raw_string("Login")
return root
def visit_Text(self, node):
if self._in_dropdown:
- return tags.a(node.text, _class='dropdown-item disabled')
+ return tags.a(node.text, _class="dropdown-item disabled")
- return tags.li(tags.a(node.text, _class='nav-link disabled'), _class="nav-item")
+ return tags.li(tags.a(node.text, _class="nav-link disabled"), _class="nav-item")
def visit_Link(self, node):
if self._in_dropdown:
- return tags.a(node.text, href=node.get_url(), _class='dropdown-item')
+ return tags.a(node.text, href=node.get_url(), _class="dropdown-item")
- item = tags.li(_class='nav-item')
- item.add(tags.a(node.text, href=node.get_url(), _class='nav-link'))
+ item = tags.li(_class="nav-item")
+ item.add(tags.a(node.text, href=node.get_url(), _class="nav-link"))
return item
def visit_View(self, node):
if self._in_dropdown:
- return tags.a(node.text, href=node.get_url(), _class='dropdown-item')
+ return tags.a(node.text, href=node.get_url(), _class="dropdown-item")
- item = tags.li(_class='nav-item')
- item.add(tags.a(node.text, href=node.get_url(), title=node.text, _class='nav-link'))
+ item = tags.li(_class="nav-item")
+ item.add(
+ tags.a(node.text, href=node.get_url(), title=node.text, _class="nav-link")
+ )
if node.active:
- item['class'] = 'nav-item active'
+ item["class"] = "nav-item active"
return item
def visit_Subgroup(self, node):
if self._in_dropdown:
- raise RuntimeError('Cannot render nested Subgroups')
+ raise RuntimeError("Cannot render nested Subgroups")
- li = tags.li(_class='nav-item dropdown')
+ li = tags.li(_class="nav-item dropdown")
if node.active:
- li['class'] = 'nav-item dropdown active'
+ li["class"] = "nav-item dropdown active"
- a = li.add(tags.a(node.title, href='#', _class='nav-link dropdown-toggle'))
- a['data-toggle'] = 'dropdown'
- a['aria-haspopup'] = 'true'
- a['aria-expanded'] = 'false'
+ a = li.add(tags.a(node.title, href="#", _class="nav-link dropdown-toggle"))
+ a["data-toggle"] = "dropdown"
+ a["aria-haspopup"] = "true"
+ a["aria-expanded"] = "false"
- dropdown_div = li.add(tags.div(_class='dropdown-menu'))
+ dropdown_div = li.add(tags.div(_class="dropdown-menu"))
self._in_dropdown = True
for item in node.items:
@@ -177,6 +188,6 @@ def visit_Subgroup(self, node):
def visit_Separator(self, node):
if self._in_dropdown:
- return tags.div(_class='dropdown-divider')
+ return tags.div(_class="dropdown-divider")
- raise RuntimeError('Cannot render separator outside Subgroup.')
+ raise RuntimeError("Cannot render separator outside Subgroup.")
diff --git a/abcd/server/app/views/api.py b/abcd/server/app/views/api.py
index 7a02d307..b3aa6d25 100644
--- a/abcd/server/app/views/api.py
+++ b/abcd/server/app/views/api.py
@@ -1,21 +1,20 @@
from flask import Blueprint, Response, make_response, jsonify, request
-bp = Blueprint('api', __name__)
+bp = Blueprint("api", __name__)
-@bp.route('/')
+@bp.route("/")
def index():
- return Response('ok', status=200)
+ return Response("ok", status=200)
+
# endpoint to create new user
@bp.route("/calculation", methods=["POST"])
def query_calculation():
- response = {
- 'query': request.json,
- 'results': []
- }
+ response = {"query": request.json, "results": []}
return jsonify(response)
+
# # endpoint to show all users
# @bp.route("/calculation", methods=["GET"])
# def get_calculation():
diff --git a/abcd/server/app/views/database.py b/abcd/server/app/views/database.py
index 52dc0588..8569dcd2 100644
--- a/abcd/server/app/views/database.py
+++ b/abcd/server/app/views/database.py
@@ -1,31 +1,31 @@
from flask import Blueprint, render_template, request
from flask import Response
-bp = Blueprint('database', __name__)
+bp = Blueprint("database", __name__)
-@bp.route('/')
+@bp.route("/")
def index():
return Response(status=200)
# Our index-page just shows a quick explanation. Check out the template
# "templates/index.html" documentation for more details.
-@bp.route('//', methods=['GET'])
+@bp.route("//", methods=["GET"])
def database(database_name):
# data = Atoms.objects()
# list(Atoms.objects.aggregate({'$unwind': '$derived.arrays_keys'}, {'$group': {'_id': '$derived.arrays_keys', 'count': {'$sum': 1}}}))
- if request.method == 'POST':
- print('POST')
+ if request.method == "POST":
+ print("POST")
info = {
- 'name': database_name,
- 'description': 'Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.',
- 'columns': [
- {'slug': 'formula', 'name': 'Formula'},
- {'slug': 'energy', 'name': 'Energy'},
- {'slug': 'derived.n_atoms', 'name': "# of atoms"}
+ "name": database_name,
+ "description": "Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.",
+ "columns": [
+ {"slug": "formula", "name": "Formula"},
+ {"slug": "energy", "name": "Energy"},
+ {"slug": "derived.n_atoms", "name": "# of atoms"},
],
}
@@ -34,23 +34,23 @@ def database(database_name):
# page = request.args.get('page', 1, type=int)
# paginated_atoms = atoms.paginate(page, per_page=10)
- return render_template("database/database.html",
- info=info,
- atoms=paginated_atoms)
+ return render_template("database/database.html", info=info, atoms=paginated_atoms)
# Our index-page just shows a quick explanation. Check out the template
# "templates/index.html" documentation for more details.
-@bp.route('//settings')
+@bp.route("//settings")
def settings(database_name):
info = {
- 'name': database_name,
- 'description': 'Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.',
- 'columns': [
- {'slug': 'formula', 'name': 'Formula'},
- {'slug': 'energy', 'name': 'Energy'},
- {'slug': 'derived.n_atoms', 'name': "# of atoms"}
+ "name": database_name,
+ "description": "Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.",
+ "columns": [
+ {"slug": "formula", "name": "Formula"},
+ {"slug": "energy", "name": "Energy"},
+ {"slug": "derived.n_atoms", "name": "# of atoms"},
],
- 'public': True
+ "public": True,
}
- return render_template("database/settings.html", database_name=database_name, info=info)
+ return render_template(
+ "database/settings.html", database_name=database_name, info=info
+ )
diff --git a/abcd/server/app/views/index.py b/abcd/server/app/views/index.py
index b283d999..37faa52a 100644
--- a/abcd/server/app/views/index.py
+++ b/abcd/server/app/views/index.py
@@ -4,24 +4,24 @@
# from flask import jsonify
# import requests
-bp = Blueprint('index', __name__)
+bp = Blueprint("index", __name__)
-@bp.route('/')
+@bp.route("/")
def index():
return render_template("index.html")
-@bp.route('/login/')
+@bp.route("/login/")
def login():
return render_template("login.html")
-@bp.route('/new/')
+@bp.route("/new/")
def new():
return render_template("new.html")
-@bp.route('/graphql')
+@bp.route("/graphql")
def graphql():
- return render_template("graphql.html", url=url_for('api.graphql'))
+ return render_template("graphql.html", url=url_for("api.graphql"))
diff --git a/tests/test_database.py b/tests/test_database.py
index 09340849..82cddfab 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -8,9 +8,9 @@
@pytest.fixture
-@mongomock.patch(servers=(('localhost', 27017),))
+@mongomock.patch(servers=(("localhost", 27017),))
def abcd_mongodb():
- url = 'mongodb://localhost'
+ url = "mongodb://localhost"
abcd = ABCD.from_url(url)
abcd.print_info()
@@ -22,13 +22,15 @@ def test_thing(abcd_mongodb):
def test_push(abcd_mongodb):
- xyz = StringIO("""2
+ xyz = StringIO(
+ """2
Properties=species:S:1:pos:R:3 s="sadf" _vtk_test="t e s t _ s t r" pbc="F F F"
-Si 0.00000000 0.00000000 0.00000000
-Si 0.00000000 0.00000000 0.00000000
-""")
+Si 0.00000000 0.00000000 0.00000000
+Si 0.00000000 0.00000000 0.00000000
+"""
+ )
- atoms = read(xyz, format='extxyz')
+ atoms = read(xyz, format="extxyz")
atoms.set_cell([1, 1, 1])
abcd_mongodb.destroy()
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
index 02a6d98a..5a2211e0 100644
--- a/tests/test_parsers.py
+++ b/tests/test_parsers.py
@@ -4,7 +4,6 @@
class TestParsingExtras:
-
@pytest.fixture
def parser(self):
return extras_parser
@@ -140,16 +139,15 @@ def test_composite(self, parser):
'a3x3_array="1 4 7 2 5 8 3 6 9" ' # fortran ordering
'Lattice=" 4.3 0.0 0.0 0.0 3.3 0.0 0.0 0.0 7.0 " ' # spaces in array
'comma_separated="7, 4, -1"',
- 'array_boolean_2=" T, F, T " ' # leading spaces
- 'not_array="1.2 3.4 text"',
+ 'array_boolean_2=" T, F, T " ' 'not_array="1.2 3.4 text"', # leading spaces
"not_bool_array=[T F S]",
],
)
- def test_missing(self, string): ...
+ def test_missing(self, string):
+ ...
class TestParsingQueries:
-
@pytest.fixture
def parser(self):
return queries_parser
@@ -190,7 +188,8 @@ def test_combination(self, parser, string, expected):
("any(aa) > 3", {}),
],
)
- def test_expressions(self, case): ...
+ def test_expressions(self, case):
+ ...
@pytest.mark.skip("known issues / future features")
@pytest.mark.parametrize(
@@ -203,4 +202,5 @@ def test_expressions(self, case): ...
("aa and (bb > 23.54 or (22 in cc and dd))", {}),
],
)
- def test_expressions(self, case): ...
+ def test_expressions(self, case):
+ ...
diff --git a/tutorials/gb_upload.py b/tutorials/gb_upload.py
index 1b0680d8..3d276de2 100644
--- a/tutorials/gb_upload.py
+++ b/tutorials/gb_upload.py
@@ -1,23 +1,18 @@
import sys
from pathlib import Path
-sys.path.append('..')
+sys.path.append("..")
from abcd import ABCD
from utils.ext_xyz import XYZReader
-if __name__ == '__main__':
+if __name__ == "__main__":
- url = 'mongodb://localhost:27017'
+ url = "mongodb://localhost:27017"
abcd = ABCD(url)
- for file in Path('GB_alphaFe_001/tilt/').glob('*.xyz'):
+ for file in Path("GB_alphaFe_001/tilt/").glob("*.xyz"):
print(file)
- gb_params = {
- 'name': 'alphaFe',
- 'type': 'tilt',
- 'params': file.name[:-4]
-
- }
+ gb_params = {"name": "alphaFe", "type": "tilt", "params": file.name[:-4]}
with abcd as db, XYZReader(file) as reader:
- db.push(reader.read_atoms(), extra_info={'GB_params': gb_params})
+ db.push(reader.read_atoms(), extra_info={"GB_params": gb_params})
diff --git a/tutorials/scripts/Preprocess.py b/tutorials/scripts/Preprocess.py
index 6b29c474..49a315d8 100644
--- a/tutorials/scripts/Preprocess.py
+++ b/tutorials/scripts/Preprocess.py
@@ -4,6 +4,7 @@
from ase.io import read, write
from ase.geometry import crystal_structure_from_cell
import numpy as np
+
# import numpy.linalg as la
import matplotlib.pyplot as plt
@@ -12,7 +13,7 @@
class Calculation(object):
def __init__(self, *args, **kwargs):
- self.filepath = kwargs.pop('filepath', None)
+ self.filepath = kwargs.pop("filepath", None)
self.parameters = kwargs
def get_data(self, index=-1):
@@ -20,20 +21,19 @@ def get_data(self, index=-1):
@classmethod
def from_path(cls, path: Path):
- with (path / 'gb.json').open() as data_file:
+ with (path / "gb.json").open() as data_file:
gb_data = json.load(data_file)
- with (path / 'subgb.json').open() as data_file:
+ with (path / "subgb.json").open() as data_file:
subgb_data = json.load(data_file)
# print(gb_data['angle'])
- filename = subgb_data['name'] + "_traj.xyz"
+ filename = subgb_data["name"] + "_traj.xyz"
filepath = (path / filename).resolve()
# configuration = read(str((path / filename).resolve()), index=index)
# # gb = read(str((path / filename).resolve()), index=-1)
-
# print('{:=^60}'.format(' '+str(path)+' '))
#
# print('{:-^40}'.format(' gb.json '))
@@ -60,27 +60,34 @@ def from_path(cls, path: Path):
# print('Force mean: {:f}, std: {:f}'.format(force_mean, force_std))
# pprint(gb_final.calc.results)
-
return cls(**{**gb_data, **subgb_data}, filepath=filepath)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Read grain boundary database
- dirpath = Path('../GB_alphaFe_001')
+ dirpath = Path("../GB_alphaFe_001")
calculations = {
- 'tilt': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'tilt').iterdir() if calc_dir.is_dir()],
- 'twist': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'twist').iterdir() if calc_dir.is_dir()]
+ "tilt": [
+ Calculation.from_path(calc_dir)
+ for calc_dir in (dirpath / "tilt").iterdir()
+ if calc_dir.is_dir()
+ ],
+ "twist": [
+ Calculation.from_path(calc_dir)
+ for calc_dir in (dirpath / "twist").iterdir()
+ if calc_dir.is_dir()
+ ],
}
# potential energy of the perfect crystal according to a specific potential
- potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH
+ potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH
eV = 1.6021766208e-19
- Angstrom = 1.e-10
+ Angstrom = 1.0e-10
angles, energies = [], []
- for calc in sorted(calculations['tilt'], key=lambda item: item.parameters['angle']):
+ for calc in sorted(calculations["tilt"], key=lambda item: item.parameters["angle"]):
# E_gb = calc.parameters.get('E_gb', None)
#
@@ -92,7 +99,7 @@ def from_path(cls, path: Path):
# energy = 16.02 / (2 * calc.parameters['A'] ) * \
# (E_gb - potential_energy_per_atom * calc.parameters['n_at'])
- if calc.parameters.get('converged', None):
+ if calc.parameters.get("converged", None):
# energy = 16.02 / (2 * calc.parameters['A'] ) * \
# (calc.parameters.get('E_gb') - potential_energy_per_atom * calc.parameters['n_at'])
#
@@ -101,24 +108,26 @@ def from_path(cls, path: Path):
A = cell[0, 0] * cell[1, 1]
energy = (
- eV / Angstrom**2 /
- (2 * A) *
- (atoms.get_potential_energy() - potential_energy_per_atom * len(atoms))
+ eV
+ / Angstrom**2
+ / (2 * A)
+ * (
+ atoms.get_potential_energy()
+ - potential_energy_per_atom * len(atoms)
+ )
)
write(calc.filepath.name, atoms)
-
# print(energy)
# print(calc.parameters['converged'])
# print(data.get_potential_energy()) # data.get_total_energy() == data.get_potential_energy()
# energies.append(calc.parameters['E_gb'] - data.get_total_energy())
energies.append(energy)
- angles.append(calc.parameters['angle'] * 180.0 / np.pi)
+ angles.append(calc.parameters["angle"] * 180.0 / np.pi)
else:
print("not converged: ", calc.filepath)
-
plt.bar(angles, energies)
# x_smooth = np.linspace(min(angles), max(angles), 1000, endpoint=True)
@@ -126,5 +135,3 @@ def from_path(cls, path: Path):
# plt.plot(x_smooth, f(x_smooth), '-')
plt.show()
-
-
diff --git a/tutorials/scripts/Reader.py b/tutorials/scripts/Reader.py
index 3bd0ff3f..83d6be52 100644
--- a/tutorials/scripts/Reader.py
+++ b/tutorials/scripts/Reader.py
@@ -4,6 +4,7 @@
from ase.io import read, write
from ase.geometry import crystal_structure_from_cell
import numpy as np
+
# import numpy.linalg as la
import matplotlib.pyplot as plt
@@ -12,7 +13,7 @@
class Calculation(object):
def __init__(self, *args, **kwargs):
- self.filepath = kwargs.pop('filepath', None)
+ self.filepath = kwargs.pop("filepath", None)
self.parameters = kwargs
def get_data(self, index=-1):
@@ -20,20 +21,19 @@ def get_data(self, index=-1):
@classmethod
def from_path(cls, path: Path, index=-1):
- with (path / 'gb.json').open() as data_file:
+ with (path / "gb.json").open() as data_file:
gb_data = json.load(data_file)
- with (path / 'subgb.json').open() as data_file:
+ with (path / "subgb.json").open() as data_file:
subgb_data = json.load(data_file)
# print(gb_data['angle'])
- filename = subgb_data['name'] + "_traj.xyz"
+ filename = subgb_data["name"] + "_traj.xyz"
filepath = (path / filename).resolve()
# configuration = read(str((path / filename).resolve()), index=index)
# # gb = read(str((path / filename).resolve()), index=-1)
-
# print('{:=^60}'.format(' '+str(path)+' '))
#
# print('{:-^40}'.format(' gb.json '))
@@ -60,41 +60,54 @@ def from_path(cls, path: Path, index=-1):
# print('Force mean: {:f}, std: {:f}'.format(force_mean, force_std))
# pprint(gb_final.calc.results)
-
return cls(**{**gb_data, **subgb_data}, filepath=filepath)
-if __name__ == '__main__':
+if __name__ == "__main__":
# Read grain boundary database
- dirpath = Path('../GB_alphaFe_001')
+ dirpath = Path("../GB_alphaFe_001")
calculations = {
- 'tilt': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'tilt').iterdir() if calc_dir.is_dir()],
- 'twist': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'twist').iterdir() if calc_dir.is_dir()]
+ "tilt": [
+ Calculation.from_path(calc_dir)
+ for calc_dir in (dirpath / "tilt").iterdir()
+ if calc_dir.is_dir()
+ ],
+ "twist": [
+ Calculation.from_path(calc_dir)
+ for calc_dir in (dirpath / "twist").iterdir()
+ if calc_dir.is_dir()
+ ],
}
# potential energy of the perfect crystal according to a specific potential
- potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH
+ potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH
eV = 1.6021766208e-19
- Angstrom = 1.e-10
+ Angstrom = 1.0e-10
angles, energies = [], []
- for calc in sorted(calculations['tilt'], key=lambda item: item.parameters['angle']):
- angles.append(calc.parameters['angle'] * 180.0 / np.pi)
-
+ for calc in sorted(calculations["tilt"], key=lambda item: item.parameters["angle"]):
+ angles.append(calc.parameters["angle"] * 180.0 / np.pi)
- energy = 16.02 / (2 * calc.parameters['A'] ) * \
- (calc.parameters['E_gb'] - potential_energy_per_atom * calc.parameters['n_at'])
+ energy = (
+ 16.02
+ / (2 * calc.parameters["A"])
+ * (
+ calc.parameters["E_gb"]
+ - potential_energy_per_atom * calc.parameters["n_at"]
+ )
+ )
atoms = calc.get_data()
cell = atoms.get_cell()
A = cell[0, 0] * cell[1, 1]
energy = (
- eV / Angstrom**2 /
- (2 * A) *
- (atoms.get_total_energy() - potential_energy_per_atom * len(atoms))
+ eV
+ / Angstrom**2
+ / (2 * A)
+ * (atoms.get_total_energy() - potential_energy_per_atom * len(atoms))
)
print(energy)
@@ -102,7 +115,6 @@ def from_path(cls, path: Path, index=-1):
# energies.append(calc.parameters['E_gb'] - data.get_total_energy())
energies.append(energy)
-
plt.bar(angles, energies)
# x_smooth = np.linspace(min(angles), max(angles), 1000, endpoint=True)
@@ -186,7 +198,7 @@ def from_path(cls, path: Path, index=-1):
# print at.get_potential_energy()
# print E_gb, 'eV/A^2'
# E_gb = 16.02*(at.get_potential_energy()-(at.n*(E_bulk)))/(2.*A)
- # print E_gb, 'J/m^2'
+# print E_gb, 'J/m^2'
# return E_gb
#
# relax.py
@@ -194,4 +206,4 @@ def from_path(cls, path: Path, index=-1):
# E_gb = grain.get_potential_energy()
#
#
-# ener_per_atom = -4.01298214176
\ No newline at end of file
+# ener_per_atom = -4.01298214176
diff --git a/tutorials/scripts/Visualise.py b/tutorials/scripts/Visualise.py
index 1ed65c2d..308e76fb 100644
--- a/tutorials/scripts/Visualise.py
+++ b/tutorials/scripts/Visualise.py
@@ -7,13 +7,12 @@
import matplotlib.pyplot as plt
import numpy as np
-
-@register_backend('ase')
+@register_backend("ase")
class MyASEStructure(Structure):
def __init__(self, atoms, bfactor=[], occupancy=[]):
# super(MyASEStructure, self).__init__()
- self.ext = 'pdb'
+ self.ext = "pdb"
self.params = {}
self._atoms = atoms
self.bfactor = bfactor # [min, max]
@@ -21,7 +20,7 @@ def __init__(self, atoms, bfactor=[], occupancy=[]):
self.id = str(uuid.uuid4())
def get_structure_string(self):
- """ PDB file format:
+ """PDB file format:
CRYST1 16.980 62.517 124.864 90.00 90.00 90.00 P 1
MODEL 1
ATOM 0 Fe MOL 1 15.431 60.277 6.801 1.00 0.00 FE
@@ -36,12 +35,12 @@ def get_structure_string(self):
if self._atoms.get_pbc().any():
cellpar = self._atoms.get_cell_lengths_and_angles()
- str_format = 'CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + ' P 1\n'
+ str_format = "CRYST1" + "{:9.3f}" * 3 + "{:7.2f}" * 3 + " P 1\n"
data += str_format.format(*cellpar.tolist())
- data += 'MODEL 1\n'
+ data += "MODEL 1\n"
- str_format = 'ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n'
+ str_format = "ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n"
for index, atom in enumerate(self._atoms):
data += str_format.format(
index,
@@ -51,10 +50,10 @@ def get_structure_string(self):
atom.position[2].tolist(),
self.occupancy[index] if index <= len(self.occupancy) - 1 else 1.0,
self.bfactor[index] if index <= len(self.bfactor) - 1 else 1.0,
- atom.symbol.upper()
+ atom.symbol.upper(),
)
- data += 'ENDMDL\n'
+ data += "ENDMDL\n"
return data
@@ -63,12 +62,11 @@ def ViewStructure(atoms):
import nglview
view = nglview.NGLWidget()
-
+
structure = MyASEStructure(atoms)
view.add_structure(structure)
-
- return view
+ return view
class AtomViewer(object):
@@ -76,38 +74,35 @@ def __init__(self, atoms, data=[], xsize=1000, ysize=500):
self.view = self._init_nglview(atoms, data, xsize, ysize)
self.widgets = {
- 'radius': FloatSlider(
- value=0.8, min=0.0, max=1.5, step=0.01,
- description='Ball size'
+ "radius": FloatSlider(
+ value=0.8, min=0.0, max=1.5, step=0.01, description="Ball size"
),
- 'color_scheme': Dropdown(description='Solor scheme:'),
- 'colorbar': Output()
+ "color_scheme": Dropdown(description="Solor scheme:"),
+ "colorbar": Output(),
}
self.show_colorbar(data)
- self.widgets['radius'].observe(self._update_repr)
+ self.widgets["radius"].observe(self._update_repr)
- self.gui = VBox([
- self.view,
- self.widgets['colorbar'],
- self.widgets['radius']])
+ self.gui = VBox([self.view, self.widgets["colorbar"], self.widgets["radius"]])
def _update_repr(self, chg=None):
self.view.update_spacefill(
- radiusType='radius',
- radius=self.widgets['radius'].value
+ radiusType="radius", radius=self.widgets["radius"].value
)
def show_colorbar(self, data):
- with self.widgets['colorbar']:
+ with self.widgets["colorbar"]:
# Have colormaps separated into categories:
# http://matplotlib.org/examples/color/colormaps_reference.html
- cmap = 'rainbow'
+ cmap = "rainbow"
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 2))
- img = ax1.imshow([[min(data), max(data)]], aspect='auto', cmap=plt.get_cmap(cmap))
+ img = ax1.imshow(
+ [[min(data), max(data)]], aspect="auto", cmap=plt.get_cmap(cmap)
+ )
ax1.remove()
- cbar = fig.colorbar(img, cax=ax2, orientation='horizontal')
+ cbar = fig.colorbar(img, cax=ax2, orientation="horizontal")
plt.show()
@@ -117,15 +112,12 @@ def _init_nglview(atoms, data, xsize, ysize):
view = nglview.NGLWidget(gui=False)
view._remote_call(
- 'setSize',
- target='Widget',
- args=[
- '{:d}px'.format(xsize),
- '{:d}px'.format(ysize)
- ]
+ "setSize",
+ target="Widget",
+ args=["{:d}px".format(xsize), "{:d}px".format(ysize)],
)
- data = np.max(data)-data
+ data = np.max(data) - data
structure = MyASEStructure(atoms, bfactor=data)
view.add_structure(structure)
@@ -136,19 +128,15 @@ def _init_nglview(atoms, data, xsize, ysize):
view.add_spacefill(
# radiusType='radius',
# radius=1.0,
- color_scheme='bfactor',
- color_scale='rainbow'
- )
- view.update_spacefill(
- radiusType='radius',
- radius=1.0
+ color_scheme="bfactor",
+ color_scale="rainbow",
)
-
+ view.update_spacefill(radiusType="radius", radius=1.0)
# update camera type
view.control.spin([1, 0, 0], np.pi / 2)
view.control.spin([0, 0, 1], np.pi / 2)
- view.camera = 'orthographic'
+ view.camera = "orthographic"
view.center()
return view
diff --git a/tutorials/scripts/Visualise_quip.py b/tutorials/scripts/Visualise_quip.py
index 1ed65c2d..308e76fb 100644
--- a/tutorials/scripts/Visualise_quip.py
+++ b/tutorials/scripts/Visualise_quip.py
@@ -7,13 +7,12 @@
import matplotlib.pyplot as plt
import numpy as np
-
-@register_backend('ase')
+@register_backend("ase")
class MyASEStructure(Structure):
def __init__(self, atoms, bfactor=[], occupancy=[]):
# super(MyASEStructure, self).__init__()
- self.ext = 'pdb'
+ self.ext = "pdb"
self.params = {}
self._atoms = atoms
self.bfactor = bfactor # [min, max]
@@ -21,7 +20,7 @@ def __init__(self, atoms, bfactor=[], occupancy=[]):
self.id = str(uuid.uuid4())
def get_structure_string(self):
- """ PDB file format:
+ """PDB file format:
CRYST1 16.980 62.517 124.864 90.00 90.00 90.00 P 1
MODEL 1
ATOM 0 Fe MOL 1 15.431 60.277 6.801 1.00 0.00 FE
@@ -36,12 +35,12 @@ def get_structure_string(self):
if self._atoms.get_pbc().any():
cellpar = self._atoms.get_cell_lengths_and_angles()
- str_format = 'CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + ' P 1\n'
+ str_format = "CRYST1" + "{:9.3f}" * 3 + "{:7.2f}" * 3 + " P 1\n"
data += str_format.format(*cellpar.tolist())
- data += 'MODEL 1\n'
+ data += "MODEL 1\n"
- str_format = 'ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n'
+ str_format = "ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n"
for index, atom in enumerate(self._atoms):
data += str_format.format(
index,
@@ -51,10 +50,10 @@ def get_structure_string(self):
atom.position[2].tolist(),
self.occupancy[index] if index <= len(self.occupancy) - 1 else 1.0,
self.bfactor[index] if index <= len(self.bfactor) - 1 else 1.0,
- atom.symbol.upper()
+ atom.symbol.upper(),
)
- data += 'ENDMDL\n'
+ data += "ENDMDL\n"
return data
@@ -63,12 +62,11 @@ def ViewStructure(atoms):
import nglview
view = nglview.NGLWidget()
-
+
structure = MyASEStructure(atoms)
view.add_structure(structure)
-
- return view
+ return view
class AtomViewer(object):
@@ -76,38 +74,35 @@ def __init__(self, atoms, data=[], xsize=1000, ysize=500):
self.view = self._init_nglview(atoms, data, xsize, ysize)
self.widgets = {
- 'radius': FloatSlider(
- value=0.8, min=0.0, max=1.5, step=0.01,
- description='Ball size'
+ "radius": FloatSlider(
+ value=0.8, min=0.0, max=1.5, step=0.01, description="Ball size"
),
- 'color_scheme': Dropdown(description='Solor scheme:'),
- 'colorbar': Output()
+ "color_scheme": Dropdown(description="Solor scheme:"),
+ "colorbar": Output(),
}
self.show_colorbar(data)
- self.widgets['radius'].observe(self._update_repr)
+ self.widgets["radius"].observe(self._update_repr)
- self.gui = VBox([
- self.view,
- self.widgets['colorbar'],
- self.widgets['radius']])
+ self.gui = VBox([self.view, self.widgets["colorbar"], self.widgets["radius"]])
def _update_repr(self, chg=None):
self.view.update_spacefill(
- radiusType='radius',
- radius=self.widgets['radius'].value
+ radiusType="radius", radius=self.widgets["radius"].value
)
def show_colorbar(self, data):
- with self.widgets['colorbar']:
+ with self.widgets["colorbar"]:
# Have colormaps separated into categories:
# http://matplotlib.org/examples/color/colormaps_reference.html
- cmap = 'rainbow'
+ cmap = "rainbow"
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 2))
- img = ax1.imshow([[min(data), max(data)]], aspect='auto', cmap=plt.get_cmap(cmap))
+ img = ax1.imshow(
+ [[min(data), max(data)]], aspect="auto", cmap=plt.get_cmap(cmap)
+ )
ax1.remove()
- cbar = fig.colorbar(img, cax=ax2, orientation='horizontal')
+ cbar = fig.colorbar(img, cax=ax2, orientation="horizontal")
plt.show()
@@ -117,15 +112,12 @@ def _init_nglview(atoms, data, xsize, ysize):
view = nglview.NGLWidget(gui=False)
view._remote_call(
- 'setSize',
- target='Widget',
- args=[
- '{:d}px'.format(xsize),
- '{:d}px'.format(ysize)
- ]
+ "setSize",
+ target="Widget",
+ args=["{:d}px".format(xsize), "{:d}px".format(ysize)],
)
- data = np.max(data)-data
+ data = np.max(data) - data
structure = MyASEStructure(atoms, bfactor=data)
view.add_structure(structure)
@@ -136,19 +128,15 @@ def _init_nglview(atoms, data, xsize, ysize):
view.add_spacefill(
# radiusType='radius',
# radius=1.0,
- color_scheme='bfactor',
- color_scale='rainbow'
- )
- view.update_spacefill(
- radiusType='radius',
- radius=1.0
+ color_scheme="bfactor",
+ color_scale="rainbow",
)
-
+ view.update_spacefill(radiusType="radius", radius=1.0)
# update camera type
view.control.spin([1, 0, 0], np.pi / 2)
view.control.spin([0, 0, 1], np.pi / 2)
- view.camera = 'orthographic'
+ view.camera = "orthographic"
view.center()
return view
diff --git a/tutorials/test_db.py b/tutorials/test_db.py
index 8f1fb542..8a8f34f8 100644
--- a/tutorials/test_db.py
+++ b/tutorials/test_db.py
@@ -3,7 +3,7 @@
from abcd import ABCD
-if __name__ == '__main__':
+if __name__ == "__main__":
# http requests
# url = 'http://localhost:5000/api'
@@ -12,9 +12,9 @@
# Mongoengine
# https://stackoverflow.com/questions/36200288/mongolab-pymongo-connection-error
- url = 'mongodb://root:example@localhost:27018/?authSource=admin'
+ url = "mongodb://root:example@localhost:27018/?authSource=admin"
- abcd = ABCD(url, db='abcd', collection='default')
+ abcd = ABCD(url, db="abcd", collection="default")
print(abcd)
abcd.print_info()
@@ -24,8 +24,8 @@
abcd.destroy()
- direcotry = Path('../tutorials/data/')
- file = direcotry / 'bcc_bulk_54_expanded_2_high.xyz'
+ direcotry = Path("../tutorials/data/")
+ file = direcotry / "bcc_bulk_54_expanded_2_high.xyz"
# file = direcotry / 'GAP_6.xyz'
traj = read(file.as_posix(), index=slice(None))
diff --git a/tutorials/test_upload.py b/tutorials/test_upload.py
index 0f3a34bd..c5fe8e91 100644
--- a/tutorials/test_upload.py
+++ b/tutorials/test_upload.py
@@ -3,7 +3,7 @@
from abcd import ABCD
-if __name__ == '__main__':
- abcd = ABCD(url='mongodb://localhost:27017')
+if __name__ == "__main__":
+ abcd = ABCD(url="mongodb://localhost:27017")
print(abcd)
abcd.print_info()