diff --git a/abcd/__init__.py b/abcd/__init__.py index 52279171..b8008379 100644 --- a/abcd/__init__.py +++ b/abcd/__init__.py @@ -14,7 +14,7 @@ class ABCD(object): @classmethod def from_config(cls, config): # Factory method - url = config['url'] + url = config["url"] return ABCD.from_url(url) @classmethod @@ -23,35 +23,38 @@ def from_url(cls, url, **kwargs): r = parse.urlparse(url) logger.info(r) - if r.scheme == 'mongodb': + if r.scheme == "mongodb": conn_settings = { - 'host': r.hostname, - 'port': r.port, - 'username': r.username, - 'password': r.password, - 'authSource': 'admin', + "host": r.hostname, + "port": r.port, + "username": r.username, + "password": r.password, + "authSource": "admin", } - db = r.path.split('/')[1] if r.path else None - db = db if db else 'abcd' + db = r.path.split("/")[1] if r.path else None + db = db if db else "abcd" from abcd.backends.atoms_pymongo import MongoDatabase + return MongoDatabase(db_name=db, **conn_settings, **kwargs) - elif r.scheme == 'http' or r.scheme == 'https': - raise NotImplementedError('http not yet supported! soon...') - elif r.scheme == 'ssh': - raise NotImplementedError('ssh not yet supported! soon...') + elif r.scheme == "http" or r.scheme == "https": + raise NotImplementedError("http not yet supported! soon...") + elif r.scheme == "ssh": + raise NotImplementedError("ssh not yet supported! soon...") else: - raise NotImplementedError('Unable to recognise the type of connection. (url: {})'.format(url)) + raise NotImplementedError( + "Unable to recognise the type of connection. (url: {})".format(url) + ) -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) # url = 'mongodb://mongoadmin:secret@localhost:27017' - url = 'mongodb://mongoadmin:secret@localhost:27017/abcd_new' + url = "mongodb://mongoadmin:secret@localhost:27017/abcd_new" abcd = ABCD.from_url(url) abcd.print_info() diff --git a/abcd/backends/atoms_http.py b/abcd/backends/atoms_http.py index 1ff6e95f..ee62e61a 100644 --- a/abcd/backends/atoms_http.py +++ b/abcd/backends/atoms_http.py @@ -13,16 +13,15 @@ class Atoms(ase.Atoms): - @classmethod def from_dict(cls, data): - return cls(numbers=data['numbers'], positions=data['positions']) + return cls(numbers=data["numbers"], positions=data["positions"]) class HttpDatabase(Database): """client/local interface""" - def __init__(self, url='http://localhost'): + def __init__(self, url="http://localhost"): super().__init__() self.url = url @@ -30,7 +29,9 @@ def __init__(self, url='http://localhost'): def push(self, atoms: ase.Atoms): # todo: list of Atoms, metadata(user, project, tags) - message = requests.put(self.url + '/calculation', json=DictEncoder().encode(atoms)) + message = requests.put( + self.url + "/calculation", json=DictEncoder().encode(atoms) + ) # message = json.dumps(atoms) # message_hash = hashlib.md5(message.encode('utf-8')).hexdigest() logger.info(message) @@ -44,29 +45,33 @@ def query(self, query_string): pass def search(self, query_string: str) -> List[str]: - results = requests.get(self.url + '/calculation').json() + results = requests.get(self.url + "/calculation").json() return results def get_atoms(self, id: str) -> Atoms: - data = requests.get(self.url + '/calculation/{}'.format(id)).json() + data = requests.get(self.url + "/calculation/{}".format(id)).json() atoms = Atoms.from_dict(data) return atoms def __repr__(self): - return 'ABCD(type={}, url={}, ...)'.format(self.__class__.__name__, self.url) + return "ABCD(type={}, url={}, ...)".format(self.__class__.__name__, self.url) def _repr_html_(self): """jupyter notebook representation""" - return 'ABCD database' + return "ABCD database" def print_info(self): """shows basic information about the connected database""" - out = linesep.join([ - '{:=^50}'.format(' ABCD Database '), - '{:>10}: {}'.format('type', 'remote (http/https)'), - linesep.join('{:>10}: {}'.format(k, v) for k, v in self.db.info().items()) - ]) + out = linesep.join( + [ + "{:=^50}".format(" ABCD Database "), + "{:>10}: {}".format("type", "remote (http/https)"), + linesep.join( + "{:>10}: {}".format(k, v) for k, v in self.db.info().items() + ), + ] + ) print(out) @@ -80,6 +85,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass -if __name__ == '__main__': - abcd = HttpDatabase(url='http://localhost:8080/api') +if __name__ == "__main__": + abcd = HttpDatabase(url="http://localhost:8080/api") abcd.print_info() diff --git a/abcd/backends/atoms_pymongo.py b/abcd/backends/atoms_pymongo.py index 217e1c52..993c6eb0 100644 --- a/abcd/backends/atoms_pymongo.py +++ b/abcd/backends/atoms_pymongo.py @@ -26,12 +26,12 @@ logger = logging.getLogger(__name__) map_types = { - bool: 'bool', - float: 'float', - int: 'int', - str: 'str', - datetime: 'date', - dict: 'dict' + bool: "bool", + float: "float", + int: "int", + str: "str", + datetime: "date", + dict: "dict", } @@ -49,15 +49,14 @@ def from_atoms(cls, collection, atoms: Atoms, extra_info=None, store_calc=True): @property def _id(self): - return self.get('_id', None) + return self.get("_id", None) def save(self): if not self._id: self._collection.insert_one(self) else: - new_values = { "$set": self } - self._collection.update_one( - {"_id": ObjectId(self._id)}, new_values) + new_values = {"$set": self} + self._collection.update_one({"_id": ObjectId(self._id)}, new_values) def remove(self): if self._id: @@ -66,28 +65,27 @@ def remove(self): class MongoQuery(AbstractQuerySet): - def __init__(self): pass def visit(self, syntax_tree): op, *args = syntax_tree try: - fun = self.__getattribute__('visit_' + op.lower()) + fun = self.__getattribute__("visit_" + op.lower()) return fun(*args) except KeyError: pass def visit_name(self, field): - return {field: {'$exists': True}} + return {field: {"$exists": True}} def visit_not(self, value): _, field = value - return {field: {'$exists': False}} + return {field: {"$exists": False}} def visit_and(self, *args): print(args) - return {'$and': [self.visit(arg) for arg in args]} + return {"$and": [self.visit(arg) for arg in args]} # TODO recursively combining all the and statements # out = {} # for arg in args: @@ -97,28 +95,28 @@ def visit_and(self, *args): # return out def visit_or(self, *args): - return {'$or': [self.visit(arg) for arg in args]} + return {"$or": [self.visit(arg) for arg in args]} def visit_eq(self, field, value): return {field[1]: value[1]} def visit_re(self, field, value): - return {field[1]: {'$regex': value[1]}} + return {field[1]: {"$regex": value[1]}} def visit_gt(self, field, value): - return {field[1]: {'$gt': value[1]}} + return {field[1]: {"$gt": value[1]}} def visit_gte(self, field, value): - return {field[1]: {'$gte': value[1]}} + return {field[1]: {"$gte": value[1]}} def visit_lt(self, field, value): - return {field[1]: {'$lt': value[1]}} + return {field[1]: {"$lt": value[1]}} def visit_lte(self, field, value): - return {field[1]: {'$lte': value[1]}} + return {field[1]: {"$lte": value[1]}} def visit_in(self, field, *values): - return {field[1]: {'$in': [value[1] for value in values]}} + return {field[1]: {"$in": [value[1] for value in values]}} def __enter__(self): return self @@ -127,12 +125,13 @@ def __exit__(self, exc_type, exc_val, exc_tb): pass def __call__(self, ast): - logger.info('parsed ast: {}'.format(ast)) + logger.info("parsed ast: {}".format(ast)) if isinstance(ast, dict): return ast elif isinstance(ast, str): from abcd.parsers.queries import parser + p = parser(ast) return self.visit(p) @@ -156,20 +155,43 @@ def wrapper(*args, query=None, **kwargs): class MongoDatabase(AbstractABCD): """Wrapper to make database operations easy""" - def __init__(self, host='localhost', port=27017, - db_name='abcd', collection_name='atoms', - username=None, password=None, authSource='admin', **kwargs): + def __init__( + self, + host="localhost", + port=27017, + db_name="abcd", + collection_name="atoms", + username=None, + password=None, + authSource="admin", + **kwargs + ): super().__init__() - logger.info((host, port, db_name, collection_name, username, password, authSource, kwargs)) + logger.info( + ( + host, + port, + db_name, + collection_name, + username, + password, + authSource, + kwargs, + ) + ) self.client = MongoClient( - host=host, port=port, username=username, password=password, - authSource=authSource) + host=host, + port=port, + username=username, + password=password, + authSource=authSource, + ) try: info = self.client.server_info() # Forces a call. - logger.info('DB info: {}'.format(info)) + logger.info("DB info: {}".format(info)) except pymongo.errors.OperationFailure: raise abcd.errors.AuthenticationError() @@ -184,12 +206,12 @@ def info(self): host, port = self.client.address return { - 'host': host, - 'port': port, - 'db': self.db.name, - 'collection': self.collection.name, - 'number of confs': self.collection.count_documents({}), - 'type': 'mongodb' + "host": host, + "port": port, + "db": self.db.name, + "collection": self.collection.name, + "number of confs": self.collection.count_documents({}), + "type": "mongodb", } def delete(self, query=None): @@ -205,14 +227,18 @@ def push(self, atoms: Union[Atoms, Iterable], extra_info=None, store_calc=True): extra_info = extras.parser.parse(extra_info) if isinstance(atoms, Atoms): - data = AtomsModel.from_atoms(self.collection, atoms, extra_info=extra_info, store_calc=store_calc) + data = AtomsModel.from_atoms( + self.collection, atoms, extra_info=extra_info, store_calc=store_calc + ) data.save() # self.collection.insert_one(data) elif isinstance(atoms, types.GeneratorType) or isinstance(atoms, list): for item in atoms: - data = AtomsModel.from_atoms(self.collection, item, extra_info=extra_info, store_calc=store_calc) + data = AtomsModel.from_atoms( + self.collection, item, extra_info=extra_info, store_calc=store_calc + ) data.save() def upload(self, file: Path, extra_infos=None, store_calc=True): @@ -225,7 +251,7 @@ def upload(self, file: Path, extra_infos=None, store_calc=True): for info in extra_infos: extra_info.update(extras.parser.parse(info)) - extra_info['filename'] = str(file) + extra_info["filename"] = str(file) data = iread(str(file)) self.push(data, extra_info, store_calc=store_calc) @@ -243,7 +269,7 @@ def get_atoms(self, query=None): def count(self, query=None): query = parser(query) - logger.info('query; {}'.format(query)) + logger.info("query; {}".format(query)) if not query: query = {} @@ -254,61 +280,69 @@ def property(self, name, query=None): query = parser(query) pipeline = [ - {'$match': query}, - {'$match': {'{}'.format(name): {"$exists": True}}}, - {'$project': {'_id': False, 'data': '${}'.format(name)}} + {"$match": query}, + {"$match": {"{}".format(name): {"$exists": True}}}, + {"$project": {"_id": False, "data": "${}".format(name)}}, ] - return [val['data'] for val in self.db.atoms.aggregate(pipeline)] + return [val["data"] for val in self.db.atoms.aggregate(pipeline)] def properties(self, query=None): query = parser(query) properties = {} pipeline = [ - {'$match': query}, - {'$unwind': '$derived.info_keys'}, - {'$group': {'_id': '$derived.info_keys'}} + {"$match": query}, + {"$unwind": "$derived.info_keys"}, + {"$group": {"_id": "$derived.info_keys"}}, + ] + properties["info"] = [ + value["_id"] for value in self.db.atoms.aggregate(pipeline) ] - properties['info'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)] pipeline = [ - {'$match': query}, - {'$unwind': '$derived.arrays_keys'}, - {'$group': {'_id': '$derived.arrays_keys'}} + {"$match": query}, + {"$unwind": "$derived.arrays_keys"}, + {"$group": {"_id": "$derived.arrays_keys"}}, + ] + properties["arrays"] = [ + value["_id"] for value in self.db.atoms.aggregate(pipeline) ] - properties['arrays'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)] pipeline = [ - {'$match': query}, - {'$unwind': '$derived.derived_keys'}, - {'$group': {'_id': '$derived.derived_keys'}} + {"$match": query}, + {"$unwind": "$derived.derived_keys"}, + {"$group": {"_id": "$derived.derived_keys"}}, + ] + properties["derived"] = [ + value["_id"] for value in self.db.atoms.aggregate(pipeline) ] - properties['derived'] = [value['_id'] for value in self.db.atoms.aggregate(pipeline)] return properties def get_type_of_property(self, prop, category): # TODO: Probably it would be nicer to store the type info in the database from the beginning. - atoms = self.db.atoms.find_one({prop: {'$exists': True}}) + atoms = self.db.atoms.find_one({prop: {"$exists": True}}) data = atoms[prop] - if category == 'arrays': + if category == "arrays": if type(data[0]) == list: - return 'array({}, N x {})'.format(map_types[type(data[0][0])], len(data[0])) + return "array({}, N x {})".format( + map_types[type(data[0][0])], len(data[0]) + ) else: - return 'vector({}, N)'.format(map_types[type(data[0])]) + return "vector({}, N)".format(map_types[type(data[0])]) if type(data) == list: if type(data[0]) == list: if type(data[0][0]) == list: - return 'list(list(...)' + return "list(list(...)" else: - return 'array({})'.format(map_types[type(data[0][0])]) + return "array({})".format(map_types[type(data[0][0])]) else: - return 'vector({})'.format(map_types[type(data[0])]) + return "vector({})".format(map_types[type(data[0])]) else: - return 'scalar({})'.format(map_types[type(data)]) + return "scalar({})".format(map_types[type(data)]) def count_properties(self, query=None): query = parser(query) @@ -316,67 +350,69 @@ def count_properties(self, query=None): properties = {} pipeline = [ - {'$match': query}, - {'$unwind': '$derived.info_keys'}, - {'$group': {'_id': '$derived.info_keys', 'count': {'$sum': 1}}} + {"$match": query}, + {"$unwind": "$derived.info_keys"}, + {"$group": {"_id": "$derived.info_keys", "count": {"$sum": 1}}}, ] info_keys = self.db.atoms.aggregate(pipeline) for val in info_keys: - properties[val['_id']] = { - 'count': val['count'], - 'category': 'info', - 'dtype': self.get_type_of_property(val['_id'], 'info') + properties[val["_id"]] = { + "count": val["count"], + "category": "info", + "dtype": self.get_type_of_property(val["_id"], "info"), } pipeline = [ - {'$match': query}, - {'$unwind': '$derived.arrays_keys'}, - {'$group': {'_id': '$derived.arrays_keys', 'count': {'$sum': 1}}} + {"$match": query}, + {"$unwind": "$derived.arrays_keys"}, + {"$group": {"_id": "$derived.arrays_keys", "count": {"$sum": 1}}}, ] arrays_keys = list(self.db.atoms.aggregate(pipeline)) for val in arrays_keys: - properties[val['_id']] = { - 'count': val['count'], - 'category': 'arrays', - 'dtype': self.get_type_of_property(val['_id'], 'arrays') + properties[val["_id"]] = { + "count": val["count"], + "category": "arrays", + "dtype": self.get_type_of_property(val["_id"], "arrays"), } pipeline = [ - {'$match': query}, - {'$unwind': '$derived.derived_keys'}, - {'$group': {'_id': '$derived.derived_keys', 'count': {'$sum': 1}}} + {"$match": query}, + {"$unwind": "$derived.derived_keys"}, + {"$group": {"_id": "$derived.derived_keys", "count": {"$sum": 1}}}, ] arrays_keys = list(self.db.atoms.aggregate(pipeline)) for val in arrays_keys: - properties[val['_id']] = { - 'count': val['count'], - 'category': 'derived', - 'dtype': self.get_type_of_property(val['_id'], 'derived') + properties[val["_id"]] = { + "count": val["count"], + "category": "derived", + "dtype": self.get_type_of_property(val["_id"], "derived"), } return properties def add_property(self, data, query=None): - logger.info('add: data={}, query={}'.format(data, query)) + logger.info("add: data={}, query={}".format(data, query)) self.collection.update_many( parser(query), - {'$push': {'derived.info_keys': {'$each': list(data.keys())}}, - '$set': data}) + { + "$push": {"derived.info_keys": {"$each": list(data.keys())}}, + "$set": data, + }, + ) def rename_property(self, name, new_name, query=None): - logger.info('rename: query={}, old={}, new={}'.format(query, name, new_name)) + logger.info("rename: query={}, old={}, new={}".format(query, name, new_name)) # TODO name in derived.info_keys OR name in derived.arrays_keys OR name in derived.derived_keys self.collection.update_many( - parser(query), - {'$push': {'derived.info_keys': new_name}}) + parser(query), {"$push": {"derived.info_keys": new_name}} + ) self.collection.update_many( parser(query), - { - '$pull': {'derived.info_keys': name}, - '$rename': {name: new_name}}) + {"$pull": {"derived.info_keys": name}, "$rename": {name: new_name}}, + ) # self.collection.update_many( # parser(query + ['arrays.{}'.format(name)]), @@ -388,13 +424,15 @@ def rename_property(self, name, new_name, query=None): # '$rename': {'arrays.{}'.format(name): 'arrays.{}'.format(new_name)}}) def delete_property(self, name, query=None): - logger.info('delete: query={}, porperty={}'.format(name, query)) + logger.info("delete: query={}, porperty={}".format(name, query)) self.collection.update_many( parser(query), - {'$pull': {'derived.info_keys': name, - 'derived.arrays_keys': name}, - '$unset': {name: ''}}) + { + "$pull": {"derived.info_keys": name, "derived.arrays_keys": name}, + "$unset": {name: ""}, + }, + ) def hist(self, name, query=None, **kwargs): @@ -411,21 +449,27 @@ def exec(self, code, query=None): def __repr__(self): host, port = self.client.address - return '{}('.format(self.__class__.__name__) + \ - 'url={}:{}, '.format(host, port) + \ - 'db={}, '.format(self.db.name) + \ - 'collection={})'.format(self.collection.name) + return ( + "{}(".format(self.__class__.__name__) + + "url={}:{}, ".format(host, port) + + "db={}, ".format(self.db.name) + + "collection={})".format(self.collection.name) + ) def _repr_html_(self): """Jupyter notebook representation""" - return 'ABCD MongoDB database' + return "ABCD MongoDB database" def print_info(self): """shows basic information about the connected database""" - out = linesep.join(['{:=^50}'.format(' ABCD MongoDB '), - '{:>10}: {}'.format('type', 'mongodb'), - linesep.join('{:>10}: {}'.format(k, v) for k, v in self.info().items())]) + out = linesep.join( + [ + "{:=^50}".format(" ABCD MongoDB "), + "{:>10}: {}".format("type", "mongodb"), + linesep.join("{:>10}: {}".format(k, v) for k, v in self.info().items()), + ] + ) print(out) @@ -449,26 +493,36 @@ def histogram(name, data, **kwargs): return None if ptype == float: - bins = kwargs.get('bins', 10) + bins = kwargs.get("bins", 10) return _hist_float(name, data, bins) elif ptype == int: - bins = kwargs.get('bins', 10) + bins = kwargs.get("bins", 10) return _hist_int(name, data, bins) elif ptype == str: return _hist_str(name, data, **kwargs) elif ptype == datetime: - bins = kwargs.get('bins', 10) + bins = kwargs.get("bins", 10) return _hist_date(name, data, bins) else: - print('{}: Histogram for list of {} types are not supported!'.format(name, type(data[0]))) - logger.info('{}: Histogram for list of {} types are not supported!'.format(name, type(data[0]))) + print( + "{}: Histogram for list of {} types are not supported!".format( + name, type(data[0]) + ) + ) + logger.info( + "{}: Histogram for list of {} types are not supported!".format( + name, type(data[0]) + ) + ) else: - logger.info('{}: Histogram for {} types are not supported!'.format(name, type(data))) + logger.info( + "{}: Histogram for {} types are not supported!".format(name, type(data)) + ) return None @@ -477,16 +531,16 @@ def _hist_float(name, data, bins=10): hist, bin_edges = np.histogram(data, bins=bins) return { - 'type': 'hist_float', - 'name': name, - 'bins': bins, - 'edges': bin_edges, - 'counts': hist, - 'min': data.min(), - 'max': data.max(), - 'median': data.mean(), - 'std': data.std(), - 'var': data.var() + "type": "hist_float", + "name": name, + "bins": bins, + "edges": bin_edges, + "counts": hist, + "min": data.min(), + "max": data.max(), + "median": data.mean(), + "std": data.std(), + "var": data.var(), } @@ -497,16 +551,16 @@ def _hist_date(name, data, bins=10): fromtimestamp = datetime.fromtimestamp return { - 'type': 'hist_date', - 'name': name, - 'bins': bins, - 'edges': [fromtimestamp(d) for d in bin_edges], - 'counts': hist, - 'min': fromtimestamp(hist_data.min()), - 'max': fromtimestamp(hist_data.max()), - 'median': fromtimestamp(hist_data.mean()), - 'std': fromtimestamp(hist_data.std()), - 'var': fromtimestamp(hist_data.var()) + "type": "hist_date", + "name": name, + "bins": bins, + "edges": [fromtimestamp(d) for d in bin_edges], + "counts": hist, + "min": fromtimestamp(hist_data.min()), + "max": fromtimestamp(hist_data.max()), + "median": fromtimestamp(hist_data.mean()), + "std": fromtimestamp(hist_data.std()), + "var": fromtimestamp(hist_data.var()), } @@ -520,16 +574,16 @@ def _hist_int(name, data, bins=10): hist, bin_edges = np.histogram(data, bins=bins) return { - 'type': 'hist_int', - 'name': name, - 'bins': bins, - 'edges': bin_edges, - 'counts': hist, - 'min': data.min(), - 'max': data.max(), - 'median': data.mean(), - 'std': data.std(), - 'var': data.var() + "type": "hist_int", + "name": name, + "bins": bins, + "edges": bin_edges, + "counts": hist, + "min": data.min(), + "max": data.max(), + "median": data.mean(), + "std": data.std(), + "var": data.var(), } @@ -538,7 +592,9 @@ def _hist_str(name, data, bins=10, truncate=20): if truncate: # data = (item[:truncate] for item in data) - data = (item[:truncate] + '...' if len(item) > truncate else item for item in data) + data = ( + item[:truncate] + "..." if len(item) > truncate else item for item in data + ) data = Counter(data) @@ -548,27 +604,27 @@ def _hist_str(name, data, bins=10, truncate=20): labels, counts = zip(*data.items()) return { - 'type': 'hist_str', - 'name': name, - 'total': sum(data.values()), - 'unique': n_unique, - 'labels': labels[:bins], - 'counts': counts[:bins] + "type": "hist_str", + "name": name, + "total": sum(data.values()), + "unique": n_unique, + "labels": labels[:bins], + "counts": counts[:bins], } -if __name__ == '__main__': +if __name__ == "__main__": # import json # from ase.io import iread # from pprint import pprint # from server.styles.myjson import JSONEncoderOld, JSONDecoderOld, JSONEncoder - print('hello') - db = MongoDatabase(username='mongoadmin', password='secret') + print("hello") + db = MongoDatabase(username="mongoadmin", password="secret") print(db.info()) print(db.count()) - print(db.hist('uploaded')) + print(db.hist("uploaded")) # for atoms in iread('../../tutorials/data/bcc_bulk_54_expanded_2_high.xyz', index=slice(None)): # # print(at) diff --git a/abcd/frontends/commandline/commands.py b/abcd/frontends/commandline/commands.py index ebe6a213..de158a5c 100644 --- a/abcd/frontends/commandline/commands.py +++ b/abcd/frontends/commandline/commands.py @@ -11,32 +11,37 @@ @init_config def login(*, config, name, url, **kwargs): logger.info( - 'login args: \nconfig:{}, name:{}, url:{}, kwargs:{}'.format(config, name, url, kwargs)) + "login args: \nconfig:{}, name:{}, url:{}, kwargs:{}".format( + config, name, url, kwargs + ) + ) from abcd import ABCD db = ABCD.from_url(url=url) info = db.info() - config['url'] = url + config["url"] = url config.save() - print('Successfully connected to the database!') - print(" type: {type}\n" - " hostname: {host}\n" - " port: {port}\n" - " database: {db}\n" - " # of confs: {number of confs}".format(**info)) + print("Successfully connected to the database!") + print( + " type: {type}\n" + " hostname: {host}\n" + " port: {port}\n" + " database: {db}\n" + " # of confs: {number of confs}".format(**info) + ) @init_config @init_db def download(*, db, query, fileformat, filename, **kwargs): - logger.info('download\n kwargs: {}'.format(kwargs)) + logger.info("download\n kwargs: {}".format(kwargs)) from ase.io import write - if kwargs.pop('remote'): - write('-', list(db.get_atoms(query=query)), format=fileformat) + if kwargs.pop("remote"): + write("-", list(db.get_atoms(query=query)), format=fileformat) return write(filename, list(db.get_atoms(query=query)), format=fileformat) @@ -46,14 +51,18 @@ def download(*, db, query, fileformat, filename, **kwargs): @init_db @check_remote def delete(*, db, query, yes, **kwargs): - logger.info('delete\n kwargs: {}'.format(kwargs)) + logger.info("delete\n kwargs: {}".format(kwargs)) if not yes: - print('Please use --yes for deleting {} configurations'.format(db.count(query=query))) + print( + "Please use --yes for deleting {} configurations".format( + db.count(query=query) + ) + ) exit(1) count = db.delete(query=query) - print('{} configuration has been deleted'.format(count)) + print("{} configuration has been deleted".format(count)) @init_config @@ -69,11 +78,11 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs): db.upload(path, extra_infos, store_calc=calculator) elif path.is_dir(): - for file in path.glob('.xyz'): - logger.info('Uploaded file: {}'.format(file)) + for file in path.glob(".xyz"): + logger.info("Uploaded file: {}".format(file)) db.upload(file, extra_infos, store_calc=calculator) else: - logger.info('No file found: {}'.format(path)) + logger.info("No file found: {}".format(path)) raise FileNotFoundError() else: @@ -83,8 +92,8 @@ def upload(*, db, path, extra_infos, ignore_calc_results, **kwargs): @init_config @init_db def summary(*, db, query, print_all, bins, truncate, props, **kwargs): - logger.info('summary\n kwargs: {}'.format(kwargs)) - logger.info('query: {}'.format(query)) + logger.info("summary\n kwargs: {}".format(kwargs)) + logger.info("query: {}".format(query)) if print_all: truncate = None @@ -97,15 +106,15 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs): props_list = [] for prop in props: # TODO: Check that is this the right place? - props_list.extend(re.split(r';\s*|,\s*|\s+', prop)) + props_list.extend(re.split(r";\s*|,\s*|\s+", prop)) - if '*' in props_list: - props_list = '*' + if "*" in props_list: + props_list = "*" - logging.info('property list: {}'.format(props_list)) + logging.info("property list: {}".format(props_list)) total = db.count(query) - print('Total number of configurations: {}'.format(total)) + print("Total number of configurations: {}".format(total)) if total == 0: return @@ -118,13 +127,13 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs): labels, categories, dtypes, counts = [], [], [], [] for k in sorted(props, key=str.lower): labels.append(k) - counts.append(props[k]['count']) - categories.append(props[k]['category']) - dtypes.append(props[k]['dtype']) + counts.append(props[k]["count"]) + categories.append(props[k]["category"]) + dtypes.append(props[k]["dtype"]) f.hist_labels(counts, categories, dtypes, labels) - elif props_list == '*': + elif props_list == "*": props = db.properties(query=query) for ptype in props: @@ -149,8 +158,8 @@ def summary(*, db, query, print_all, bins, truncate, props, **kwargs): @init_config @init_db def show(*, db, query, print_all, props, **kwargs): - logger.info('show\n kwargs: {}'.format(kwargs)) - logger.info('query: {}'.format(query)) + logger.info("show\n kwargs: {}".format(kwargs)) + logger.info("query: {}".format(query)) if not props: print("Please define at least on property by using the -p option!") @@ -162,7 +171,7 @@ def show(*, db, query, print_all, props, **kwargs): for dct in islice(db.get_items(query), 0, limit): print(" | ".join(str(dct.get(prop, None)) for prop in props)) - logging.info('property list: {}'.format(props)) + logging.info("property list: {}".format(props)) @check_remote @@ -171,17 +180,19 @@ def show(*, db, query, print_all, props, **kwargs): def key_add(*, db, query, keys, **kwargs): from abcd.parsers.extras import parser - keys = ' '.join(keys) + keys = " ".join(keys) data = parser.parse(keys) if query: - test = ('AND', query, ("OR", *(('NAME', key) for key in data.keys()))) + test = ("AND", query, ("OR", *(("NAME", key) for key in data.keys()))) else: - test = ("OR", *(('NAME', key) for key in data.keys())) + test = ("OR", *(("NAME", key) for key in data.keys())) if db.count(query=test): - print('The new key already exist for the given query! ' - 'Please make sure that the target key name don\'t exist') + print( + "The new key already exist for the given query! " + "Please make sure that the target key name don't exist" + ) exit(1) db.add_property(data, query=query) @@ -192,8 +203,10 @@ def key_add(*, db, query, keys, **kwargs): @init_db def key_rename(*, db, query, old_keys, new_keys, **kwargs): if db.count(query=query + [old_keys, new_keys]): - print('The new key already exist for the given query! ' - 'Please make sure that the target key name don\'t exist') + print( + "The new key already exist for the given query! " + "Please make sure that the target key name don't exist" + ) exit(1) db.rename_property(old_keys, new_keys, query=query) @@ -205,14 +218,17 @@ def key_rename(*, db, query, old_keys, new_keys, **kwargs): def key_delete(*, db, query, yes, keys, **kwargs): from abcd.parsers.extras import parser - keys = ' '.join(keys) + keys = " ".join(keys) data = parser.parse(keys) - query = ('AND', query, ('OR', *(('NAME', key) for key in data.keys()))) + query = ("AND", query, ("OR", *(("NAME", key) for key in data.keys()))) if not yes: - print('Please use --yes for deleting keys from {} configurations'.format( - db.count(query=query))) + print( + "Please use --yes for deleting keys from {} configurations".format( + db.count(query=query) + ) + ) exit(1) for k in keys: @@ -224,8 +240,11 @@ def key_delete(*, db, query, yes, keys, **kwargs): @init_db def execute(*, db, query, yes, python_code, **kwargs): if not yes: - print('Please use --yes for executing code on {} configurations'.format( - db.count(query=query))) + print( + "Please use --yes for executing code on {} configurations".format( + db.count(query=query) + ) + ) exit(1) db.exec(python_code, query) @@ -236,7 +255,9 @@ def server(*, abcd_url, url, api_only, **kwargs): from urllib.parse import urlparse from abcd.server.app import create_app - logger.info("SERVER - abcd: {}, url: {}, api_only:{}".format(abcd_url, url, api_only)) + logger.info( + "SERVER - abcd: {}, url: {}, api_only:{}".format(abcd_url, url, api_only) + ) if api_only: print("Not implemented yet!") @@ -252,26 +273,37 @@ class Formater(object): partialBlocks = ["▏", "▎", "▍", "▌", "▋", "▊", "▉", "█"] # char=pb def title(self, title): - print('', title, '=' * len(title), sep=os.linesep) + print("", title, "=" * len(title), sep=os.linesep) def describe(self, data): - if data['type'] == 'hist_float': + if data["type"] == "hist_float": print( - '{} count: {} min: {:11.4e} med: {:11.4e} max: {:11.4e} std: {:11.4e} var:{' - ':11.4e}'.format( - data["name"], sum(data["counts"]), - data["min"], data["median"], data["max"], - data["std"], data["var"]) + "{} count: {} min: {:11.4e} med: {:11.4e} max: {:11.4e} std: {:11.4e} var:{" + ":11.4e}".format( + data["name"], + sum(data["counts"]), + data["min"], + data["median"], + data["max"], + data["std"], + data["var"], + ) ) - elif data['type'] == 'hist_int': - print('{} count: {} '.format(data["name"], sum(data["counts"])), - 'min: {:d} med: {:d} max: {:d} '.format(int(data["min"]), int(data["median"]), - int(data["max"])) - ) + elif data["type"] == "hist_int": + print( + "{} count: {} ".format(data["name"], sum(data["counts"])), + "min: {:d} med: {:d} max: {:d} ".format( + int(data["min"]), int(data["median"]), int(data["max"]) + ), + ) - elif data['type'] == 'hist_str': - print('{} count: {} unique: {}'.format(data["name"], data["total"], data["unique"])) + elif data["type"] == "hist_str": + print( + "{} count: {} unique: {}".format( + data["name"], data["total"], data["unique"] + ) + ) else: pass @@ -282,10 +314,11 @@ def hist_float(self, bin_edges, counts, width_hist=40): for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]): scale = int(ratio * count) - self.print('{:<{}} {:>{}d} [{: >11.4e}, {: >11.4f})'.format( - "▉" * scale, width_hist, - count, width_count, - lower, upper)) + self.print( + "{:<{}} {:>{}d} [{: >11.4e}, {: >11.4f})".format( + "▉" * scale, width_hist, count, width_count, lower, upper + ) + ) def hist_int(self, bin_edges, counts, width_hist=40): @@ -294,35 +327,50 @@ def hist_int(self, bin_edges, counts, width_hist=40): for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]): scale = int(ratio * count) - self.print('{:<{}} {:>{}d} [{:d}, {:d})'.format( - "▉" * scale, width_hist, - count, width_count, - np.ceil(lower).astype(int), np.floor(upper).astype(int))) + self.print( + "{:<{}} {:>{}d} [{:d}, {:d})".format( + "▉" * scale, + width_hist, + count, + width_count, + np.ceil(lower).astype(int), + np.floor(upper).astype(int), + ) + ) def hist_date(self, bin_edges, counts, width_hist=40): - dateformat = '%y-%m-%d %H:%M' + dateformat = "%y-%m-%d %H:%M" ratio = width_hist / max(counts) width_count = len(str(max(counts))) for count, lower, upper in zip(counts, bin_edges[:-1], bin_edges[1:]): scale = int(ratio * count) - self.print('{:<{}} {:>{}d} [{}, {})'.format( - "▉" * scale, width_hist, - count, width_count, - lower.strftime(dateformat), upper.strftime(dateformat))) + self.print( + "{:<{}} {:>{}d} [{}, {})".format( + "▉" * scale, + width_hist, + count, + width_count, + lower.strftime(dateformat), + upper.strftime(dateformat), + ) + ) def hist_str(self, total, counts, labels, width_hist=40): remain = total - sum(counts) if remain > 0: counts = (*counts, remain) - labels = (*labels, '...') + labels = (*labels, "...") width_count = len(str(max(counts))) ratio = width_hist / max(counts) for label, count in zip(labels, counts): scale = int(ratio * count) self.print( - '{:<{}} {:>{}d} {}'.format("▉" * scale, width_hist, count, width_count, label)) + "{:<{}} {:>{}d} {}".format( + "▉" * scale, width_hist, count, width_count, label + ) + ) def hist_labels(self, counts, categories, dtypes, labels, width_hist=40): @@ -330,18 +378,21 @@ def hist_labels(self, counts, categories, dtypes, labels, width_hist=40): ratio = width_hist / max(counts) for label, count, dtype in zip(labels, counts, dtypes): scale = int(ratio * count) - self.print('{:<{}} {:<21} {:>{}d} {}'.format( - "▉" * scale, width_hist, dtype, count, width_count, label)) + self.print( + "{:<{}} {:<21} {:>{}d} {}".format( + "▉" * scale, width_hist, dtype, count, width_count, label + ) + ) def hist(self, data: dict, width_hist=40): - if data['type'] == 'hist_float': - self.hist_float(data['edges'], data['counts']) - elif data['type'] == 'hist_int': - self.hist_int(data['edges'], data['counts']) - elif data['type'] == 'hist_date': - self.hist_date(data['edges'], data['counts']) - elif data['type'] == 'hist_str': - self.hist_str(data['total'], data['counts'], data['labels']) + if data["type"] == "hist_float": + self.hist_float(data["edges"], data["counts"]) + elif data["type"] == "hist_int": + self.hist_int(data["edges"], data["counts"]) + elif data["type"] == "hist_date": + self.hist_date(data["edges"], data["counts"]) + elif data["type"] == "hist_str": + self.hist_str(data["total"], data["counts"], data["labels"]) else: pass @@ -350,4 +401,4 @@ def print(*args, **kwargs): print(*args, **kwargs) def _trunc(self, text, width=80): - return text if len(text) < width else text[:width - 3] + '...' + return text if len(text) < width else text[: width - 3] + "..." diff --git a/abcd/frontends/commandline/config.py b/abcd/frontends/commandline/config.py index 96f1380c..3aa21bea 100644 --- a/abcd/frontends/commandline/config.py +++ b/abcd/frontends/commandline/config.py @@ -21,27 +21,33 @@ def from_json(cls, filename): @classmethod def load(cls): - if os.environ.get('ABCD_CONFIG') and Path(os.environ.get('ABCD_CONFIG')).is_file(): - file = Path(os.environ.get('ABCD_CONFIG')) - elif (Path.home() / '.abcd').is_file(): - file = Path.home() / '.abcd' + if ( + os.environ.get("ABCD_CONFIG") + and Path(os.environ.get("ABCD_CONFIG")).is_file() + ): + file = Path(os.environ.get("ABCD_CONFIG")) + elif (Path.home() / ".abcd").is_file(): + file = Path.home() / ".abcd" else: return cls() - logger.info('Using config file: {}'.format(file)) + logger.info("Using config file: {}".format(file)) config = cls.from_json(file) return config def save(self): - file = Path(os.environ.get('ABCD_CONFIG')) if os.environ.get('ABCD_CONFIG') \ - else Path.home() / '.abcd' + file = ( + Path(os.environ.get("ABCD_CONFIG")) + if os.environ.get("ABCD_CONFIG") + else Path.home() / ".abcd" + ) - logger.info('The saved config\'s file: {}'.format(file)) + logger.info("The saved config's file: {}".format(file)) - with open(str(file), 'w') as file: + with open(str(file), "w") as file: json.dump(self, file) def __repr__(self): - return '<{} {}>'.format(self.__class__.__name__, dict.__repr__(self)) + return "<{} {}>".format(self.__class__.__name__, dict.__repr__(self)) diff --git a/abcd/frontends/commandline/decorators.py b/abcd/frontends/commandline/decorators.py index d46e9ced..c2439be7 100644 --- a/abcd/frontends/commandline/decorators.py +++ b/abcd/frontends/commandline/decorators.py @@ -19,10 +19,10 @@ def wrapper(*args, **kwargs): def init_db(func): def wrapper(*args, config, **kwargs): - url = config.get('url', None) + url = config.get("url", None) if url is None: - print('Please use abcd login first!') + print("Please use abcd login first!") exit(1) db = ABCD.from_url(url=url) @@ -32,10 +32,10 @@ def wrapper(*args, config, **kwargs): # TODO: better ast optimisation query_list = [] - for q in kwargs.pop('default_query', []): + for q in kwargs.pop("default_query", []): query_list.append(parser(q)) - for q in kwargs.pop('query', []): + for q in kwargs.pop("query", []): query_list.append(parser(q)) if not query_list: @@ -43,7 +43,7 @@ def wrapper(*args, config, **kwargs): elif len(query_list) == 1: query = query_list[0] else: - query = ('AND', *query_list) + query = ("AND", *query_list) func(*args, db=db, query=query, **kwargs) @@ -52,8 +52,8 @@ def wrapper(*args, config, **kwargs): def check_remote(func): def wrapper(*args, **kwargs): - if kwargs.pop('remote'): - print('In read only mode, you can\'t modify the data in the database') + if kwargs.pop("remote"): + print("In read only mode, you can't modify the data in the database") exit(1) func(*args, **kwargs) diff --git a/abcd/frontends/commandline/parser.py b/abcd/frontends/commandline/parser.py index 6856bd9e..9b2c1af2 100644 --- a/abcd/frontends/commandline/parser.py +++ b/abcd/frontends/commandline/parser.py @@ -5,135 +5,234 @@ logger = logging.getLogger(__name__) -parser = ArgumentParser(description='Command line interface for ABCD database') -parser.add_argument('-v', '--verbose', help='Enable verbose mode', action='store_true') -parser.add_argument('-q', '--query', dest='default_query', action='append', help='Filtering extra quantities', - default=[]) - -parser.add_argument('--remote', help='Disables all the functions which would modify the database', - action='store_true') - -subparsers = parser.add_subparsers(title='Commands', dest='command', parser_class=ArgumentParser) - -login_parser = subparsers.add_parser('login', help='login to the database') +parser = ArgumentParser(description="Command line interface for ABCD database") +parser.add_argument("-v", "--verbose", help="Enable verbose mode", action="store_true") +parser.add_argument( + "-q", + "--query", + dest="default_query", + action="append", + help="Filtering extra quantities", + default=[], +) + +parser.add_argument( + "--remote", + help="Disables all the functions which would modify the database", + action="store_true", +) + +subparsers = parser.add_subparsers( + title="Commands", dest="command", parser_class=ArgumentParser +) + +login_parser = subparsers.add_parser("login", help="login to the database") login_parser.set_defaults(callback_func=commands.login) -login_parser.add_argument('-n', '--name', help='name of the database', default='default') -login_parser.add_argument(dest='url', - help='url of abcd api (default: http://localhost)', - default='http://localhost') - -download_parser = subparsers.add_parser('download', help='download data from the database') +login_parser.add_argument( + "-n", "--name", help="name of the database", default="default" +) +login_parser.add_argument( + dest="url", + help="url of abcd api (default: http://localhost)", + default="http://localhost", +) + +download_parser = subparsers.add_parser( + "download", help="download data from the database" +) download_parser.set_defaults(callback_func=commands.download) -download_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[]) -download_parser.add_argument('-f', '--format', help='Valid ASE file format (optional)', dest='fileformat', default='extxyz') -download_parser.add_argument(dest='filename', help='name of the file to store the configurations', nargs='?') - -upload_parser = subparsers.add_parser('upload', help='upload any ase supported files to the database') -upload_parser.add_argument('-e', '--extra_infos', action='append', help='Adding extra quantities') -upload_parser.add_argument('-i', '--ignore_calc_results', action='store_true', - help='Ignore calculators results/parameters') -upload_parser.add_argument('--upload-duplicates', action='store_true', - help='Upload but still report all of the duplicates') -upload_parser.add_argument('--upload-structure-duplicates', action='store_true', - help='Ignore all the "exact" duplicates but store all of the structural duplicates') -upload_parser.add_argument('--upload-duplicates-replace', action='store_true', - help='Upload everything and duplicates overwrite previously existing data') -upload_parser.add_argument(dest='path', help='Path to the file or folder.') +download_parser.add_argument( + "-q", "--query", action="append", help="Filtering extra quantities", default=[] +) +download_parser.add_argument( + "-f", + "--format", + help="Valid ASE file format (optional)", + dest="fileformat", + default="extxyz", +) +download_parser.add_argument( + dest="filename", help="name of the file to store the configurations", nargs="?" +) + +upload_parser = subparsers.add_parser( + "upload", help="upload any ase supported files to the database" +) +upload_parser.add_argument( + "-e", "--extra_infos", action="append", help="Adding extra quantities" +) +upload_parser.add_argument( + "-i", + "--ignore_calc_results", + action="store_true", + help="Ignore calculators results/parameters", +) +upload_parser.add_argument( + "--upload-duplicates", + action="store_true", + help="Upload but still report all of the duplicates", +) +upload_parser.add_argument( + "--upload-structure-duplicates", + action="store_true", + help='Ignore all the "exact" duplicates but store all of the structural duplicates', +) +upload_parser.add_argument( + "--upload-duplicates-replace", + action="store_true", + help="Upload everything and duplicates overwrite previously existing data", +) +upload_parser.add_argument(dest="path", help="Path to the file or folder.") upload_parser.set_defaults(callback_func=commands.upload) -summary_parser = subparsers.add_parser('summary', help='Discovery mode') +summary_parser = subparsers.add_parser("summary", help="Discovery mode") summary_parser.set_defaults(callback_func=commands.summary) -summary_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[]) -summary_parser.add_argument('-p', '--props', action='append', - help='Selecting properties for detailed description') -summary_parser.add_argument('-a', '--all', - help='Show everything without truncation of strings and limits of lines', - action='store_true', dest='print_all') -summary_parser.add_argument('-n', '--bins', help='The number of bins of the histogram', default=10, type=int) -summary_parser.add_argument('-t', '--trunc', - help='Length of string before truncation', - default=20, type=int, dest='truncate') - -show_parser = subparsers.add_parser('show', help='shows the first 10 items') +summary_parser.add_argument( + "-q", "--query", action="append", help="Filtering extra quantities", default=[] +) +summary_parser.add_argument( + "-p", + "--props", + action="append", + help="Selecting properties for detailed description", +) +summary_parser.add_argument( + "-a", + "--all", + help="Show everything without truncation of strings and limits of lines", + action="store_true", + dest="print_all", +) +summary_parser.add_argument( + "-n", "--bins", help="The number of bins of the histogram", default=10, type=int +) +summary_parser.add_argument( + "-t", + "--trunc", + help="Length of string before truncation", + default=20, + type=int, + dest="truncate", +) + +show_parser = subparsers.add_parser("show", help="shows the first 10 items") show_parser.set_defaults(callback_func=commands.show) -show_parser.add_argument('-q', '--query', action='append', help='Filtering extra quantities', default=[]) -show_parser.add_argument('-p', '--props', action='append', - help='Selecting properties for detailed description') -show_parser.add_argument('-a', '--all', - help='Show everything without truncation of strings and limits of lines', - action='store_true', dest='print_all') - -delete_parser = subparsers.add_parser('delete', help='Delete configurations from the database') +show_parser.add_argument( + "-q", "--query", action="append", help="Filtering extra quantities", default=[] +) +show_parser.add_argument( + "-p", + "--props", + action="append", + help="Selecting properties for detailed description", +) +show_parser.add_argument( + "-a", + "--all", + help="Show everything without truncation of strings and limits of lines", + action="store_true", + dest="print_all", +) + +delete_parser = subparsers.add_parser( + "delete", help="Delete configurations from the database" +) delete_parser.set_defaults(callback_func=commands.delete) -delete_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[]) -delete_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual deletion.') - -key_add_parser = subparsers.add_parser('add-key', help='Adding new key value pairs for a given query') +delete_parser.add_argument( + "-q", "--query", action="append", help="Filtering by a query", default=[] +) +delete_parser.add_argument( + "-y", "--yes", action="store_true", help="Do the actual deletion." +) + +key_add_parser = subparsers.add_parser( + "add-key", help="Adding new key value pairs for a given query" +) key_add_parser.set_defaults(callback_func=commands.key_add) -key_add_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[]) -key_add_parser.add_argument('-y', '--yes', action='store_true', help='Overwrite?') -key_add_parser.add_argument('keys', help='keys(=value) pairs', nargs='+') - -key_rename_parser = subparsers.add_parser('rename-key', help='Rename a specific keys for a given query') +key_add_parser.add_argument( + "-q", "--query", action="append", help="Filtering by a query", default=[] +) +key_add_parser.add_argument("-y", "--yes", action="store_true", help="Overwrite?") +key_add_parser.add_argument("keys", help="keys(=value) pairs", nargs="+") + +key_rename_parser = subparsers.add_parser( + "rename-key", help="Rename a specific keys for a given query" +) key_rename_parser.set_defaults(callback_func=commands.key_rename) -key_rename_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[]) -key_rename_parser.add_argument('-y', '--yes', action='store_true', help='Overwrite?') -key_rename_parser.add_argument('old_keys', help='name of the old key') -key_rename_parser.add_argument('new_keys', help='new name of the key') - -key_delete_parser = subparsers.add_parser('delete-key', help='Delete all the keys for a given query') +key_rename_parser.add_argument( + "-q", "--query", action="append", help="Filtering by a query", default=[] +) +key_rename_parser.add_argument("-y", "--yes", action="store_true", help="Overwrite?") +key_rename_parser.add_argument("old_keys", help="name of the old key") +key_rename_parser.add_argument("new_keys", help="new name of the key") + +key_delete_parser = subparsers.add_parser( + "delete-key", help="Delete all the keys for a given query" +) key_delete_parser.set_defaults(callback_func=commands.key_delete) -key_delete_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[]) -key_delete_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual deletion.') -key_delete_parser.add_argument('keys', help='keys(=value) data', nargs='+') - -exec_parser = subparsers.add_parser('exec', help='Running custom python code') +key_delete_parser.add_argument( + "-q", "--query", action="append", help="Filtering by a query", default=[] +) +key_delete_parser.add_argument( + "-y", "--yes", action="store_true", help="Do the actual deletion." +) +key_delete_parser.add_argument("keys", help="keys(=value) data", nargs="+") + +exec_parser = subparsers.add_parser("exec", help="Running custom python code") exec_parser.set_defaults(callback_func=commands.execute) -exec_parser.add_argument('-q', '--query', action='append', help='Filtering by a query', default=[]) -exec_parser.add_argument('-y', '--yes', action='store_true', help='Do the actual execution.') -exec_parser.add_argument('python_code', help='Selecting properties for detailed description') - -server = subparsers.add_parser('server', help='Running custom python code') +exec_parser.add_argument( + "-q", "--query", action="append", help="Filtering by a query", default=[] +) +exec_parser.add_argument( + "-y", "--yes", action="store_true", help="Do the actual execution." +) +exec_parser.add_argument( + "python_code", help="Selecting properties for detailed description" +) + +server = subparsers.add_parser("server", help="Running custom python code") server.set_defaults(callback_func=commands.server) -server.add_argument('abcd_url', help='Url for abcd database.') -server.add_argument('--api-only', action='store_true', help='Running only the API.') -server.add_argument('-u', '--url', help='Url to run the server.', default='http://localhost:5000') +server.add_argument("abcd_url", help="Url for abcd database.") +server.add_argument("--api-only", action="store_true", help="Running only the API.") +server.add_argument( + "-u", "--url", help="Url to run the server.", default="http://localhost:5000" +) def main(args=None): kwargs = parser.parse_args(args).__dict__ - if kwargs.pop('verbose'): + if kwargs.pop("verbose"): # Remove all handlers associated with the root logger object. # https://stackoverflow.com/questions/12158048/changing-loggings-basicconfig-which-is-already-set for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) logging.basicConfig(level=logging.INFO) - logger.info('Verbose mode is active') + logger.info("Verbose mode is active") - if not kwargs.pop('command'): + if not kwargs.pop("command"): print(parser.format_help()) return try: - callback_func = kwargs.pop('callback_func') + callback_func = kwargs.pop("callback_func") callback_func(**kwargs) except URLError: - print('Wrong connection: Please check the parameters of the url!') + print("Wrong connection: Please check the parameters of the url!") exit(1) except AuthenticationError: - print('Authentication failed: Please check the parameters of the connection!') + print("Authentication failed: Please check the parameters of the connection!") exit(1) except TimeoutError: print("Timeout: Please check the parameters of the connection!") exit(1) -if __name__ == '__main__': - main(['summary']) - main('delete-key -q pbc pbc'.split()) +if __name__ == "__main__": + main(["summary"]) + main("delete-key -q pbc pbc".split()) # main('upload -e cas -i ../../../tutorials/GB_alphaFe_001/tilt/00110391110_v6bxv2_tv0.4bxv0.2_d1.6z_traj.xyz'.split()) # main('summary -q formula~"Si2"'.split()) # main('upload -e cas -i ../../../tutorials/GB_alphaFe_001/tilt/00110391110_v6bxv2_tv0.4bxv0.2_d1.6z_traj.xyz'.split()) @@ -142,7 +241,7 @@ def main(args=None): # main('-v summary -p energy'.split()) # main('-v summary -p *'.split()) # main('add-key -q cas selected user="cas"'.split()) - main('delete-key user'.split()) + main("delete-key user".split()) # main(['summary', '-p', '*']) # main(['summary', '-p', 'info.config_name, info.energy']) # main(['summary', '-p', 'info.config_name, info.energy,info.energy;info.energy info.energy']) diff --git a/abcd/model.py b/abcd/model.py index c62d14ec..f4c87b61 100644 --- a/abcd/model.py +++ b/abcd/model.py @@ -18,13 +18,13 @@ def __init__(self, method=md5()): def update(self, value): if isinstance(value, int): - self.update(str(value).encode('ascii')) + self.update(str(value).encode("ascii")) elif isinstance(value, str): - self.update(value.encode('utf-8')) + self.update(value.encode("utf-8")) elif isinstance(value, float): - self.update('{:.8e}'.format(value).encode('ascii')) + self.update("{:.8e}".format(value).encode("ascii")) elif isinstance(value, (tuple, list)): for e in value: @@ -33,7 +33,7 @@ def update(self, value): elif isinstance(value, (dict, UserDict)): keys = value.keys() for k in sorted(keys): - self.update(k.encode('utf-8')) + self.update(k.encode("utf-8")) self.update(value[k]) elif isinstance(value, datetime.datetime): @@ -43,7 +43,9 @@ def update(self, value): self.method.update(value) else: - raise ValueError("The {} type cannot be hashed! (Value: {})", format(type(value), value)) + raise ValueError( + "The {} type cannot be hashed! (Value: {})", format(type(value), value) + ) def __call__(self): """Retrieve the digest of the hash.""" @@ -51,7 +53,14 @@ def __call__(self): class AbstractModel(UserDict): - reserved_keys = {'n_atoms', 'cell', 'pbc', 'calculator_name', 'calculator_parameters', 'derived'} + reserved_keys = { + "n_atoms", + "cell", + "pbc", + "calculator_name", + "calculator_parameters", + "derived", + } def __init__(self, dict=None, **kwargs): self.arrays_keys = [] @@ -64,22 +73,22 @@ def __init__(self, dict=None, **kwargs): @property def derived(self): return { - 'arrays_keys': self.arrays_keys, - 'info_keys': self.info_keys, - 'results_keys': self.results_keys, - 'derived_keys': self.derived_keys + "arrays_keys": self.arrays_keys, + "info_keys": self.info_keys, + "results_keys": self.results_keys, + "derived_keys": self.derived_keys, } def __getitem__(self, key): - if key == 'derived': + if key == "derived": return self.derived return super().__getitem__(key) def __setitem__(self, key, value): - if key == 'derived': + if key == "derived": # raise KeyError('Please do not use "derived" as key because it is protected!') # Silent return to avoid raising error in pymongo package return @@ -99,31 +108,31 @@ def convert(self, value): def update_key_category(self, key, value): - if key == '_id': + if key == "_id": # raise KeyError('Please do not use "derived" as key because it is protected!') return - for category in ('arrays_keys', 'info_keys', 'results_keys', 'derived_keys'): + for category in ("arrays_keys", "info_keys", "results_keys", "derived_keys"): if key in self.derived[category]: return - if key in ('positions', 'forces'): - self.derived['arrays_keys'].append(key) + if key in ("positions", "forces"): + self.derived["arrays_keys"].append(key) return - if key in ('n_atoms', 'cell', 'pbc'): - self.derived['info_keys'].append(key) + if key in ("n_atoms", "cell", "pbc"): + self.derived["info_keys"].append(key) return # Guess the category based in the shape of the value - n_atoms = self['n_atoms'] + n_atoms = self["n_atoms"] if isinstance(value, (np.ndarray, list)) and len(value) == n_atoms: - self.derived['arrays_keys'].append(key) + self.derived["arrays_keys"].append(key) else: - self.derived['info_keys'].append(key) + self.derived["info_keys"].append(key) def __delitem__(self, key): - for category in ('arrays_keys', 'info_keys', 'results_keys', 'derived_keys'): + for category in ("arrays_keys", "info_keys", "results_keys", "derived_keys"): if key in self.derived[category]: self.derived[category].remove(key) break @@ -133,34 +142,44 @@ def __delitem__(self, key): def __iter__(self): for item in super().__iter__(): yield item - yield 'derived' + yield "derived" @classmethod def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): """ASE's original implementation""" - reserved_keys = {'n_atoms', 'cell', 'pbc', 'calculator_name', 'calculator_parameters', 'derived', 'formula'} + reserved_keys = { + "n_atoms", + "cell", + "pbc", + "calculator_name", + "calculator_parameters", + "derived", + "formula", + } arrays_keys = set(atoms.arrays.keys()) info_keys = set(atoms.info.keys()) - results_keys = set(atoms.calc.results.keys()) if store_calc and atoms.calc else {} + results_keys = ( + set(atoms.calc.results.keys()) if store_calc and atoms.calc else {} + ) all_keys = (reserved_keys, arrays_keys, info_keys, results_keys) if len(set.union(*all_keys)) != sum(map(len, all_keys)): print(all_keys) - raise ValueError('All the keys must be unique!') + raise ValueError("All the keys must be unique!") item = cls() n_atoms = len(atoms) dct = { - 'n_atoms': n_atoms, - 'cell': atoms.cell.tolist(), - 'pbc': atoms.pbc.tolist(), - 'formula': atoms.get_chemical_formula() + "n_atoms": n_atoms, + "cell": atoms.cell.tolist(), + "pbc": atoms.pbc.tolist(), + "formula": atoms.get_chemical_formula(), } - info_keys.update({'n_atoms', 'cell', 'pbc', 'formula'}) + info_keys.update({"n_atoms", "cell", "pbc", "formula"}) for key, value in atoms.arrays.items(): if isinstance(value, np.ndarray): @@ -175,9 +194,9 @@ def from_atoms(cls, atoms: Atoms, extra_info=None, store_calc=True): dct[key] = value if store_calc and atoms.calc: - dct['calculator_name'] = atoms.calc.__class__.__name__ - dct['calculator_parameters'] = atoms.calc.todict() - info_keys.update({'calculator_name', 'calculator_parameters'}) + dct["calculator_name"] = atoms.calc.__class__.__name__ + dct["calculator_parameters"] = atoms.calc.todict() + info_keys.update({"calculator_name", "calculator_parameters"}) for key, value in atoms.calc.results.items(): @@ -205,26 +224,22 @@ def to_ase(self): arrays_keys = set(self.arrays_keys) info_keys = set(self.info_keys) - cell = self.pop('cell', None) - pbc = self.pop('pbc', None) - numbers = self.pop('numbers', None) - positions = self.pop('positions', None) - results_keys = self.derived['results_keys'] + cell = self.pop("cell", None) + pbc = self.pop("pbc", None) + numbers = self.pop("numbers", None) + positions = self.pop("positions", None) + results_keys = self.derived["results_keys"] - info_keys -= {'cell', 'pbc'} - arrays_keys -= {'numbers', 'positions'} + info_keys -= {"cell", "pbc"} + arrays_keys -= {"numbers", "positions"} - atoms = Atoms( - cell=cell, - pbc=pbc, - numbers=numbers, - positions=positions) + atoms = Atoms(cell=cell, pbc=pbc, numbers=numbers, positions=positions) - if 'calculator_name' in self: + if "calculator_name" in self: # calculator_name = self['info'].pop('calculator_name') # atoms.calc = get_calculator(data['results']['calculator_name'])(**params) - params = self.pop('calculator_parameters', {}) + params = self.pop("calculator_parameters", {}) atoms.calc = SinglePointCalculator(atoms, **params) atoms.calc.results.update((key, self[key]) for key in results_keys) @@ -235,38 +250,38 @@ def to_ase(self): return atoms def pre_save(self): - self.derived_keys = ['elements', 'username', 'uploaded', 'modified'] + self.derived_keys = ["elements", "username", "uploaded", "modified"] - cell = self['cell'] + cell = self["cell"] if cell: volume = abs(np.linalg.det(cell)) # atoms.get_volume() - self['volume'] = volume - self.derived_keys.append('volume') + self["volume"] = volume + self.derived_keys.append("volume") - virial = self.get('virial') + virial = self.get("virial") if virial: # pressure P = -1/3 Tr(stress) = -1/3 Tr(virials/volume) - self['pressure'] = -1 / 3 * np.trace(virial / volume) - self.derived_keys.append('pressure') + self["pressure"] = -1 / 3 * np.trace(virial / volume) + self.derived_keys.append("pressure") # 'elements': Counter(atoms.get_chemical_symbols()), - self['elements'] = Counter(str(element) for element in self['numbers']) + self["elements"] = Counter(str(element) for element in self["numbers"]) - self['username'] = getpass.getuser() + self["username"] = getpass.getuser() - if not self.get('uploaded'): - self['uploaded'] = datetime.datetime.utcnow() + if not self.get("uploaded"): + self["uploaded"] = datetime.datetime.utcnow() - self['modified'] = datetime.datetime.utcnow() + self["modified"] = datetime.datetime.utcnow() m = Hasher() - for key in ('numbers', 'positions', 'cell', 'pbc'): + for key in ("numbers", "positions", "cell", "pbc"): m.update(self[key]) - self.derived_keys.append('hash_structure') - self['hash_structure'] = m() + self.derived_keys.append("hash_structure") + self["hash_structure"] = m() m = Hasher() for key in self.arrays_keys: @@ -274,11 +289,11 @@ def pre_save(self): for key in self.info_keys: m.update(self[key]) - self.derived_keys.append('hash') - self['hash'] = m() + self.derived_keys.append("hash") + self["hash"] = m() -if __name__ == '__main__': +if __name__ == "__main__": import io from pprint import pprint from ase.io import read @@ -286,7 +301,7 @@ def pre_save(self): logging.basicConfig(level=logging.INFO) # from ase.io import jsonio - atoms = read('test.xyz', format='xyz', index=0) + atoms = read("test.xyz", format="xyz", index=0) atoms.set_cell([1, 1, 1]) print(atoms) diff --git a/abcd/parsers/extras.py b/abcd/parsers/extras.py index 939210bb..c007acf6 100644 --- a/abcd/parsers/extras.py +++ b/abcd/parsers/extras.py @@ -90,40 +90,43 @@ def string(self, s): # parser = Lark(grammar, parser='lalr', lexer='contextual', debug=False) -parser = Lark(grammar, parser='lalr', lexer='contextual', transformer=TreeToDict(), debug=False) - -if __name__ == '__main__': - test_string = ' '.join([ - ' ' # start with a separator - 'flag', - 'quotedd_string="quoteddd value"', - r'quotedddd_string_escaped="esc\"aped"', - 'false_value = F', - 'integer=22', - 'floating=1.1', - 'int_array={1 2 3}', - 'scientific_float=1.2e7', - 'scientific_float_2=5e-6', - 'scientific_float_array="1.2 2.2e3 4e1 3.3e-1 2e-2"', - 'not_array="1.2 3.4 text"', - 'array_nested=[[1,2],[3,4]] ' # gets flattented if not 3x3 - 'array_many_other_quotes=({[4 8 12]})', - 'array_boolean={T F T F}', - 'array_boolean_2=" T, F, T " ' # leading spaces - # 'not_bool_array=[T F S]', - # # read and write - # u'\xfcnicode_key=val\xfce', - # # u'unquoted_special_value=a_to_Z_$%%^&*\xfc\u2615', - # '2body=33.3', - 'hyphen-ated', - # # # parse only - 'comma_separated="7, 4, -1"', - 'array_bool_commas=[T, T, F, T]', - # # 'Properties=species:S:1:pos:R:3', - # 'double_equals=abc=xyz', - 'multiple_separators ', - 'trailing' - ]) +parser = Lark( + grammar, parser="lalr", lexer="contextual", transformer=TreeToDict(), debug=False +) + +if __name__ == "__main__": + test_string = " ".join( + [ + " " "flag", # start with a separator + 'quotedd_string="quoteddd value"', + r'quotedddd_string_escaped="esc\"aped"', + "false_value = F", + "integer=22", + "floating=1.1", + "int_array={1 2 3}", + "scientific_float=1.2e7", + "scientific_float_2=5e-6", + 'scientific_float_array="1.2 2.2e3 4e1 3.3e-1 2e-2"', + 'not_array="1.2 3.4 text"', + "array_nested=[[1,2],[3,4]] " # gets flattented if not 3x3 + "array_many_other_quotes=({[4 8 12]})", + "array_boolean={T F T F}", + 'array_boolean_2=" T, F, T " ' # leading spaces + # 'not_bool_array=[T F S]', + # # read and write + # u'\xfcnicode_key=val\xfce', + # # u'unquoted_special_value=a_to_Z_$%%^&*\xfc\u2615', + # '2body=33.3', + "hyphen-ated", + # # # parse only + 'comma_separated="7, 4, -1"', + "array_bool_commas=[T, T, F, T]", + # # 'Properties=species:S:1:pos:R:3', + # 'double_equals=abc=xyz', + "multiple_separators ", + "trailing", + ] + ) j = parser.parse(test_string) diff --git a/abcd/parsers/queries.py b/abcd/parsers/queries.py index 707fd4bc..41ef1c33 100644 --- a/abcd/parsers/queries.py +++ b/abcd/parsers/queries.py @@ -90,17 +90,17 @@ class TreeTransformer(Transformer): def array(self, *items): return list(items) - true = lambda _: ('VALUE', True) - false = lambda _: ('VALUE', False) + true = lambda _: ("VALUE", True) + false = lambda _: ("VALUE", False) def float(self, number): - return 'NUMBER', float(number) + return "NUMBER", float(number) def int(self, number): - return 'NUMBER', int(number) + return "NUMBER", int(number) def string(self, s): - return 'STRING', s[1:-1].replace('\\"', '"') + return "STRING", s[1:-1].replace('\\"', '"') def single_statement(self, expression): return expression @@ -109,24 +109,24 @@ def multi_statement(self, statement, operator, expression): return operator.type, statement, expression def single_expression(self, name): - return 'NAME', str(name) + return "NAME", str(name) def grouped_expression(self, statement): # return statement - return 'GROUP', statement + return "GROUP", statement def operator_expression(self, name, operator, value): - return operator.type, ('NAME', str(name)), value + return operator.type, ("NAME", str(name)), value def reversed_expression(self, value, operator, name): - return operator.type, ('NAME', str(name)), value + return operator.type, ("NAME", str(name)), value def negation_expression(self, operator, expression): return operator.type, expression class Parser: - parser = Lark(grammar, start='statement') + parser = Lark(grammar, start="statement") transformer = TreeTransformer() def parse(self, string): @@ -138,29 +138,29 @@ def __call__(self, string): parser = Parser() -if __name__ == '__main__': +if __name__ == "__main__": # logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.INFO) queries = ( - ' ', - 'single', - 'not single', - 'operator_gt > 23 ', - 'operator_gt > -2.31e-5 ', + " ", + "single", + "not single", + "operator_gt > 23 ", + "operator_gt > -2.31e-5 ", 'string = "some string"', 'regexp ~ ".*H"', - 'aa & not bb', - 'aa & bb > 23.54 | cc & dd', + "aa & not bb", + "aa & bb > 23.54 | cc & dd", # 'aa bb > 22 cc > 33 dd > 44 ', - 'aa and bb > 22 and cc > 33 and dd > 44 ', - '((aa and bb > 22) and cc > 33) and dd > 44 ', - '(aa and bb > 22) and (cc > 33 and dd > 44) ', - '(aa and bb > 22 and cc > 33 and dd > 44) ', - 'aa and bb > 23.54 or 22 in cc and dd', - 'aa & bb > 23.54 | (22 in cc & dd)', - 'aa and bb > 23.54 or (22 in cc and dd)', - 'aa and not (bb > 23.54 or (22 in cc and dd))', + "aa and bb > 22 and cc > 33 and dd > 44 ", + "((aa and bb > 22) and cc > 33) and dd > 44 ", + "(aa and bb > 22) and (cc > 33 and dd > 44) ", + "(aa and bb > 22 and cc > 33 and dd > 44) ", + "aa and bb > 23.54 or 22 in cc and dd", + "aa & bb > 23.54 | (22 in cc & dd)", + "aa and bb > 23.54 or (22 in cc and dd)", + "aa and not (bb > 23.54 or (22 in cc and dd))", # 'expression = (bb/3-1)*cc', # 'energy/n_atoms > 3', # '1=3', @@ -176,7 +176,7 @@ def __call__(self, string): # print(parser.parse(query).pretty()) try: tree = parser.parse(query) - logger.info('=> tree: {}'.format(tree)) - logger.info('==> ast: {}'.format(parser(query))) + logger.info("=> tree: {}".format(tree)) + logger.info("==> ast: {}".format(parser(query))) except LarkError: raise NotImplementedError diff --git a/abcd/parsers/queries_new.py b/abcd/parsers/queries_new.py index 8f176d0c..22004d11 100644 --- a/abcd/parsers/queries_new.py +++ b/abcd/parsers/queries_new.py @@ -8,21 +8,21 @@ # TODO: Reversed operator in the grammar (value op prop VS prop op value VS IN) -class DebugTransformer(Transformer): # pragma: no cover +class DebugTransformer(Transformer): # pragma: no cover def __init__(self): super().__init__() def __default__(self, data, children, meta): - print('Node: ', data, children) + print("Node: ", data, children) return data -if __name__ == '__main__': +if __name__ == "__main__": logging.basicConfig(level=logging.INFO) - with open('query_new.lark') as file: - parser = Lark(file.read(), start='expression') + with open("query_new.lark") as file: + parser = Lark(file.read(), start="expression") transformer = DebugTransformer() @@ -72,30 +72,30 @@ def __default__(self, data, children, meta): # ) queries = ( - ' ', - 'single', - 'not single', - 'operator_gt > 23 ', - 'operator_gt > -2.31e-5 ', + " ", + "single", + "not single", + "operator_gt > 23 ", + "operator_gt > -2.31e-5 ", 'string = "some string"', 'regexp ~ ".*H"', - 'aa & not bb', - 'aa & bb > 23.54 | cc & dd', - 'aa and bb > 22 and cc > 33 and dd > 44 ', - '((aa and bb > 22) and cc > 33) and dd > 44 ', - '(aa and bb > 22) and (cc > 33 and dd > 44) ', - '(aa and bb > 22 and cc > 33 and dd > 44) ', - 'aa and bb > 23.54 or 22 in cc and dd', - 'aa & bb > 23.54 | (22 in cc & dd)', - 'aa and bb > 23.54 or (22 in cc and dd)', - 'aa and not (bb > 23.54 or (22 in cc and dd))', - 'expression = (bb/3-1)*cc', - 'energy/n_atoms > 3', - '1=3', - 'all(aa) > 3', - 'any(aa) > 3', - 'aa = False', - 'aa = [True, True, True]', + "aa & not bb", + "aa & bb > 23.54 | cc & dd", + "aa and bb > 22 and cc > 33 and dd > 44 ", + "((aa and bb > 22) and cc > 33) and dd > 44 ", + "(aa and bb > 22) and (cc > 33 and dd > 44) ", + "(aa and bb > 22 and cc > 33 and dd > 44) ", + "aa and bb > 23.54 or 22 in cc and dd", + "aa & bb > 23.54 | (22 in cc & dd)", + "aa and bb > 23.54 or (22 in cc and dd)", + "aa and not (bb > 23.54 or (22 in cc and dd))", + "expression = (bb/3-1)*cc", + "energy/n_atoms > 3", + "1=3", + "all(aa) > 3", + "any(aa) > 3", + "aa = False", + "aa = [True, True, True]", ) for query in queries: diff --git a/abcd/server/__init__.py b/abcd/server/__init__.py index 86601b85..3b21e3bb 100644 --- a/abcd/server/__init__.py +++ b/abcd/server/__init__.py @@ -2,5 +2,5 @@ app = create_app() -if __name__ == '__main__': - app.run(host='0.0.0.0') +if __name__ == "__main__": + app.run(host="0.0.0.0") diff --git a/abcd/server/app/__init__.py b/abcd/server/app/__init__.py index bc4614f6..d7176a42 100644 --- a/abcd/server/app/__init__.py +++ b/abcd/server/app/__init__.py @@ -11,34 +11,34 @@ def create_app(abcd_url=None): # Define the WSGI application object app = Flask(__name__) - app.logger.info('Creating an application') + app.logger.info("Creating an application") - app.config['SECRET_KEY'] = os.getenv('SECRET_KEY', os.urandom(12).hex()) - app.config['ABCD_URL'] = os.getenv('ABCD_URL', 'mongodb://localhost:27017/abcd') + app.config["SECRET_KEY"] = os.getenv("SECRET_KEY", os.urandom(12).hex()) + app.config["ABCD_URL"] = os.getenv("ABCD_URL", "mongodb://localhost:27017/abcd") # Initialize extensions/add-ons/plugins. nav.init_app(app) - register_renderer(app, 'BootstrapRenderer', BootstrapRenderer) - register_renderer(app, 'DatabaseNav', DatabaseNav) + register_renderer(app, "BootstrapRenderer", BootstrapRenderer) + register_renderer(app, "DatabaseNav", DatabaseNav) db.init_app(app) # Setup redirects and register blueprints. app.register_blueprint(index.bp) - app.register_blueprint(database.bp, url_prefix='/db') - app.register_blueprint(api.bp, url_prefix='/api') + app.register_blueprint(database.bp, url_prefix="/db") + app.register_blueprint(api.bp, url_prefix="/api") @app.errorhandler(404) def not_found(error): - return render_template('404.html'), 404 + return render_template("404.html"), 404 return app -if __name__ == '__main__': +if __name__ == "__main__": import logging logging.basicConfig(level=logging.DEBUG) app = create_app() - app.run(host='0.0.0.0', debug=True) + app.run(host="0.0.0.0", debug=True) diff --git a/abcd/server/app/nav.py b/abcd/server/app/nav.py index 09af4385..ec8ea2d1 100644 --- a/abcd/server/app/nav.py +++ b/abcd/server/app/nav.py @@ -16,29 +16,29 @@ def __init__(self, title, *args, **kwargs): @nav.navigation() def main_navbar(): return TopNavbar( - 'ABCD', - View('Home', 'index.index'), - View('API', 'api.index'), - View('Databases', 'database.database', database_name='default'), - Link('Docs', 'https://libatoms.github.io/abcd/'), - Link('Github', 'https://github.com/libatoms/abcd'), + "ABCD", + View("Home", "index.index"), + View("API", "api.index"), + View("Databases", "database.database", database_name="default"), + Link("Docs", "https://libatoms.github.io/abcd/"), + Link("Github", "https://github.com/libatoms/abcd"), ) @nav.navigation() def database_navbar(): return Navbar( - '', - View('Database', 'database.database'), - Link('Collections', '#'), - Link('History', '#'), - Link('Statistics', '#'), - View('Settings', 'database.settings'), + "", + View("Database", "database.database"), + Link("Collections", "#"), + Link("History", "#"), + Link("Statistics", "#"), + View("Settings", "database.settings"), ) class DatabaseNav(Renderer): - def __init__(self, database_name='atoms'): + def __init__(self, database_name="atoms"): self.database_name = database_name def visit_Navbar(self, node): @@ -50,22 +50,24 @@ def visit_Navbar(self, node): return root def visit_Text(self, node): - return tags.li(tags.a(node.text, _class='nav-link disabled'), _class="nav-item") + return tags.li(tags.a(node.text, _class="nav-link disabled"), _class="nav-item") def visit_Link(self, node): - item = tags.li(_class='nav-item') - item.add(tags.a(node.text, href=node.get_url(), _class='nav-link')) + item = tags.li(_class="nav-item") + item.add(tags.a(node.text, href=node.get_url(), _class="nav-link")) return item def visit_View(self, node): # Dinamically modify the url - node.url_for_kwargs.update({'database_name': self.database_name}) + node.url_for_kwargs.update({"database_name": self.database_name}) - item = tags.li(_class='nav-item') - item.add(tags.a(node.text, href=node.get_url(), title=node.text, _class='nav-link')) + item = tags.li(_class="nav-item") + item.add( + tags.a(node.text, href=node.get_url(), title=node.text, _class="nav-link") + ) if node.active: - item['class'] = 'nav-item active' + item["class"] = "nav-item active" return item @@ -78,34 +80,41 @@ def __init__(self, nav_id=None): def visit_Navbar(self, node): node_id = self.id or sha1(str(id(node)).encode()).hexdigest() - root = tags.nav(_class='navbar navbar-expand-md navbar-dark bg-dark fixed-top') + root = tags.nav(_class="navbar navbar-expand-md navbar-dark bg-dark fixed-top") # title may also have a 'get_url()' method, in which case we render # a brand-link if node.title is not None: - if hasattr(node.title, 'get_url'): - root.add(tags.a(node.title.text, _class='navbar-brand', - href=node.title.get_url())) + if hasattr(node.title, "get_url"): + root.add( + tags.a( + node.title.text, + _class="navbar-brand", + href=node.title.get_url(), + ) + ) else: - root.add(tags.span(node.title, _class='navbar-brand')) + root.add(tags.span(node.title, _class="navbar-brand")) btn = root.add(tags.button()) - btn['type'] = 'button' - btn['class'] = 'navbar-toggler' - btn['data-toggle'] = 'collapse' - btn['data-target'] = '#' + node_id - btn['aria-controls'] = 'navbarCollapse' - btn['aria-expanded'] = 'false' - btn['aria-label'] = "Toggle navigation" + btn["type"] = "button" + btn["class"] = "navbar-toggler" + btn["data-toggle"] = "collapse" + btn["data-target"] = "#" + node_id + btn["aria-controls"] = "navbarCollapse" + btn["aria-expanded"] = "false" + btn["aria-label"] = "Toggle navigation" - btn.add(tags.span('', _class='navbar-toggler-icon')) + btn.add(tags.span("", _class="navbar-toggler-icon")) - bar = root.add(tags.div( - _class='navbar-collapse collapse', - id=node_id, - )) + bar = root.add( + tags.div( + _class="navbar-collapse collapse", + id=node_id, + ) + ) - bar_list = bar.add(tags.ul(_class='navbar-nav mr-auto')) + bar_list = bar.add(tags.ul(_class="navbar-nav mr-auto")) for item in node.items: bar_list.add(self.visit(item)) @@ -117,56 +126,58 @@ def visit_Navbar(self, node): # search_input['aria-label'] = "Search" search_btn = search_form.add(tags.button(_class="btn btn-success my-2 my-sm-0")) - search_btn['type'] = "submit" - search_btn.add_raw_string('+') + search_btn["type"] = "submit" + search_btn.add_raw_string("+") search_btn = search_form.add(tags.button(_class="btn btn-success my-2 my-sm-0")) - search_btn['type'] = "submit" - search_btn.add_raw_string('Login') + search_btn["type"] = "submit" + search_btn.add_raw_string("Login") return root def visit_Text(self, node): if self._in_dropdown: - return tags.a(node.text, _class='dropdown-item disabled') + return tags.a(node.text, _class="dropdown-item disabled") - return tags.li(tags.a(node.text, _class='nav-link disabled'), _class="nav-item") + return tags.li(tags.a(node.text, _class="nav-link disabled"), _class="nav-item") def visit_Link(self, node): if self._in_dropdown: - return tags.a(node.text, href=node.get_url(), _class='dropdown-item') + return tags.a(node.text, href=node.get_url(), _class="dropdown-item") - item = tags.li(_class='nav-item') - item.add(tags.a(node.text, href=node.get_url(), _class='nav-link')) + item = tags.li(_class="nav-item") + item.add(tags.a(node.text, href=node.get_url(), _class="nav-link")) return item def visit_View(self, node): if self._in_dropdown: - return tags.a(node.text, href=node.get_url(), _class='dropdown-item') + return tags.a(node.text, href=node.get_url(), _class="dropdown-item") - item = tags.li(_class='nav-item') - item.add(tags.a(node.text, href=node.get_url(), title=node.text, _class='nav-link')) + item = tags.li(_class="nav-item") + item.add( + tags.a(node.text, href=node.get_url(), title=node.text, _class="nav-link") + ) if node.active: - item['class'] = 'nav-item active' + item["class"] = "nav-item active" return item def visit_Subgroup(self, node): if self._in_dropdown: - raise RuntimeError('Cannot render nested Subgroups') + raise RuntimeError("Cannot render nested Subgroups") - li = tags.li(_class='nav-item dropdown') + li = tags.li(_class="nav-item dropdown") if node.active: - li['class'] = 'nav-item dropdown active' + li["class"] = "nav-item dropdown active" - a = li.add(tags.a(node.title, href='#', _class='nav-link dropdown-toggle')) - a['data-toggle'] = 'dropdown' - a['aria-haspopup'] = 'true' - a['aria-expanded'] = 'false' + a = li.add(tags.a(node.title, href="#", _class="nav-link dropdown-toggle")) + a["data-toggle"] = "dropdown" + a["aria-haspopup"] = "true" + a["aria-expanded"] = "false" - dropdown_div = li.add(tags.div(_class='dropdown-menu')) + dropdown_div = li.add(tags.div(_class="dropdown-menu")) self._in_dropdown = True for item in node.items: @@ -177,6 +188,6 @@ def visit_Subgroup(self, node): def visit_Separator(self, node): if self._in_dropdown: - return tags.div(_class='dropdown-divider') + return tags.div(_class="dropdown-divider") - raise RuntimeError('Cannot render separator outside Subgroup.') + raise RuntimeError("Cannot render separator outside Subgroup.") diff --git a/abcd/server/app/views/api.py b/abcd/server/app/views/api.py index 7a02d307..b3aa6d25 100644 --- a/abcd/server/app/views/api.py +++ b/abcd/server/app/views/api.py @@ -1,21 +1,20 @@ from flask import Blueprint, Response, make_response, jsonify, request -bp = Blueprint('api', __name__) +bp = Blueprint("api", __name__) -@bp.route('/') +@bp.route("/") def index(): - return Response('ok', status=200) + return Response("ok", status=200) + # endpoint to create new user @bp.route("/calculation", methods=["POST"]) def query_calculation(): - response = { - 'query': request.json, - 'results': [] - } + response = {"query": request.json, "results": []} return jsonify(response) + # # endpoint to show all users # @bp.route("/calculation", methods=["GET"]) # def get_calculation(): diff --git a/abcd/server/app/views/database.py b/abcd/server/app/views/database.py index 52dc0588..8569dcd2 100644 --- a/abcd/server/app/views/database.py +++ b/abcd/server/app/views/database.py @@ -1,31 +1,31 @@ from flask import Blueprint, render_template, request from flask import Response -bp = Blueprint('database', __name__) +bp = Blueprint("database", __name__) -@bp.route('/') +@bp.route("/") def index(): return Response(status=200) # Our index-page just shows a quick explanation. Check out the template # "templates/index.html" documentation for more details. -@bp.route('//', methods=['GET']) +@bp.route("//", methods=["GET"]) def database(database_name): # data = Atoms.objects() # list(Atoms.objects.aggregate({'$unwind': '$derived.arrays_keys'}, {'$group': {'_id': '$derived.arrays_keys', 'count': {'$sum': 1}}})) - if request.method == 'POST': - print('POST') + if request.method == "POST": + print("POST") info = { - 'name': database_name, - 'description': 'Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.', - 'columns': [ - {'slug': 'formula', 'name': 'Formula'}, - {'slug': 'energy', 'name': 'Energy'}, - {'slug': 'derived.n_atoms', 'name': "# of atoms"} + "name": database_name, + "description": "Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.", + "columns": [ + {"slug": "formula", "name": "Formula"}, + {"slug": "energy", "name": "Energy"}, + {"slug": "derived.n_atoms", "name": "# of atoms"}, ], } @@ -34,23 +34,23 @@ def database(database_name): # page = request.args.get('page', 1, type=int) # paginated_atoms = atoms.paginate(page, per_page=10) - return render_template("database/database.html", - info=info, - atoms=paginated_atoms) + return render_template("database/database.html", info=info, atoms=paginated_atoms) # Our index-page just shows a quick explanation. Check out the template # "templates/index.html" documentation for more details. -@bp.route('//settings') +@bp.route("//settings") def settings(database_name): info = { - 'name': database_name, - 'description': 'Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.', - 'columns': [ - {'slug': 'formula', 'name': 'Formula'}, - {'slug': 'energy', 'name': 'Energy'}, - {'slug': 'derived.n_atoms', 'name': "# of atoms"} + "name": database_name, + "description": "Vivamus sagittis lacus vel augue laoreet rutrum faucibus dolor auctor. Duis mollis, est non commodo luctus.", + "columns": [ + {"slug": "formula", "name": "Formula"}, + {"slug": "energy", "name": "Energy"}, + {"slug": "derived.n_atoms", "name": "# of atoms"}, ], - 'public': True + "public": True, } - return render_template("database/settings.html", database_name=database_name, info=info) + return render_template( + "database/settings.html", database_name=database_name, info=info + ) diff --git a/abcd/server/app/views/index.py b/abcd/server/app/views/index.py index b283d999..37faa52a 100644 --- a/abcd/server/app/views/index.py +++ b/abcd/server/app/views/index.py @@ -4,24 +4,24 @@ # from flask import jsonify # import requests -bp = Blueprint('index', __name__) +bp = Blueprint("index", __name__) -@bp.route('/') +@bp.route("/") def index(): return render_template("index.html") -@bp.route('/login/') +@bp.route("/login/") def login(): return render_template("login.html") -@bp.route('/new/') +@bp.route("/new/") def new(): return render_template("new.html") -@bp.route('/graphql') +@bp.route("/graphql") def graphql(): - return render_template("graphql.html", url=url_for('api.graphql')) + return render_template("graphql.html", url=url_for("api.graphql")) diff --git a/tests/test_database.py b/tests/test_database.py index 09340849..82cddfab 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -8,9 +8,9 @@ @pytest.fixture -@mongomock.patch(servers=(('localhost', 27017),)) +@mongomock.patch(servers=(("localhost", 27017),)) def abcd_mongodb(): - url = 'mongodb://localhost' + url = "mongodb://localhost" abcd = ABCD.from_url(url) abcd.print_info() @@ -22,13 +22,15 @@ def test_thing(abcd_mongodb): def test_push(abcd_mongodb): - xyz = StringIO("""2 + xyz = StringIO( + """2 Properties=species:S:1:pos:R:3 s="sadf" _vtk_test="t e s t _ s t r" pbc="F F F" -Si 0.00000000 0.00000000 0.00000000 -Si 0.00000000 0.00000000 0.00000000 -""") +Si 0.00000000 0.00000000 0.00000000 +Si 0.00000000 0.00000000 0.00000000 +""" + ) - atoms = read(xyz, format='extxyz') + atoms = read(xyz, format="extxyz") atoms.set_cell([1, 1, 1]) abcd_mongodb.destroy() diff --git a/tests/test_parsers.py b/tests/test_parsers.py index 02a6d98a..5a2211e0 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -4,7 +4,6 @@ class TestParsingExtras: - @pytest.fixture def parser(self): return extras_parser @@ -140,16 +139,15 @@ def test_composite(self, parser): 'a3x3_array="1 4 7 2 5 8 3 6 9" ' # fortran ordering 'Lattice=" 4.3 0.0 0.0 0.0 3.3 0.0 0.0 0.0 7.0 " ' # spaces in array 'comma_separated="7, 4, -1"', - 'array_boolean_2=" T, F, T " ' # leading spaces - 'not_array="1.2 3.4 text"', + 'array_boolean_2=" T, F, T " ' 'not_array="1.2 3.4 text"', # leading spaces "not_bool_array=[T F S]", ], ) - def test_missing(self, string): ... + def test_missing(self, string): + ... class TestParsingQueries: - @pytest.fixture def parser(self): return queries_parser @@ -190,7 +188,8 @@ def test_combination(self, parser, string, expected): ("any(aa) > 3", {}), ], ) - def test_expressions(self, case): ... + def test_expressions(self, case): + ... @pytest.mark.skip("known issues / future features") @pytest.mark.parametrize( @@ -203,4 +202,5 @@ def test_expressions(self, case): ... ("aa and (bb > 23.54 or (22 in cc and dd))", {}), ], ) - def test_expressions(self, case): ... + def test_expressions(self, case): + ... diff --git a/tutorials/gb_upload.py b/tutorials/gb_upload.py index 1b0680d8..3d276de2 100644 --- a/tutorials/gb_upload.py +++ b/tutorials/gb_upload.py @@ -1,23 +1,18 @@ import sys from pathlib import Path -sys.path.append('..') +sys.path.append("..") from abcd import ABCD from utils.ext_xyz import XYZReader -if __name__ == '__main__': +if __name__ == "__main__": - url = 'mongodb://localhost:27017' + url = "mongodb://localhost:27017" abcd = ABCD(url) - for file in Path('GB_alphaFe_001/tilt/').glob('*.xyz'): + for file in Path("GB_alphaFe_001/tilt/").glob("*.xyz"): print(file) - gb_params = { - 'name': 'alphaFe', - 'type': 'tilt', - 'params': file.name[:-4] - - } + gb_params = {"name": "alphaFe", "type": "tilt", "params": file.name[:-4]} with abcd as db, XYZReader(file) as reader: - db.push(reader.read_atoms(), extra_info={'GB_params': gb_params}) + db.push(reader.read_atoms(), extra_info={"GB_params": gb_params}) diff --git a/tutorials/scripts/Preprocess.py b/tutorials/scripts/Preprocess.py index 6b29c474..49a315d8 100644 --- a/tutorials/scripts/Preprocess.py +++ b/tutorials/scripts/Preprocess.py @@ -4,6 +4,7 @@ from ase.io import read, write from ase.geometry import crystal_structure_from_cell import numpy as np + # import numpy.linalg as la import matplotlib.pyplot as plt @@ -12,7 +13,7 @@ class Calculation(object): def __init__(self, *args, **kwargs): - self.filepath = kwargs.pop('filepath', None) + self.filepath = kwargs.pop("filepath", None) self.parameters = kwargs def get_data(self, index=-1): @@ -20,20 +21,19 @@ def get_data(self, index=-1): @classmethod def from_path(cls, path: Path): - with (path / 'gb.json').open() as data_file: + with (path / "gb.json").open() as data_file: gb_data = json.load(data_file) - with (path / 'subgb.json').open() as data_file: + with (path / "subgb.json").open() as data_file: subgb_data = json.load(data_file) # print(gb_data['angle']) - filename = subgb_data['name'] + "_traj.xyz" + filename = subgb_data["name"] + "_traj.xyz" filepath = (path / filename).resolve() # configuration = read(str((path / filename).resolve()), index=index) # # gb = read(str((path / filename).resolve()), index=-1) - # print('{:=^60}'.format(' '+str(path)+' ')) # # print('{:-^40}'.format(' gb.json ')) @@ -60,27 +60,34 @@ def from_path(cls, path: Path): # print('Force mean: {:f}, std: {:f}'.format(force_mean, force_std)) # pprint(gb_final.calc.results) - return cls(**{**gb_data, **subgb_data}, filepath=filepath) -if __name__ == '__main__': +if __name__ == "__main__": # Read grain boundary database - dirpath = Path('../GB_alphaFe_001') + dirpath = Path("../GB_alphaFe_001") calculations = { - 'tilt': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'tilt').iterdir() if calc_dir.is_dir()], - 'twist': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'twist').iterdir() if calc_dir.is_dir()] + "tilt": [ + Calculation.from_path(calc_dir) + for calc_dir in (dirpath / "tilt").iterdir() + if calc_dir.is_dir() + ], + "twist": [ + Calculation.from_path(calc_dir) + for calc_dir in (dirpath / "twist").iterdir() + if calc_dir.is_dir() + ], } # potential energy of the perfect crystal according to a specific potential - potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH + potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH eV = 1.6021766208e-19 - Angstrom = 1.e-10 + Angstrom = 1.0e-10 angles, energies = [], [] - for calc in sorted(calculations['tilt'], key=lambda item: item.parameters['angle']): + for calc in sorted(calculations["tilt"], key=lambda item: item.parameters["angle"]): # E_gb = calc.parameters.get('E_gb', None) # @@ -92,7 +99,7 @@ def from_path(cls, path: Path): # energy = 16.02 / (2 * calc.parameters['A'] ) * \ # (E_gb - potential_energy_per_atom * calc.parameters['n_at']) - if calc.parameters.get('converged', None): + if calc.parameters.get("converged", None): # energy = 16.02 / (2 * calc.parameters['A'] ) * \ # (calc.parameters.get('E_gb') - potential_energy_per_atom * calc.parameters['n_at']) # @@ -101,24 +108,26 @@ def from_path(cls, path: Path): A = cell[0, 0] * cell[1, 1] energy = ( - eV / Angstrom**2 / - (2 * A) * - (atoms.get_potential_energy() - potential_energy_per_atom * len(atoms)) + eV + / Angstrom**2 + / (2 * A) + * ( + atoms.get_potential_energy() + - potential_energy_per_atom * len(atoms) + ) ) write(calc.filepath.name, atoms) - # print(energy) # print(calc.parameters['converged']) # print(data.get_potential_energy()) # data.get_total_energy() == data.get_potential_energy() # energies.append(calc.parameters['E_gb'] - data.get_total_energy()) energies.append(energy) - angles.append(calc.parameters['angle'] * 180.0 / np.pi) + angles.append(calc.parameters["angle"] * 180.0 / np.pi) else: print("not converged: ", calc.filepath) - plt.bar(angles, energies) # x_smooth = np.linspace(min(angles), max(angles), 1000, endpoint=True) @@ -126,5 +135,3 @@ def from_path(cls, path: Path): # plt.plot(x_smooth, f(x_smooth), '-') plt.show() - - diff --git a/tutorials/scripts/Reader.py b/tutorials/scripts/Reader.py index 3bd0ff3f..83d6be52 100644 --- a/tutorials/scripts/Reader.py +++ b/tutorials/scripts/Reader.py @@ -4,6 +4,7 @@ from ase.io import read, write from ase.geometry import crystal_structure_from_cell import numpy as np + # import numpy.linalg as la import matplotlib.pyplot as plt @@ -12,7 +13,7 @@ class Calculation(object): def __init__(self, *args, **kwargs): - self.filepath = kwargs.pop('filepath', None) + self.filepath = kwargs.pop("filepath", None) self.parameters = kwargs def get_data(self, index=-1): @@ -20,20 +21,19 @@ def get_data(self, index=-1): @classmethod def from_path(cls, path: Path, index=-1): - with (path / 'gb.json').open() as data_file: + with (path / "gb.json").open() as data_file: gb_data = json.load(data_file) - with (path / 'subgb.json').open() as data_file: + with (path / "subgb.json").open() as data_file: subgb_data = json.load(data_file) # print(gb_data['angle']) - filename = subgb_data['name'] + "_traj.xyz" + filename = subgb_data["name"] + "_traj.xyz" filepath = (path / filename).resolve() # configuration = read(str((path / filename).resolve()), index=index) # # gb = read(str((path / filename).resolve()), index=-1) - # print('{:=^60}'.format(' '+str(path)+' ')) # # print('{:-^40}'.format(' gb.json ')) @@ -60,41 +60,54 @@ def from_path(cls, path: Path, index=-1): # print('Force mean: {:f}, std: {:f}'.format(force_mean, force_std)) # pprint(gb_final.calc.results) - return cls(**{**gb_data, **subgb_data}, filepath=filepath) -if __name__ == '__main__': +if __name__ == "__main__": # Read grain boundary database - dirpath = Path('../GB_alphaFe_001') + dirpath = Path("../GB_alphaFe_001") calculations = { - 'tilt': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'tilt').iterdir() if calc_dir.is_dir()], - 'twist': [Calculation.from_path(calc_dir) for calc_dir in (dirpath / 'twist').iterdir() if calc_dir.is_dir()] + "tilt": [ + Calculation.from_path(calc_dir) + for calc_dir in (dirpath / "tilt").iterdir() + if calc_dir.is_dir() + ], + "twist": [ + Calculation.from_path(calc_dir) + for calc_dir in (dirpath / "twist").iterdir() + if calc_dir.is_dir() + ], } # potential energy of the perfect crystal according to a specific potential - potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH + potential_energy_per_atom = -4.01298214176 # alpha-Fe PotBH eV = 1.6021766208e-19 - Angstrom = 1.e-10 + Angstrom = 1.0e-10 angles, energies = [], [] - for calc in sorted(calculations['tilt'], key=lambda item: item.parameters['angle']): - angles.append(calc.parameters['angle'] * 180.0 / np.pi) - + for calc in sorted(calculations["tilt"], key=lambda item: item.parameters["angle"]): + angles.append(calc.parameters["angle"] * 180.0 / np.pi) - energy = 16.02 / (2 * calc.parameters['A'] ) * \ - (calc.parameters['E_gb'] - potential_energy_per_atom * calc.parameters['n_at']) + energy = ( + 16.02 + / (2 * calc.parameters["A"]) + * ( + calc.parameters["E_gb"] + - potential_energy_per_atom * calc.parameters["n_at"] + ) + ) atoms = calc.get_data() cell = atoms.get_cell() A = cell[0, 0] * cell[1, 1] energy = ( - eV / Angstrom**2 / - (2 * A) * - (atoms.get_total_energy() - potential_energy_per_atom * len(atoms)) + eV + / Angstrom**2 + / (2 * A) + * (atoms.get_total_energy() - potential_energy_per_atom * len(atoms)) ) print(energy) @@ -102,7 +115,6 @@ def from_path(cls, path: Path, index=-1): # energies.append(calc.parameters['E_gb'] - data.get_total_energy()) energies.append(energy) - plt.bar(angles, energies) # x_smooth = np.linspace(min(angles), max(angles), 1000, endpoint=True) @@ -186,7 +198,7 @@ def from_path(cls, path: Path, index=-1): # print at.get_potential_energy() # print E_gb, 'eV/A^2' # E_gb = 16.02*(at.get_potential_energy()-(at.n*(E_bulk)))/(2.*A) - # print E_gb, 'J/m^2' +# print E_gb, 'J/m^2' # return E_gb # # relax.py @@ -194,4 +206,4 @@ def from_path(cls, path: Path, index=-1): # E_gb = grain.get_potential_energy() # # -# ener_per_atom = -4.01298214176 \ No newline at end of file +# ener_per_atom = -4.01298214176 diff --git a/tutorials/scripts/Visualise.py b/tutorials/scripts/Visualise.py index 1ed65c2d..308e76fb 100644 --- a/tutorials/scripts/Visualise.py +++ b/tutorials/scripts/Visualise.py @@ -7,13 +7,12 @@ import matplotlib.pyplot as plt import numpy as np - -@register_backend('ase') +@register_backend("ase") class MyASEStructure(Structure): def __init__(self, atoms, bfactor=[], occupancy=[]): # super(MyASEStructure, self).__init__() - self.ext = 'pdb' + self.ext = "pdb" self.params = {} self._atoms = atoms self.bfactor = bfactor # [min, max] @@ -21,7 +20,7 @@ def __init__(self, atoms, bfactor=[], occupancy=[]): self.id = str(uuid.uuid4()) def get_structure_string(self): - """ PDB file format: + """PDB file format: CRYST1 16.980 62.517 124.864 90.00 90.00 90.00 P 1 MODEL 1 ATOM 0 Fe MOL 1 15.431 60.277 6.801 1.00 0.00 FE @@ -36,12 +35,12 @@ def get_structure_string(self): if self._atoms.get_pbc().any(): cellpar = self._atoms.get_cell_lengths_and_angles() - str_format = 'CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + ' P 1\n' + str_format = "CRYST1" + "{:9.3f}" * 3 + "{:7.2f}" * 3 + " P 1\n" data += str_format.format(*cellpar.tolist()) - data += 'MODEL 1\n' + data += "MODEL 1\n" - str_format = 'ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n' + str_format = "ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n" for index, atom in enumerate(self._atoms): data += str_format.format( index, @@ -51,10 +50,10 @@ def get_structure_string(self): atom.position[2].tolist(), self.occupancy[index] if index <= len(self.occupancy) - 1 else 1.0, self.bfactor[index] if index <= len(self.bfactor) - 1 else 1.0, - atom.symbol.upper() + atom.symbol.upper(), ) - data += 'ENDMDL\n' + data += "ENDMDL\n" return data @@ -63,12 +62,11 @@ def ViewStructure(atoms): import nglview view = nglview.NGLWidget() - + structure = MyASEStructure(atoms) view.add_structure(structure) - - return view + return view class AtomViewer(object): @@ -76,38 +74,35 @@ def __init__(self, atoms, data=[], xsize=1000, ysize=500): self.view = self._init_nglview(atoms, data, xsize, ysize) self.widgets = { - 'radius': FloatSlider( - value=0.8, min=0.0, max=1.5, step=0.01, - description='Ball size' + "radius": FloatSlider( + value=0.8, min=0.0, max=1.5, step=0.01, description="Ball size" ), - 'color_scheme': Dropdown(description='Solor scheme:'), - 'colorbar': Output() + "color_scheme": Dropdown(description="Solor scheme:"), + "colorbar": Output(), } self.show_colorbar(data) - self.widgets['radius'].observe(self._update_repr) + self.widgets["radius"].observe(self._update_repr) - self.gui = VBox([ - self.view, - self.widgets['colorbar'], - self.widgets['radius']]) + self.gui = VBox([self.view, self.widgets["colorbar"], self.widgets["radius"]]) def _update_repr(self, chg=None): self.view.update_spacefill( - radiusType='radius', - radius=self.widgets['radius'].value + radiusType="radius", radius=self.widgets["radius"].value ) def show_colorbar(self, data): - with self.widgets['colorbar']: + with self.widgets["colorbar"]: # Have colormaps separated into categories: # http://matplotlib.org/examples/color/colormaps_reference.html - cmap = 'rainbow' + cmap = "rainbow" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 2)) - img = ax1.imshow([[min(data), max(data)]], aspect='auto', cmap=plt.get_cmap(cmap)) + img = ax1.imshow( + [[min(data), max(data)]], aspect="auto", cmap=plt.get_cmap(cmap) + ) ax1.remove() - cbar = fig.colorbar(img, cax=ax2, orientation='horizontal') + cbar = fig.colorbar(img, cax=ax2, orientation="horizontal") plt.show() @@ -117,15 +112,12 @@ def _init_nglview(atoms, data, xsize, ysize): view = nglview.NGLWidget(gui=False) view._remote_call( - 'setSize', - target='Widget', - args=[ - '{:d}px'.format(xsize), - '{:d}px'.format(ysize) - ] + "setSize", + target="Widget", + args=["{:d}px".format(xsize), "{:d}px".format(ysize)], ) - data = np.max(data)-data + data = np.max(data) - data structure = MyASEStructure(atoms, bfactor=data) view.add_structure(structure) @@ -136,19 +128,15 @@ def _init_nglview(atoms, data, xsize, ysize): view.add_spacefill( # radiusType='radius', # radius=1.0, - color_scheme='bfactor', - color_scale='rainbow' - ) - view.update_spacefill( - radiusType='radius', - radius=1.0 + color_scheme="bfactor", + color_scale="rainbow", ) - + view.update_spacefill(radiusType="radius", radius=1.0) # update camera type view.control.spin([1, 0, 0], np.pi / 2) view.control.spin([0, 0, 1], np.pi / 2) - view.camera = 'orthographic' + view.camera = "orthographic" view.center() return view diff --git a/tutorials/scripts/Visualise_quip.py b/tutorials/scripts/Visualise_quip.py index 1ed65c2d..308e76fb 100644 --- a/tutorials/scripts/Visualise_quip.py +++ b/tutorials/scripts/Visualise_quip.py @@ -7,13 +7,12 @@ import matplotlib.pyplot as plt import numpy as np - -@register_backend('ase') +@register_backend("ase") class MyASEStructure(Structure): def __init__(self, atoms, bfactor=[], occupancy=[]): # super(MyASEStructure, self).__init__() - self.ext = 'pdb' + self.ext = "pdb" self.params = {} self._atoms = atoms self.bfactor = bfactor # [min, max] @@ -21,7 +20,7 @@ def __init__(self, atoms, bfactor=[], occupancy=[]): self.id = str(uuid.uuid4()) def get_structure_string(self): - """ PDB file format: + """PDB file format: CRYST1 16.980 62.517 124.864 90.00 90.00 90.00 P 1 MODEL 1 ATOM 0 Fe MOL 1 15.431 60.277 6.801 1.00 0.00 FE @@ -36,12 +35,12 @@ def get_structure_string(self): if self._atoms.get_pbc().any(): cellpar = self._atoms.get_cell_lengths_and_angles() - str_format = 'CRYST1' + '{:9.3f}' * 3 + '{:7.2f}' * 3 + ' P 1\n' + str_format = "CRYST1" + "{:9.3f}" * 3 + "{:7.2f}" * 3 + " P 1\n" data += str_format.format(*cellpar.tolist()) - data += 'MODEL 1\n' + data += "MODEL 1\n" - str_format = 'ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n' + str_format = "ATOM {:5d} {:>4s} MOL 1 {:8.3f}{:8.3f}{:8.3f}{:6.2f}{:6.2f} {:2s}\n" for index, atom in enumerate(self._atoms): data += str_format.format( index, @@ -51,10 +50,10 @@ def get_structure_string(self): atom.position[2].tolist(), self.occupancy[index] if index <= len(self.occupancy) - 1 else 1.0, self.bfactor[index] if index <= len(self.bfactor) - 1 else 1.0, - atom.symbol.upper() + atom.symbol.upper(), ) - data += 'ENDMDL\n' + data += "ENDMDL\n" return data @@ -63,12 +62,11 @@ def ViewStructure(atoms): import nglview view = nglview.NGLWidget() - + structure = MyASEStructure(atoms) view.add_structure(structure) - - return view + return view class AtomViewer(object): @@ -76,38 +74,35 @@ def __init__(self, atoms, data=[], xsize=1000, ysize=500): self.view = self._init_nglview(atoms, data, xsize, ysize) self.widgets = { - 'radius': FloatSlider( - value=0.8, min=0.0, max=1.5, step=0.01, - description='Ball size' + "radius": FloatSlider( + value=0.8, min=0.0, max=1.5, step=0.01, description="Ball size" ), - 'color_scheme': Dropdown(description='Solor scheme:'), - 'colorbar': Output() + "color_scheme": Dropdown(description="Solor scheme:"), + "colorbar": Output(), } self.show_colorbar(data) - self.widgets['radius'].observe(self._update_repr) + self.widgets["radius"].observe(self._update_repr) - self.gui = VBox([ - self.view, - self.widgets['colorbar'], - self.widgets['radius']]) + self.gui = VBox([self.view, self.widgets["colorbar"], self.widgets["radius"]]) def _update_repr(self, chg=None): self.view.update_spacefill( - radiusType='radius', - radius=self.widgets['radius'].value + radiusType="radius", radius=self.widgets["radius"].value ) def show_colorbar(self, data): - with self.widgets['colorbar']: + with self.widgets["colorbar"]: # Have colormaps separated into categories: # http://matplotlib.org/examples/color/colormaps_reference.html - cmap = 'rainbow' + cmap = "rainbow" fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 2)) - img = ax1.imshow([[min(data), max(data)]], aspect='auto', cmap=plt.get_cmap(cmap)) + img = ax1.imshow( + [[min(data), max(data)]], aspect="auto", cmap=plt.get_cmap(cmap) + ) ax1.remove() - cbar = fig.colorbar(img, cax=ax2, orientation='horizontal') + cbar = fig.colorbar(img, cax=ax2, orientation="horizontal") plt.show() @@ -117,15 +112,12 @@ def _init_nglview(atoms, data, xsize, ysize): view = nglview.NGLWidget(gui=False) view._remote_call( - 'setSize', - target='Widget', - args=[ - '{:d}px'.format(xsize), - '{:d}px'.format(ysize) - ] + "setSize", + target="Widget", + args=["{:d}px".format(xsize), "{:d}px".format(ysize)], ) - data = np.max(data)-data + data = np.max(data) - data structure = MyASEStructure(atoms, bfactor=data) view.add_structure(structure) @@ -136,19 +128,15 @@ def _init_nglview(atoms, data, xsize, ysize): view.add_spacefill( # radiusType='radius', # radius=1.0, - color_scheme='bfactor', - color_scale='rainbow' - ) - view.update_spacefill( - radiusType='radius', - radius=1.0 + color_scheme="bfactor", + color_scale="rainbow", ) - + view.update_spacefill(radiusType="radius", radius=1.0) # update camera type view.control.spin([1, 0, 0], np.pi / 2) view.control.spin([0, 0, 1], np.pi / 2) - view.camera = 'orthographic' + view.camera = "orthographic" view.center() return view diff --git a/tutorials/test_db.py b/tutorials/test_db.py index 8f1fb542..8a8f34f8 100644 --- a/tutorials/test_db.py +++ b/tutorials/test_db.py @@ -3,7 +3,7 @@ from abcd import ABCD -if __name__ == '__main__': +if __name__ == "__main__": # http requests # url = 'http://localhost:5000/api' @@ -12,9 +12,9 @@ # Mongoengine # https://stackoverflow.com/questions/36200288/mongolab-pymongo-connection-error - url = 'mongodb://root:example@localhost:27018/?authSource=admin' + url = "mongodb://root:example@localhost:27018/?authSource=admin" - abcd = ABCD(url, db='abcd', collection='default') + abcd = ABCD(url, db="abcd", collection="default") print(abcd) abcd.print_info() @@ -24,8 +24,8 @@ abcd.destroy() - direcotry = Path('../tutorials/data/') - file = direcotry / 'bcc_bulk_54_expanded_2_high.xyz' + direcotry = Path("../tutorials/data/") + file = direcotry / "bcc_bulk_54_expanded_2_high.xyz" # file = direcotry / 'GAP_6.xyz' traj = read(file.as_posix(), index=slice(None)) diff --git a/tutorials/test_upload.py b/tutorials/test_upload.py index 0f3a34bd..c5fe8e91 100644 --- a/tutorials/test_upload.py +++ b/tutorials/test_upload.py @@ -3,7 +3,7 @@ from abcd import ABCD -if __name__ == '__main__': - abcd = ABCD(url='mongodb://localhost:27017') +if __name__ == "__main__": + abcd = ABCD(url="mongodb://localhost:27017") print(abcd) abcd.print_info()