Skip to content

Commit e274cef

Browse files
committed
type: Annotate bst.schema
1 parent a35001e commit e274cef

File tree

2 files changed

+31
-24
lines changed

2 files changed

+31
-24
lines changed

tools/schemacode/src/bidsschematools/schema.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from __future__ import annotations
44

55
import json
6-
import os
76
import re
8-
from collections.abc import Iterable, Mapping
7+
from collections.abc import Iterable, Mapping, MutableMapping
98
from copy import deepcopy
109
from functools import cache, lru_cache
1110
from pathlib import Path
@@ -15,6 +14,9 @@
1514

1615
TYPE_CHECKING = False
1716
if TYPE_CHECKING:
17+
from typing import Any, Callable
18+
19+
from acres import typ as at
1820
from jsonschema.protocols import Validator as JsonschemaValidator
1921

2022
lgr = utils.get_logger()
@@ -24,39 +26,38 @@ class BIDSSchemaError(Exception):
2426
"""Errors indicating invalid values in the schema itself"""
2527

2628

27-
def _get_schema_version(schema_dir):
29+
def _get_schema_version(schema_dir: str | at.Traversable) -> str:
2830
"""
2931
Determine schema version for given schema directory, based on file specification.
3032
"""
33+
if isinstance(schema_dir, str):
34+
schema_dir = Path(schema_dir)
3135

32-
schema_version_path = os.path.join(schema_dir, "SCHEMA_VERSION")
33-
with open(schema_version_path) as f:
34-
schema_version = f.readline().rstrip()
35-
return schema_version
36+
schema_version_path = schema_dir / "SCHEMA_VERSION"
37+
return schema_version_path.read_text().strip()
3638

3739

38-
def _get_bids_version(schema_dir):
40+
def _get_bids_version(schema_dir: str | at.Traversable) -> str:
3941
"""
4042
Determine BIDS version for given schema directory, with directory name, file specification,
4143
and string fallback.
4244
"""
45+
if isinstance(schema_dir, str):
46+
schema_dir = Path(schema_dir)
4347

44-
bids_version_path = os.path.join(schema_dir, "BIDS_VERSION")
48+
bids_version_path = schema_dir / "BIDS_VERSION"
4549
try:
46-
with open(bids_version_path) as f:
47-
bids_version = f.readline().rstrip()
50+
return bids_version_path.read_text().strip()
4851
# If this file is not in the schema, fall back to placeholder heuristics:
4952
except FileNotFoundError:
5053
# Maybe the directory encodes the version, as in:
5154
# https://github.com/bids-standard/bids-schema
52-
_, bids_version = os.path.split(schema_dir)
53-
if not re.match(r"^.*?[0-9]*?\.[0-9]*?\.[0-9]*?.*?$", bids_version):
54-
# Then we don't know, really.
55-
bids_version = schema_dir
56-
return bids_version
55+
if re.match(r"^.*?[0-9]*?\.[0-9]*?\.[0-9]*?.*?$", schema_dir.name):
56+
return schema_dir.name
57+
return str(schema_dir)
5758

5859

59-
def _find(obj, predicate):
60+
def _find(obj: object, predicate: Callable[[Any], bool]) -> Iterable[object]:
6061
"""Find objects in an arbitrary object that satisfy a predicate.
6162
6263
Note that this does not cut branches, so every iterable sub-object
@@ -78,7 +79,7 @@ def _find(obj, predicate):
7879
except Exception:
7980
pass
8081

81-
iterable = ()
82+
iterable: Iterable[object] = ()
8283
if isinstance(obj, Mapping):
8384
iterable = obj.values()
8485
elif not isinstance(obj, str) and isinstance(obj, Iterable):
@@ -88,16 +89,17 @@ def _find(obj, predicate):
8889
yield from _find(item, predicate)
8990

9091

91-
def _dereference(namespace, base_schema):
92+
def _dereference(namespace: MutableMapping, base_schema: Namespace) -> None:
9293
# In-place, recursively dereference objects
9394
# This allows a referenced object to itself contain a reference
9495
# A dependency graph could be constructed, but would likely be slower
9596
# to build than to duplicate a couple dereferences
9697
for struct in _find(namespace, lambda obj: "$ref" in obj):
98+
assert isinstance(struct, MutableMapping)
9799
target = base_schema.get(struct["$ref"])
98100
if target is None:
99101
raise ValueError(f"Reference {struct['$ref']} not found in schema.")
100-
if isinstance(target, Mapping):
102+
if isinstance(target, MutableMapping):
101103
struct.pop("$ref")
102104
_dereference(target, base_schema)
103105
struct.update({**target, **struct})
@@ -110,7 +112,7 @@ def get_schema_validator() -> JsonschemaValidator:
110112
return utils.jsonschema_validator(metaschema, check_format=True)
111113

112114

113-
def dereference(namespace, inplace=True):
115+
def dereference(namespace: Namespace, inplace: bool = True) -> Namespace:
114116
"""Replace references in namespace with the contents of the referred object.
115117
116118
Parameters
@@ -133,6 +135,8 @@ def dereference(namespace, inplace=True):
133135

134136
# At this point, any remaining refs are one-off objects in lists
135137
for struct in _find(namespace, lambda obj: any("$ref" in sub for sub in obj)):
138+
assert isinstance(struct, list)
139+
item: MutableMapping[str, object]
136140
for i, item in enumerate(struct):
137141
try:
138142
target = item.pop("$ref")
@@ -144,7 +148,7 @@ def dereference(namespace, inplace=True):
144148
return namespace
145149

146150

147-
def flatten_enums(namespace, inplace=True):
151+
def flatten_enums(namespace: Namespace, inplace=True) -> Namespace:
148152
"""Replace enum collections with a single enum, merging enums contents.
149153
150154
The function helps reducing the complexity of the schema by assuming
@@ -175,6 +179,7 @@ def flatten_enums(namespace, inplace=True):
175179
if not inplace:
176180
namespace = deepcopy(namespace)
177181
for struct in _find(namespace, lambda obj: "anyOf" in obj):
182+
assert isinstance(struct, MutableMapping)
178183
try:
179184
# Deduplicate because JSON schema validators may not like duplicates
180185
# Long run, we should get rid of this function and have the rendering
@@ -189,7 +194,7 @@ def flatten_enums(namespace, inplace=True):
189194

190195

191196
@lru_cache
192-
def load_schema(schema_path=None):
197+
def load_schema(schema_path: at.Traversable | str | None = None) -> Namespace:
193198
"""Load the schema into a dict-like structure.
194199
195200
This function allows the schema, like BIDS itself, to be specified in
@@ -221,6 +226,7 @@ def load_schema(schema_path=None):
221226

222227
# Probably a Windows checkout with a git link. Resolve first.
223228
if schema_path.is_file() and (content := schema_path.read_text()).startswith("../"):
229+
assert isinstance(schema_path, Path)
224230
schema_path = Path.resolve(schema_path.parent / content)
225231
lgr.info("No schema path specified, defaulting to the bundled schema, `%s`.", schema_path)
226232
elif isinstance(schema_path, str):

tools/schemacode/src/bidsschematools/types/_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
TYPE_CHECKING = False
2020
if TYPE_CHECKING:
21+
from collections.abc import Mapping
2122
from typing import Any, Callable, Protocol
2223

2324
class Spec(Protocol):
@@ -241,7 +242,7 @@ def generate_protocols(
241242
return protocols
242243

243244

244-
def generate_module(schema: dict[str, Any], class_type: str) -> str:
245+
def generate_module(schema: Mapping[str, Any], class_type: str) -> str:
245246
"""Generate a context module source code from a BIDS schema.
246247
247248
Returns a tuple containing the module source code and a list of protocol names.

0 commit comments

Comments
 (0)