Skip to content

Commit aa9bdbf

Browse files
authored
Migrate ext.COD from mysql to REST API (#4117)
* tweak comments * reduce default timeout as 10 minutes is unrealistic * remove mysql test in test * finish rewrite * use tighter timeout in test * capture timeout errors in ci * make the timeout skip a wrapper * use a conditional timeout * better deprecation handle without breaking * use a formula with only one match
1 parent a28e1da commit aa9bdbf

File tree

2 files changed

+106
-74
lines changed

2 files changed

+106
-74
lines changed

src/pymatgen/ext/cod.py

Lines changed: 69 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -27,97 +27,109 @@
2727

2828
from __future__ import annotations
2929

30-
import re
31-
import subprocess
3230
import warnings
33-
from shutil import which
31+
from typing import TYPE_CHECKING
3432

3533
import requests
36-
from monty.dev import requires
3734

3835
from pymatgen.core.composition import Composition
3936
from pymatgen.core.structure import Structure
4037

38+
if TYPE_CHECKING:
39+
from typing import Literal
40+
4141

4242
class COD:
43-
"""An interface to the Crystallography Open Database."""
43+
"""An interface to the Crystallography Open Database.
4444
45-
url = "www.crystallography.net"
45+
Reference:
46+
https://wiki.crystallography.net/RESTful_API/
47+
"""
4648

47-
def query(self, sql: str) -> str:
48-
"""Perform a query.
49+
def __init__(self, timeout: int = 60):
50+
"""Initialize the COD class.
4951
5052
Args:
51-
sql: SQL string
52-
53-
Returns:
54-
Response from SQL query.
53+
timeout (int): request timeout in seconds.
5554
"""
56-
response = subprocess.check_output(["mysql", "-u", "cod_reader", "-h", self.url, "-e", sql, "cod"])
57-
return response.decode("utf-8")
55+
self.timeout = timeout
56+
self.url = "https://www.crystallography.net"
57+
self.api_url = f"{self.url}/cod/result"
5858

59-
@requires(which("mysql"), "mysql must be installed to use this query.")
60-
def get_cod_ids(self, formula) -> list[int]:
61-
"""Query the COD for all cod ids associated with a formula. Requires
62-
mysql executable to be in the path.
59+
def get_cod_ids(self, formula: str) -> list[int]:
60+
"""Query the COD for all COD IDs associated with a formula.
6361
6462
Args:
65-
formula (str): Formula.
66-
67-
Returns:
68-
List of cod ids.
63+
formula (str): The formula to request
6964
"""
70-
# TODO: Remove dependency on external mysql call. MySQL-python package does not support Py3!
71-
72-
# Standardize formula to the version used by COD
65+
# Use hill_formula format as per COD request
7366
cod_formula = Composition(formula).hill_formula
74-
sql = f'select file from data where formula="- {cod_formula} -"' # noqa: S608
75-
text = self.query(sql).split("\n")
76-
cod_ids = []
77-
for line in text:
78-
if match := re.search(r"(\d+)", line):
79-
cod_ids.append(int(match[1]))
80-
return cod_ids
8167

82-
def get_structure_by_id(self, cod_id: int, timeout: int = 600, **kwargs) -> Structure:
83-
"""Query the COD for a structure by id.
68+
# Set up query parameters
69+
params = {"formula": cod_formula, "format": "json"}
70+
71+
response = requests.get(self.api_url, params=params, timeout=self.timeout)
72+
73+
# Raise an exception if the request fails
74+
response.raise_for_status()
75+
76+
return [int(entry["file"]) for entry in response.json()]
77+
78+
def get_structure_by_id(self, cod_id: int, timeout: int | None = None, **kwargs) -> Structure:
79+
"""Query the COD for a structure by ID.
8480
8581
Args:
86-
cod_id (int): COD id.
87-
timeout (int): Timeout for the request in seconds. Default = 600.
88-
kwargs: All kwargs supported by Structure.from_str.
82+
cod_id (int): COD ID.
83+
timeout (int): DEPRECATED. request timeout in seconds.
84+
kwargs: kwargs passed to Structure.from_str.
8985
9086
Returns:
9187
A Structure.
9288
"""
93-
response = requests.get(f"https://{self.url}/cod/{cod_id}.cif", timeout=timeout)
89+
# TODO: remove timeout arg and use class level timeout after 2025-10-17
90+
if timeout is not None:
91+
warnings.warn("separate timeout arg is deprecated, please use class level timeout", DeprecationWarning)
92+
timeout = timeout or self.timeout
93+
94+
response = requests.get(f"{self.url}/cod/{cod_id}.cif", timeout=timeout)
9495
return Structure.from_str(response.text, fmt="cif", **kwargs)
9596

96-
@requires(which("mysql"), "mysql must be installed to use this query.")
97-
def get_structure_by_formula(self, formula: str, **kwargs) -> list[dict[str, str | int | Structure]]:
98-
"""Query the COD for structures by formula. Requires mysql executable to
99-
be in the path.
97+
def get_structure_by_formula(
98+
self,
99+
formula: str,
100+
**kwargs,
101+
) -> list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]]:
102+
"""Query the COD for structures by formula.
100103
101104
Args:
102105
formula (str): Chemical formula.
103106
kwargs: All kwargs supported by Structure.from_str.
104107
105108
Returns:
106-
A list of dict of the format [{"structure": Structure, "cod_id": int, "sg": "P n m a"}]
109+
A list of dict of: {"structure": Structure, "cod_id": int, "sg": "P n m a"}
107110
"""
108-
structures: list[dict[str, str | int | Structure]] = []
109-
sql = f'select file, sg from data where formula="- {Composition(formula).hill_formula} -"' # noqa: S608
110-
text = self.query(sql).split("\n")
111-
text.pop(0)
112-
for line in text:
113-
if line.strip():
114-
cod_id, sg = line.split("\t")
115-
response = requests.get(f"https://{self.url}/cod/{cod_id.strip()}.cif", timeout=60)
116-
try:
117-
struct = Structure.from_str(response.text, fmt="cif", **kwargs)
118-
structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg})
119-
except Exception:
120-
warnings.warn(f"\nStructure.from_str failed while parsing CIF file:\n{response.text}")
121-
raise
111+
# Prepare the query parameters
112+
params = {
113+
"formula": Composition(formula).hill_formula,
114+
"format": "json",
115+
}
116+
117+
response = requests.get(self.api_url, params=params, timeout=self.timeout)
118+
response.raise_for_status()
119+
120+
structures: list[dict[Literal["structure", "cod_id", "sg"], str | int | Structure]] = []
121+
122+
# Parse the JSON response
123+
for entry in response.json():
124+
cod_id = entry["file"]
125+
sg = entry.get("sg")
126+
127+
try:
128+
struct = self.get_structure_by_id(cod_id, **kwargs)
129+
structures.append({"structure": struct, "cod_id": int(cod_id), "sg": sg})
130+
131+
except Exception:
132+
warnings.warn(f"Structure.from_str failed while parsing CIF file for COD ID {cod_id}", stacklevel=2)
133+
raise
122134

123135
return structures

tests/ext/test_cod.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,58 @@
11
from __future__ import annotations
22

33
import os
4-
from shutil import which
5-
from unittest import TestCase
4+
from functools import wraps
65

76
import pytest
87
import requests
9-
import urllib3
108

119
from pymatgen.ext.cod import COD
1210

13-
if "CI" in os.environ: # test is slow and flaky, skip in CI. see
14-
# https://github.com/materialsproject/pymatgen/pull/3777#issuecomment-2071217785
15-
pytest.skip(allow_module_level=True, reason="Skip COD test in CI")
11+
# Set a tighter timeout in CI
12+
TIMEOUT = 10 if os.getenv("CI") else 60
13+
1614

1715
try:
18-
WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=60).status_code != 200
19-
except (requests.exceptions.ConnectionError, urllib3.exceptions.ConnectTimeoutError):
16+
WEBSITE_DOWN = requests.get("https://www.crystallography.net", timeout=TIMEOUT).status_code != 200
17+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout):
2018
WEBSITE_DOWN = True
2119

20+
if WEBSITE_DOWN:
21+
pytest.skip(reason="www.crystallography.net is down", allow_module_level=True)
22+
23+
24+
def skip_on_timeout(func):
25+
"""Skip test in CI when time out."""
26+
27+
@wraps(func)
28+
def wrapper(*args, **kwargs):
29+
try:
30+
return func(*args, **kwargs)
31+
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.ReadTimeout):
32+
if os.getenv("CI"):
33+
pytest.skip("Request timeout in CI environment")
34+
else:
35+
raise
36+
37+
return wrapper
38+
2239

23-
@pytest.mark.skipif(WEBSITE_DOWN, reason="www.crystallography.net is down")
24-
class TestCOD(TestCase):
25-
@pytest.mark.skipif(not which("mysql"), reason="No mysql")
40+
class TestCOD:
41+
@skip_on_timeout
2642
def test_get_cod_ids(self):
27-
ids = COD().get_cod_ids("Li2O")
43+
ids = COD(timeout=TIMEOUT).get_cod_ids("Li2O")
2844
assert len(ids) > 15
45+
assert set(ids).issuperset({1010064, 1011372})
2946

30-
@pytest.mark.skipif(not which("mysql"), reason="No mysql")
47+
@skip_on_timeout
3148
def test_get_structure_by_formula(self):
32-
data = COD().get_structure_by_formula("Li2O")
33-
assert len(data) > 15
34-
assert data[0]["structure"].reduced_formula == "Li2O"
49+
# This formula has only one match (as of 2024-10-17) therefore
50+
# the runtime is shorter (~ 2s for each match)
51+
data = COD(timeout=TIMEOUT).get_structure_by_formula("C3 H18 F6 Fe N9")
52+
assert len(data) >= 1
53+
assert data[0]["structure"].reduced_formula == "FeH18C3(N3F2)3"
3554

55+
@skip_on_timeout
3656
def test_get_structure_by_id(self):
37-
struct = COD().get_structure_by_id(2_002_926)
57+
struct = COD(timeout=TIMEOUT).get_structure_by_id(2_002_926)
3858
assert struct.formula == "Be8 H64 N16 F32"

0 commit comments

Comments
 (0)