Skip to content

Commit 316d6b5

Browse files
committed
ruff fixes
1 parent 55e536a commit 316d6b5

File tree

8 files changed

+50
-29
lines changed

8 files changed

+50
-29
lines changed

dictdatabase/io_safe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
from __future__ import annotations
2+
13
import os
24

35
from . import config, io_unsafe, locking, utils
46

57

6-
def read(file_name: str) -> dict:
8+
def read(file_name: str) -> dict | None:
79
"""
810
Read the content of a file as a dict.
911
@@ -20,7 +22,7 @@ def read(file_name: str) -> dict:
2022
return io_unsafe.read(file_name)
2123

2224

23-
def partial_read(file_name: str, key: str) -> dict:
25+
def partial_read(file_name: str, key: str) -> dict | None:
2426
"""
2527
Read only the value of a key-value pair from a file.
2628

dictdatabase/locking.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class FileLocksSnapshot:
7979
On init, orphaned locks are removed.
8080
"""
8181

82-
__slots__ = ("any_has_locks", "any_write_locks", "any_has_write_locks", "locks")
82+
__slots__ = ("any_has_locks", "any_has_write_locks", "any_write_locks", "locks")
8383

8484
locks: list[LockFileMeta]
8585
any_has_locks: bool
@@ -142,15 +142,15 @@ class AbstractLock:
142142
provides a blueprint for derived classes to implement.
143143
"""
144144

145-
__slots__ = ("db_name", "need_lock", "has_lock", "snapshot", "mode", "is_alivekeep_alive_thread")
145+
__slots__ = ("db_name", "has_lock", "is_alive", "keep_alive_thread", "mode", "need_lock", "snapshot")
146146

147147
db_name: str
148148
need_lock: LockFileMeta
149149
has_lock: LockFileMeta
150150
snapshot: FileLocksSnapshot
151151
mode: str
152152
is_alive: bool
153-
keep_alive_thread: threading.Thread
153+
keep_alive_thread: threading.Thread | None
154154

155155
def __init__(self, db_name: str) -> None:
156156
# Normalize db_name to avoid file naming conflicts
@@ -197,7 +197,8 @@ def _start_keep_alive_thread(self) -> None:
197197
"""
198198

199199
if self.keep_alive_thread is not None:
200-
raise RuntimeError("Keep alive thread already exists.")
200+
msg = "Keep alive thread already exists."
201+
raise RuntimeError(msg)
201202

202203
self.is_alive = True
203204
self.keep_alive_thread = threading.Thread(target=self._keep_alive_thread, daemon=False)
@@ -227,7 +228,7 @@ def _unlock(self) -> None:
227228
def __enter__(self) -> None:
228229
self._lock()
229230

230-
def __exit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001
231+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
231232
self._unlock()
232233

233234

@@ -248,7 +249,8 @@ def _lock(self) -> None:
248249
# If this thread already holds a read lock, raise an exception.
249250
if self.snapshot.exists(self.has_lock):
250251
os.unlink(self.need_lock.path)
251-
raise RuntimeError("Thread already has a read lock. Do not try to obtain a read lock twice.")
252+
msg = "Thread already has a read lock. Do not try to obtain a read lock twice."
253+
raise RuntimeError(msg)
252254

253255
start_time = time.time()
254256

@@ -264,7 +266,8 @@ def _lock(self) -> None:
264266
return
265267
time.sleep(SLEEP_TIMEOUT)
266268
if time.time() - start_time > AQUIRE_LOCK_TIMEOUT:
267-
raise RuntimeError("Timeout while waiting for read lock.")
269+
msg = "Timeout while waiting for read lock."
270+
raise RuntimeError(msg)
268271
self.snapshot = FileLocksSnapshot(self.need_lock)
269272

270273

@@ -285,7 +288,8 @@ def _lock(self) -> None:
285288
# If this thread already holds a write lock, raise an exception.
286289
if self.snapshot.exists(self.has_lock):
287290
os.unlink(self.need_lock.path)
288-
raise RuntimeError("Thread already has a write lock. Do not try to obtain a write lock twice.")
291+
msg = "Thread already has a write lock. Do not try to obtain a write lock twice."
292+
raise RuntimeError(msg)
289293

290294
start_time = time.time()
291295

@@ -299,5 +303,6 @@ def _lock(self) -> None:
299303
return
300304
time.sleep(SLEEP_TIMEOUT)
301305
if time.time() - start_time > AQUIRE_LOCK_TIMEOUT:
302-
raise RuntimeError("Timeout while waiting for write lock.")
306+
msg = "Timeout while waiting for write lock."
307+
raise RuntimeError(msg)
303308
self.snapshot = FileLocksSnapshot(self.need_lock)

dictdatabase/models.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ def __init__(self, path: str, key: str, where: Callable) -> None:
3636
self.key = key is not None
3737

3838
if self.key and self.where:
39-
raise TypeError("Cannot specify both key and where")
39+
msg = "Cannot specify both key and where"
40+
raise TypeError(msg)
4041
if self.key and self.dir:
41-
raise TypeError("Cannot specify sub-key when selecting a folder. Specify the key in the path instead.")
42+
msg = "Cannot specify sub-key when selecting a folder. Specify the key in the path instead."
43+
raise TypeError(msg)
4244

4345
@property
4446
def file_normal(self) -> bool:
@@ -88,7 +90,7 @@ def at(*path, key: str | None = None, where: Callable[[Any, Any], bool] | None =
8890

8991

9092
class DDBMethodChooser:
91-
__slots__ = ("path", "key", "where", "op_type")
93+
__slots__ = ("key", "op_type", "path", "where")
9294

9395
path: str
9496
key: str
@@ -124,7 +126,8 @@ def exists(self) -> bool:
124126
As long it exists as a key in any dict, it will be found.
125127
"""
126128
if self.where is not None:
127-
raise RuntimeError("DDB.at(where=...).exists() cannot be used with the where parameter")
129+
msg = "DDB.at(where=...).exists() cannot be used with the where parameter"
130+
raise RuntimeError(msg)
128131

129132
if not utils.file_exists(self.path):
130133
return False
@@ -163,7 +166,8 @@ def delete(self) -> None:
163166
Delete the file at the selected path.
164167
"""
165168
if self.where is not None or self.key is not None:
166-
raise RuntimeError("DDB.at().delete() cannot be used with the where or key parameters")
169+
msg = "DDB.at().delete() cannot be used with the where or key parameters"
170+
raise RuntimeError(msg)
167171
io_safe.delete(self.path)
168172

169173
def read(self, as_type: Type[T] | None = None) -> dict | T | None:
@@ -180,7 +184,7 @@ def type_cast(value):
180184
return value
181185
return as_type(value)
182186

183-
data = {}
187+
data: dict = {}
184188

185189
if self.op_type.file_normal:
186190
data = io_safe.read(self.path)

dictdatabase/sessions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ def type_cast(obj, as_type):
1313
return obj if as_type is None else as_type(obj)
1414

1515

16-
class SessionBase:
16+
class SessionBase[T]:
1717
in_session: bool
1818
db_name: str
1919
as_type: T
2020

21-
def __init__(self, db_name: str, as_type):
21+
def __init__(self, db_name: str, as_type: T) -> None:
2222
self.in_session = False
2323
self.db_name = db_name
2424
self.as_type = as_type
@@ -27,7 +27,7 @@ def __enter__(self):
2727
self.in_session = True
2828
self.data_handle = {}
2929

30-
def __exit__(self, type, value, tb):
30+
def __exit__(self, type, value, tb) -> None:
3131
write_lock = getattr(self, "write_lock", None)
3232
if write_lock is not None:
3333
if isinstance(write_lock, list):

dictdatabase/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def seek_index_through_value_bytes(json_bytes: bytes, index: int) -> int:
8383
while True:
8484
i = json_bytes.find(byte_codes.QUOTE, i + 1)
8585
if i == -1:
86-
raise TypeError("Invalid JSON")
86+
msg = "Invalid JSON"
87+
raise TypeError(msg)
8788

8889
j = i - 1
8990
backslash_count = 0
@@ -173,9 +174,9 @@ def find_outermost_key_in_json_bytes(json_bytes: bytes, key: str) -> Tuple[int,
173174
# TODO: Very strict. the key must have a colon directly after it
174175
# For example {"a": 1} will work, but {"a" : 1} will not work!
175176

176-
key = f'"{key}":'.encode()
177+
key_bytes = f'"{key}":'.encode()
177178

178-
if (curr_i := json_bytes.find(key, 0)) == -1:
179+
if (curr_i := json_bytes.find(key_bytes, 0)) == -1:
179180
return (-1, -1)
180181

181182
# Assert: Key was found and curr_i is the index of the first character of the key
@@ -184,8 +185,8 @@ def find_outermost_key_in_json_bytes(json_bytes: bytes, key: str) -> Tuple[int,
184185
key_nest = [(curr_i, count_nesting_in_bytes(json_bytes, 0, curr_i))]
185186

186187
# As long as more keys are found, keep track of them and their nesting level
187-
while (next_i := json_bytes.find(key, curr_i + len(key))) != -1:
188-
nesting = count_nesting_in_bytes(json_bytes, curr_i + len(key), next_i)
188+
while (next_i := json_bytes.find(key_bytes, curr_i + len(key_bytes))) != -1:
189+
nesting = count_nesting_in_bytes(json_bytes, curr_i + len(key_bytes), next_i)
189190
key_nest.append((next_i, nesting))
190191
curr_i = next_i
191192

@@ -195,7 +196,7 @@ def find_outermost_key_in_json_bytes(json_bytes: bytes, key: str) -> Tuple[int,
195196
# Early exit if there is only one key
196197
if len(key_nest) == 1:
197198
index, level = key_nest[0]
198-
return (index, index + len(key)) if level == 1 else (-1, -1)
199+
return (index, index + len(key_bytes)) if level == 1 else (-1, -1)
199200

200201
# Relative to total nesting
201202
for i in range(1, len(key_nest)):
@@ -205,7 +206,7 @@ def find_outermost_key_in_json_bytes(json_bytes: bytes, key: str) -> Tuple[int,
205206
indices_at_index_one = [i for i, level in key_nest if level == 1]
206207
if len(indices_at_index_one) != 1:
207208
return (-1, -1)
208-
return (indices_at_index_one[0], indices_at_index_one[0] + len(key))
209+
return (indices_at_index_one[0], indices_at_index_one[0] + len(key_bytes))
209210

210211

211212
def detect_indentation_in_json_bytes(json_bytes: bytes, index: int) -> Tuple[int, str]:

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,15 @@ ignore = [
5757
"ANN201", # Missing return type annotation for public function
5858
"ARG001", # Unused function argument
5959
"T201", # `print` found
60+
"ERA001", # Found commented-out code
6061
]
6162

6263

6364
[tool.ruff.format]
6465
indent-style = "tab"
6566
quote-style = "double"
67+
68+
69+
70+
[tool.mypy]
71+
ignore_missing_imports = true

tests/benchmark/run_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import json
24
import os
35
import random
@@ -6,6 +8,7 @@
68
from calendar import c
79
from dataclasses import dataclass
810
from multiprocessing import Pool
11+
from typing import Callable
912

1013
from path_dict import PathDict
1114

@@ -14,7 +17,7 @@
1417
DDB.config.storage_directory = ".ddb_bench_multi"
1518

1619

17-
def benchmark(iterations, setup: callable = None):
20+
def benchmark(iterations, setup: Callable | None = None):
1821
def decorator(function):
1922
def wrapper(*args, **kwargs):
2023
f_name = function.__name__

tests/system_checks/test_monotonic_over_threads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
}
1818

1919
# Queue to store timestamps in order
20-
timestamps = queue.Queue()
20+
timestamps: queue.Queue = queue.Queue()
2121

2222

2323
def capture_time(i, clock_func: Callable) -> None:

0 commit comments

Comments
 (0)