Skip to content

Commit dd40ae7

Browse files
authored
fix: Add type annotations, fix a bug in loading schema path (#2264)
* type: Annotate bst.utils * type: Annotate types.namespace * type: Annotate bst.schema * chore: Update pre-commit hooks
1 parent 953e550 commit dd40ae7

File tree

6 files changed

+87
-66
lines changed

6 files changed

+87
-66
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: check-added-large-files
1818
- id: check-case-conflict
1919
- repo: https://github.com/python-jsonschema/check-jsonschema
20-
rev: 0.33.3
20+
rev: 0.35.0
2121
hooks:
2222
- id: check-dependabot
2323
- id: check-github-workflows
@@ -49,7 +49,7 @@ repos:
4949
- id: codespell
5050
args: ["--config=.codespellrc", "--dictionary=-", "--dictionary=.codespell_dict"]
5151
- repo: https://github.com/pre-commit/mirrors-mypy
52-
rev: v1.17.1
52+
rev: v1.18.2
5353
hooks:
5454
- id: mypy
5555
# Sync with project.optional-dependencies.typing

tools/schemacode/src/bidsschematools/render/tsv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@propagate_fence_exception
15-
@in_context(WarningsFilter(["error"]))
15+
@in_context(WarningsFilter(("error",)))
1616
def fence(
1717
source: str,
1818
language: str,

tools/schemacode/src/bidsschematools/schema.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from __future__ import annotations
44

55
import json
6-
import os
76
import re
8-
from collections.abc import Iterable, Mapping
7+
from collections.abc import Iterable, Mapping, MutableMapping
98
from copy import deepcopy
109
from functools import cache, lru_cache
1110
from pathlib import Path
@@ -15,6 +14,9 @@
1514

1615
TYPE_CHECKING = False
1716
if TYPE_CHECKING:
17+
from typing import Any, Callable
18+
19+
from acres import typ as at
1820
from jsonschema.protocols import Validator as JsonschemaValidator
1921

2022
lgr = utils.get_logger()
@@ -24,39 +26,38 @@ class BIDSSchemaError(Exception):
2426
"""Errors indicating invalid values in the schema itself"""
2527

2628

27-
def _get_schema_version(schema_dir):
29+
def _get_schema_version(schema_dir: str | at.Traversable) -> str:
2830
"""
2931
Determine schema version for given schema directory, based on file specification.
3032
"""
33+
if isinstance(schema_dir, str):
34+
schema_dir = Path(schema_dir)
3135

32-
schema_version_path = os.path.join(schema_dir, "SCHEMA_VERSION")
33-
with open(schema_version_path) as f:
34-
schema_version = f.readline().rstrip()
35-
return schema_version
36+
schema_version_path = schema_dir / "SCHEMA_VERSION"
37+
return schema_version_path.read_text().strip()
3638

3739

38-
def _get_bids_version(schema_dir):
40+
def _get_bids_version(schema_dir: str | at.Traversable) -> str:
3941
"""
4042
Determine BIDS version for given schema directory, with directory name, file specification,
4143
and string fallback.
4244
"""
45+
if isinstance(schema_dir, str):
46+
schema_dir = Path(schema_dir)
4347

44-
bids_version_path = os.path.join(schema_dir, "BIDS_VERSION")
48+
bids_version_path = schema_dir / "BIDS_VERSION"
4549
try:
46-
with open(bids_version_path) as f:
47-
bids_version = f.readline().rstrip()
50+
return bids_version_path.read_text().strip()
4851
# If this file is not in the schema, fall back to placeholder heuristics:
4952
except FileNotFoundError:
5053
# Maybe the directory encodes the version, as in:
5154
# https://github.com/bids-standard/bids-schema
52-
_, bids_version = os.path.split(schema_dir)
53-
if not re.match(r"^.*?[0-9]*?\.[0-9]*?\.[0-9]*?.*?$", bids_version):
54-
# Then we don't know, really.
55-
bids_version = schema_dir
56-
return bids_version
55+
if re.match(r"^.*?[0-9]*?\.[0-9]*?\.[0-9]*?.*?$", schema_dir.name):
56+
return schema_dir.name
57+
return str(schema_dir)
5758

5859

59-
def _find(obj, predicate):
60+
def _find(obj: object, predicate: Callable[[Any], bool]) -> Iterable[object]:
6061
"""Find objects in an arbitrary object that satisfy a predicate.
6162
6263
Note that this does not cut branches, so every iterable sub-object
@@ -78,7 +79,7 @@ def _find(obj, predicate):
7879
except Exception:
7980
pass
8081

81-
iterable = ()
82+
iterable: Iterable[object] = ()
8283
if isinstance(obj, Mapping):
8384
iterable = obj.values()
8485
elif not isinstance(obj, str) and isinstance(obj, Iterable):
@@ -88,16 +89,17 @@ def _find(obj, predicate):
8889
yield from _find(item, predicate)
8990

9091

91-
def _dereference(namespace, base_schema):
92+
def _dereference(namespace: MutableMapping, base_schema: Namespace) -> None:
9293
# In-place, recursively dereference objects
9394
# This allows a referenced object to itself contain a reference
9495
# A dependency graph could be constructed, but would likely be slower
9596
# to build than to duplicate a couple dereferences
9697
for struct in _find(namespace, lambda obj: "$ref" in obj):
98+
assert isinstance(struct, MutableMapping)
9799
target = base_schema.get(struct["$ref"])
98100
if target is None:
99101
raise ValueError(f"Reference {struct['$ref']} not found in schema.")
100-
if isinstance(target, Mapping):
102+
if isinstance(target, MutableMapping):
101103
struct.pop("$ref")
102104
_dereference(target, base_schema)
103105
struct.update({**target, **struct})
@@ -110,7 +112,7 @@ def get_schema_validator() -> JsonschemaValidator:
110112
return utils.jsonschema_validator(metaschema, check_format=True)
111113

112114

113-
def dereference(namespace, inplace=True):
115+
def dereference(namespace: Namespace, inplace: bool = True) -> Namespace:
114116
"""Replace references in namespace with the contents of the referred object.
115117
116118
Parameters
@@ -133,6 +135,8 @@ def dereference(namespace, inplace=True):
133135

134136
# At this point, any remaining refs are one-off objects in lists
135137
for struct in _find(namespace, lambda obj: any("$ref" in sub for sub in obj)):
138+
assert isinstance(struct, list)
139+
item: MutableMapping[str, object]
136140
for i, item in enumerate(struct):
137141
try:
138142
target = item.pop("$ref")
@@ -144,7 +148,7 @@ def dereference(namespace, inplace=True):
144148
return namespace
145149

146150

147-
def flatten_enums(namespace, inplace=True):
151+
def flatten_enums(namespace: Namespace, inplace=True) -> Namespace:
148152
"""Replace enum collections with a single enum, merging enums contents.
149153
150154
The function helps reducing the complexity of the schema by assuming
@@ -175,6 +179,7 @@ def flatten_enums(namespace, inplace=True):
175179
if not inplace:
176180
namespace = deepcopy(namespace)
177181
for struct in _find(namespace, lambda obj: "anyOf" in obj):
182+
assert isinstance(struct, MutableMapping)
178183
try:
179184
# Deduplicate because JSON schema validators may not like duplicates
180185
# Long run, we should get rid of this function and have the rendering
@@ -189,7 +194,7 @@ def flatten_enums(namespace, inplace=True):
189194

190195

191196
@lru_cache
192-
def load_schema(schema_path=None):
197+
def load_schema(schema_path: at.Traversable | str | None = None) -> Namespace:
193198
"""Load the schema into a dict-like structure.
194199
195200
This function allows the schema, like BIDS itself, to be specified in
@@ -221,6 +226,7 @@ def load_schema(schema_path=None):
221226

222227
# Probably a Windows checkout with a git link. Resolve first.
223228
if schema_path.is_file() and (content := schema_path.read_text()).startswith("../"):
229+
assert isinstance(schema_path, Path)
224230
schema_path = Path.resolve(schema_path.parent / content)
225231
lgr.info("No schema path specified, defaulting to the bundled schema, `%s`.", schema_path)
226232
elif isinstance(schema_path, str):

tools/schemacode/src/bidsschematools/types/_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
TYPE_CHECKING = False
2020
if TYPE_CHECKING:
21+
from collections.abc import Mapping
2122
from typing import Any, Callable, Protocol
2223

2324
class Spec(Protocol):
@@ -241,7 +242,7 @@ def generate_protocols(
241242
return protocols
242243

243244

244-
def generate_module(schema: dict[str, Any], class_type: str) -> str:
245+
def generate_module(schema: Mapping[str, Any], class_type: str) -> str:
245246
"""Generate a context module source code from a BIDS schema.
246247
247248
Returns a tuple containing the module source code and a list of protocol names.

tools/schemacode/src/bidsschematools/types/namespace.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
YAML files available as a single dictionary and allow attribute (``.``)
55
lookups.
66
"""
7+
from __future__ import annotations
78

89
import json
9-
import typing as ty
10+
import os.path
1011
from collections.abc import ItemsView, KeysView, Mapping, MutableMapping, ValuesView
1112
from pathlib import Path
1213

14+
TYPE_CHECKING = False
15+
if TYPE_CHECKING:
16+
from collections.abc import Iterator
17+
from typing import Any, Self
1318

14-
def _expand_dots(entry: tuple[str, ty.Any]) -> tuple[str, ty.Any]:
19+
from acres import typ as at
20+
21+
22+
def _expand_dots(entry: tuple[str, Any]) -> tuple[str, Any]:
1523
# Helper function for expand
1624
key, val = entry
1725
if "." in key:
@@ -20,7 +28,7 @@ def _expand_dots(entry: tuple[str, ty.Any]) -> tuple[str, ty.Any]:
2028
return key, expand(val)
2129

2230

23-
def expand(element):
31+
def expand(element: dict[str, Any]) -> dict[str, Any]:
2432
"""Expand a dict, recursively, to replace dots in keys with recursive dictionaries
2533
2634
Parameters
@@ -46,18 +54,18 @@ def expand(element):
4654

4755

4856
class NsItemsView(ItemsView):
49-
def __init__(self, namespace, level):
57+
def __init__(self, namespace: Mapping, level: int):
5058
self._mapping = namespace
5159
self._level = level
5260

53-
def __contains__(self, item):
61+
def __contains__(self, item: Any) -> bool:
5462
key, val = item
5563
keys = key.split(".", self._level - 1)
5664
if "." in keys[-1]:
5765
return False
5866
return self._mapping[key] == val
5967

60-
def __iter__(self):
68+
def __iter__(self) -> Iterator[tuple[str, Any]]:
6169
l1 = ItemsView(self._mapping)
6270
if self._level == 1:
6371
yield from l1
@@ -73,30 +81,30 @@ def __iter__(self):
7381

7482

7583
class NsKeysView(KeysView):
76-
def __init__(self, namespace, level):
84+
def __init__(self, namespace: Mapping, level: int):
7785
self._mapping = namespace
7886
self._level = level
7987

80-
def __contains__(self, key):
88+
def __contains__(self, key: Any) -> bool:
8189
keys = key.split(".", self._level - 1)
8290
if "." in keys[-1]:
8391
return False
8492
return key in self._mapping
8593

86-
def __iter__(self):
94+
def __iter__(self) -> Iterator[str]:
8795
yield from (key for key, val in NsItemsView(self._mapping, self._level))
8896

8997

9098
class NsValuesView(ValuesView):
91-
def __init__(self, namespace, level):
99+
def __init__(self, namespace: Mapping, level: int):
92100
self._mapping = namespace
93101
self._level = level
94102
self._items = NsItemsView(namespace, level)
95103

96-
def __contains__(self, val):
104+
def __contains__(self, val: object) -> bool:
97105
return any(val == item[1] for item in self._items)
98106

99-
def __iter__(self):
107+
def __iter__(self) -> Iterator[Any]:
100108
yield from (val for key, val in self._items)
101109

102110

@@ -162,7 +170,7 @@ class Namespace(MutableMapping):
162170
>>> del ns['d']
163171
"""
164172

165-
def __init__(self, *args, **kwargs):
173+
def __init__(self, *args, **kwargs) -> None:
166174
self._properties = dict(*args, **kwargs)
167175

168176
def to_dict(self) -> dict:
@@ -177,7 +185,7 @@ def _to_dict(obj):
177185

178186
return _to_dict(self)
179187

180-
def __deepcopy__(self, memo):
188+
def __deepcopy__(self, memo) -> Self:
181189
return self.build(self.to_dict())
182190

183191
@classmethod
@@ -218,42 +226,43 @@ def __getattribute__(self, key):
218226
except KeyError:
219227
raise err
220228

221-
def _get_mapping(self, key: str) -> tuple[Mapping, str]:
229+
def _get_mapping(self, key: str) -> tuple[MutableMapping, str]:
222230
subkeys = key.split(".")
223231
mapping = self._properties
224232
for subkey in subkeys[:-1]:
225233
mapping = mapping.setdefault(subkey, {})
226-
mapping = getattr(mapping, "_properties", mapping)
234+
if isinstance(mapping, Namespace):
235+
mapping = mapping._properties
227236
if not isinstance(mapping, Mapping):
228237
raise KeyError(f"{key} (subkey={subkey})")
229238
return mapping, subkeys[-1]
230239

231-
def __getitem__(self, key):
240+
def __getitem__(self, key: str) -> Any:
232241
mapping, subkey = self._get_mapping(key)
233242
val = mapping[subkey]
234243
if isinstance(val, dict):
235244
val = self.view(val)
236245
return val
237246

238-
def __setitem__(self, key, val):
247+
def __setitem__(self, key: str, val: Any):
239248
mapping, subkey = self._get_mapping(key)
240249
mapping[subkey] = val
241250

242-
def __delitem__(self, key):
251+
def __delitem__(self, key: str):
243252
mapping, subkey = self._get_mapping(key)
244253
del mapping[subkey]
245254

246-
def __repr__(self):
255+
def __repr__(self) -> str:
247256
return f"<Namespace {self._properties}>"
248257

249-
def __len__(self):
258+
def __len__(self) -> int:
250259
return len(self._properties)
251260

252-
def __iter__(self):
261+
def __iter__(self) -> Iterator[str]:
253262
return iter(self._properties)
254263

255264
@classmethod
256-
def from_directory(cls, path, fmt="yaml"):
265+
def from_directory(cls, path: at.Traversable | str, fmt: str = "yaml") -> Self:
257266
if fmt == "yaml":
258267
if isinstance(path, str):
259268
path = Path(path)
@@ -264,27 +273,27 @@ def to_json(self, **kwargs) -> str:
264273
return json.dumps(self, cls=MappingEncoder, **kwargs)
265274

266275
@classmethod
267-
def from_json(cls, jsonstr: str):
276+
def from_json(cls, jsonstr: str) -> Self:
268277
return cls.build(json.loads(jsonstr))
269278

270279

271-
def _read_yaml_dir(path: Path) -> dict:
280+
def _read_yaml_dir(path: at.Traversable) -> dict:
272281
mapping = {}
273-
for subpath in sorted(path.iterdir()):
282+
for subpath in sorted(path.iterdir(), key=lambda p: p.name):
274283
if subpath.is_dir():
275284
mapping[subpath.name] = _read_yaml_dir(subpath)
276285
elif subpath.name.endswith("yaml"):
277286
import yaml
278287

279288
try:
280-
mapping[subpath.stem] = yaml.safe_load(subpath.read_text())
289+
mapping[os.path.splitext(subpath.name)[0]] = yaml.safe_load(subpath.read_text())
281290
except Exception as e:
282291
raise ValueError(f"There was an error reading the file: {subpath}") from e
283292
return mapping
284293

285294

286295
class MappingEncoder(json.JSONEncoder):
287-
def default(self, o):
296+
def default(self, o: object) -> object:
288297
try:
289298
return super().default(o)
290299
except TypeError as e:

0 commit comments

Comments
 (0)