Skip to content

Commit e91546e

Browse files
Pyright Improvements (#932)
# Pull Request ## Title Pyright improvements ______________________________________________________________________ ## Description Pyright is a type checker that ships with VSCode's Pylance by default. It is billed as a faster, though less complete, version of mypy. As such it gets a few things a little differently that mypy and alerts in VSCode. This PR fixes those ("standard") alerts and removes the mypy extension from VSCode's default extensions for MLOS in favor of just using pyright (there's no sense in running both interactively). We do not enable pyright's "strict" mode. Additionally, it enables pyright in pre-commit rules to ensure those fixes remain. We leave the rest of the mypy checks as well since they are still useful. A list of some of the types of fixes: - TypeDict initialization checks for Tunables - Check that json.loads() returns a dict and not a list (e.g.) - Replace ConcreteOptimizer TypeVar with a TypeAlias - Add BoundMethod protocol for checking __self__ attribute - Ensure correct type inference in a number of places - Add `...` to Protocol methods to make pyright aware of the lack of method body. - Fix a few type annotations ______________________________________________________________________ ## Type of Change - 🛠️ Bug fix - 🔄 Refactor ______________________________________________________________________ ## Testing - Additional CI checks as described above. ______________________________________________________________________ --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c66f793 commit e91546e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+323
-106
lines changed

.devcontainer/devcontainer.json

-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@
6868
"huntertran.auto-markdown-toc",
6969
"ibm.output-colorizer",
7070
"lextudio.restructuredtext",
71-
"matangover.mypy",
7271
"ms-azuretools.vscode-docker",
7372
"ms-python.black-formatter",
7473
"ms-python.pylint",

.pre-commit-config.yaml

+16-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ci:
44
# Let pre-commit.ci automatically update PRs with formatting fixes.
55
autofix_prs: true
66
# skip local hooks - they should be managed manually via conda-envs/*.yml
7-
skip: [mypy, pylint, pycodestyle]
7+
skip: [mypy, pylint, pycodestyle, pyright]
88
autoupdate_schedule: monthly
99
autoupdate_commit_msg: |
1010
[pre-commit.ci] pre-commit autoupdate
@@ -15,6 +15,7 @@ ci:
1515
See Also:
1616
- https://github.com/microsoft/MLOS/blob/main/conda-envs/mlos.yml
1717
- https://pypi.org/project/mypy/
18+
- https://pypi.org/project/pyright/
1819
- https://pypi.org/project/pylint/
1920
- https://pypi.org/project/pycodestyle/
2021
@@ -140,6 +141,20 @@ repos:
140141
(?x)^(
141142
doc/source/conf.py
142143
)$
144+
- id: pyright
145+
name: pyright
146+
entry: pyright
147+
language: system
148+
types: [python]
149+
require_serial: true
150+
exclude: |
151+
(?x)^(
152+
doc/source/conf.py|
153+
mlos_core/setup.py|
154+
mlos_bench/setup.py|
155+
mlos_viz/setup.py|
156+
conftest.py
157+
)$
143158
- id: mypy
144159
name: mypy
145160
entry: mypy

.vscode/extensions.json

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
"huntertran.auto-markdown-toc",
1313
"ibm.output-colorizer",
1414
"lextudio.restructuredtext",
15-
"matangover.mypy",
1615
"ms-azuretools.vscode-docker",
1716
"ms-python.black-formatter",
1817
"ms-python.pylint",

.vscode/settings.json

+1-2
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,5 @@
170170
"python.testing.unittestEnabled": false,
171171
"debugpy.debugJustMyCode": false,
172172
"python.analysis.autoImportCompletions": true,
173-
"python.analysis.supportRestructuredText": true,
174-
"python.analysis.typeCheckingMode": "standard"
173+
"python.analysis.supportRestructuredText": true
175174
}

conda-envs/mlos.yml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies:
2828
- pylint==3.3.3
2929
- tomlkit
3030
- mypy==1.14.1
31+
- pyright==1.1.392.post0
3132
- pandas-stubs
3233
- types-beautifulsoup4
3334
- types-colorama

doc/source/conf.py

+4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
from sphinx.application import Sphinx as SphinxApp
3030
from sphinx.environment import BuildEnvironment
3131

32+
# Note: doc requirements aren't installed by default.
33+
# To install them, run `pip install -r doc/requirements.txt`
34+
35+
3236
sys.path.insert(0, os.path.abspath("../../mlos_core/mlos_core"))
3337
sys.path.insert(1, os.path.abspath("../../mlos_bench/mlos_bench"))
3438
sys.path.insert(1, os.path.abspath("../../mlos_viz/mlos_viz"))

mlos_bench/mlos_bench/environments/mock_env.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ def __init__( # pylint: disable=too-many-arguments
6464
seed = int(self.config.get("mock_env_seed", -1))
6565
self._run_random = random.Random(seed or None) if seed >= 0 else None
6666
self._status_random = random.Random(seed or None) if seed >= 0 else None
67-
self._range = self.config.get("mock_env_range")
68-
self._metrics = self.config.get("mock_env_metrics", ["score"])
67+
self._range: tuple[int, int] | None = self.config.get("mock_env_range")
68+
self._metrics: list[str] | None = self.config.get("mock_env_metrics", ["score"])
6969
self._is_ready = True
7070

7171
def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue]:
@@ -80,7 +80,7 @@ def _produce_metrics(self, rand: random.Random | None) -> dict[str, TunableValue
8080
if self._range:
8181
score = self._range[0] + score * (self._range[1] - self._range[0])
8282

83-
return {metric: score for metric in self._metrics}
83+
return {metric: float(score) for metric in self._metrics or []}
8484

8585
def run(self) -> tuple[Status, datetime, dict[str, TunableValue] | None]:
8686
"""

mlos_bench/mlos_bench/optimizers/mock_optimizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
self._random: dict[str, Callable[[Tunable], TunableValue]] = {
3333
"categorical": lambda tunable: rnd.choice(tunable.categories),
3434
"float": lambda tunable: rnd.uniform(*tunable.range),
35-
"int": lambda tunable: rnd.randint(*tunable.range),
35+
"int": lambda tunable: rnd.randint(*(int(x) for x in tunable.range)),
3636
}
3737

3838
def bulk_register(

mlos_bench/mlos_bench/services/base_service.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Any, Literal
1515

1616
from mlos_bench.config.schemas import ConfigSchema
17+
from mlos_bench.services.types.bound_method import BoundMethod
1718
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
1819
from mlos_bench.util import instantiate_from_config
1920

@@ -278,7 +279,7 @@ def register(self, services: dict[str, Callable] | list[Callable]) -> None:
278279
for _, svc_method in self._service_methods.items()
279280
# Note: some methods are actually stand alone functions, so we need
280281
# to filter them out.
281-
if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service)
282+
if isinstance(svc_method, BoundMethod) and isinstance(svc_method.__self__, Service)
282283
}
283284

284285
def export(self) -> dict[str, Callable]:

mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class CopyMode(Enum):
2626
class SshFileShareService(FileShareService, SshService):
2727
"""A collection of functions for interacting with SSH servers as file shares."""
2828

29+
# pylint: disable=too-many-ancestors
30+
2931
async def _start_file_copy(
3032
self,
3133
params: dict,

mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec):
2525
"""Helper methods to manage machines via SSH."""
2626

27+
# pylint: disable=too-many-ancestors
2728
# pylint: disable=too-many-instance-attributes
2829

2930
def __init__(

mlos_bench/mlos_bench/services/types/authenticator_type.py

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
class SupportsAuth(Protocol[T_co]):
1515
"""Protocol interface for authentication for the cloud services."""
1616

17+
# Needed by pyright
18+
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc
19+
1720
def get_access_token(self) -> str:
1821
"""
1922
Get the access token for cloud services.
@@ -23,6 +26,7 @@ def get_access_token(self) -> str:
2326
access_token : str
2427
Access token.
2528
"""
29+
...
2630

2731
def get_auth_headers(self) -> dict:
2832
"""
@@ -33,6 +37,7 @@ def get_auth_headers(self) -> dict:
3337
access_header : dict
3438
HTTP header containing the access token.
3539
"""
40+
...
3641

3742
def get_credential(self) -> T_co:
3843
"""
@@ -43,3 +48,4 @@ def get_credential(self) -> T_co:
4348
credential : T_co
4449
Cloud-specific credential object.
4550
"""
51+
...
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
#
5+
"""Protocol representing a bound method."""
6+
7+
from typing import Any, Protocol, runtime_checkable
8+
9+
10+
@runtime_checkable
11+
class BoundMethod(Protocol):
12+
"""A callable method bound to an object."""
13+
14+
# pylint: disable=too-few-public-methods
15+
# pylint: disable=unnecessary-ellipsis
16+
17+
@property
18+
def __self__(self) -> Any:
19+
"""The self object of the bound method."""
20+
...
21+
22+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
23+
"""Call the bound method."""
24+
...

mlos_bench/mlos_bench/services/types/config_loader_type.py

+10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
class SupportsConfigLoading(Protocol):
2424
"""Protocol interface for helper functions to lookup and load configs."""
2525

26+
# Needed by pyright
27+
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc
28+
2629
def get_config_paths(self) -> list[str]:
2730
"""
2831
Gets the list of config paths this service will search for config files.
@@ -31,6 +34,7 @@ def get_config_paths(self) -> list[str]:
3134
-------
3235
list[str]
3336
"""
37+
...
3438

3539
def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None) -> str:
3640
"""
@@ -49,6 +53,7 @@ def resolve_path(self, file_path: str, extra_paths: Iterable[str] | None = None)
4953
path : str
5054
An actual path to the config or script.
5155
"""
56+
...
5257

5358
def load_config(
5459
self,
@@ -71,6 +76,7 @@ def load_config(
7176
config : Union[dict, list[dict]]
7277
Free-format dictionary that contains the configuration.
7378
"""
79+
...
7480

7581
def build_environment( # pylint: disable=too-many-arguments
7682
self,
@@ -108,6 +114,7 @@ def build_environment( # pylint: disable=too-many-arguments
108114
env : Environment
109115
An instance of the `Environment` class initialized with `config`.
110116
"""
117+
...
111118

112119
def load_environment(
113120
self,
@@ -140,6 +147,7 @@ def load_environment(
140147
env : Environment
141148
A new benchmarking environment.
142149
"""
150+
...
143151

144152
def load_environment_list(
145153
self,
@@ -173,6 +181,7 @@ def load_environment_list(
173181
env : list[Environment]
174182
A list of new benchmarking environments.
175183
"""
184+
...
176185

177186
def load_services(
178187
self,
@@ -198,3 +207,4 @@ def load_services(
198207
service : Service
199208
A collection of service methods.
200209
"""
210+
...

mlos_bench/mlos_bench/services/types/host_ops_type.py

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
class SupportsHostOps(Protocol):
1515
"""Protocol interface for Host/VM boot operations."""
1616

17+
# pylint: disable=unnecessary-ellipsis
18+
1719
def start_host(self, params: dict) -> tuple["Status", dict]:
1820
"""
1921
Start a Host/VM.
@@ -29,6 +31,7 @@ def start_host(self, params: dict) -> tuple["Status", dict]:
2931
A pair of Status and result. The result is always {}.
3032
Status is one of {PENDING, SUCCEEDED, FAILED}
3133
"""
34+
...
3235

3336
def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
3437
"""
@@ -47,6 +50,7 @@ def stop_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
4750
A pair of Status and result. The result is always {}.
4851
Status is one of {PENDING, SUCCEEDED, FAILED}
4952
"""
53+
...
5054

5155
def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dict]:
5256
"""
@@ -65,6 +69,7 @@ def restart_host(self, params: dict, force: bool = False) -> tuple["Status", dic
6569
A pair of Status and result. The result is always {}.
6670
Status is one of {PENDING, SUCCEEDED, FAILED}
6771
"""
72+
...
6873

6974
def wait_host_operation(self, params: dict) -> tuple["Status", dict]:
7075
"""
@@ -85,3 +90,4 @@ def wait_host_operation(self, params: dict) -> tuple["Status", dict]:
8590
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
8691
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
8792
"""
93+
...

mlos_bench/mlos_bench/services/types/host_provisioner_type.py

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
class SupportsHostProvisioning(Protocol):
1515
"""Protocol interface for Host/VM provisioning operations."""
1616

17+
# pylint: disable=unnecessary-ellipsis
18+
1719
def provision_host(self, params: dict) -> tuple["Status", dict]:
1820
"""
1921
Check if Host/VM is ready. Deploy a new Host/VM, if necessary.
@@ -31,6 +33,7 @@ def provision_host(self, params: dict) -> tuple["Status", dict]:
3133
A pair of Status and result. The result is always {}.
3234
Status is one of {PENDING, SUCCEEDED, FAILED}
3335
"""
36+
...
3437

3538
def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status", dict]:
3639
"""
@@ -52,6 +55,7 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> tuple["Status
5255
Status is one of {PENDING, SUCCEEDED, FAILED, TIMED_OUT}
5356
Result is info on the operation runtime if SUCCEEDED, otherwise {}.
5457
"""
58+
...
5559

5660
def deprovision_host(self, params: dict) -> tuple["Status", dict]:
5761
"""
@@ -68,6 +72,7 @@ def deprovision_host(self, params: dict) -> tuple["Status", dict]:
6872
A pair of Status and result. The result is always {}.
6973
Status is one of {PENDING, SUCCEEDED, FAILED}
7074
"""
75+
...
7176

7277
def deallocate_host(self, params: dict) -> tuple["Status", dict]:
7378
"""
@@ -88,3 +93,4 @@ def deallocate_host(self, params: dict) -> tuple["Status", dict]:
8893
A pair of Status and result. The result is always {}.
8994
Status is one of {PENDING, SUCCEEDED, FAILED}
9095
"""
96+
...

mlos_bench/mlos_bench/services/types/local_exec_type.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ class SupportsLocalExec(Protocol):
2424
vs the target environment. Used in LocalEnv and provided by LocalExecService.
2525
"""
2626

27+
# Needed by pyright
28+
# pylint: disable=unnecessary-ellipsis,redundant-returns-doc
29+
2730
def local_exec(
2831
self,
2932
script_lines: Iterable[str],
@@ -49,6 +52,7 @@ def local_exec(
4952
(return_code, stdout, stderr) : (int, str, str)
5053
A 3-tuple of return code, stdout, and stderr of the script process.
5154
"""
55+
...
5256

5357
def temp_dir_context(
5458
self,
@@ -59,11 +63,12 @@ def temp_dir_context(
5963
6064
Parameters
6165
----------
62-
path : str
66+
path : str | None
6367
A path to the temporary directory. Create a new one if None.
6468
6569
Returns
6670
-------
6771
temp_dir_context : tempfile.TemporaryDirectory
6872
Temporary directory context to use in the `with` clause.
6973
"""
74+
...

0 commit comments

Comments
 (0)