Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions aiorwlock/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import asyncio
import threading
from collections import deque
from typing import Any, Deque, List, Tuple
from typing import Any, Deque, List, Tuple, Optional

__all__ = ('RWLock', '__version__')
__all__ = ("RWLock", "__version__")


def __getattr__(name: str) -> object:
Expand All @@ -28,16 +28,31 @@ def __getattr__(name: str) -> object:
class _RWLockCore:
_RL = 1
_WL = 2
_loop = None

def __init__(self, fast: bool):


__slots__ = (
"_do_yield",
"_read_waiters",
"_write_waiters",
"_r_state",
"_w_state",
"_owning",
"_loop"
)

def __init__(self, fast: bool) -> None:
self._do_yield = not fast
self._read_waiters: Deque[asyncio.Future[None]] = deque()
self._write_waiters: Deque[asyncio.Future[None]] = deque()
self._r_state: int = 0
self._w_state: int = 0
# tasks will be few, so a list is not inefficient
self._owning: List[Tuple[asyncio.Task[Any], int]] = []
self._loop: Optional[asyncio.AbstractEventLoop] = None

# TODO: There is a Bug when different Loops are in use with using RWLocks
# However this might have to do with version differences
# SEE: https://github.com/aio-libs/aiorwlock/issues/468

def _get_loop(self) -> asyncio.AbstractEventLoop:
"""
Expand All @@ -51,7 +66,7 @@ def _get_loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = loop
if loop is not self._loop:
raise RuntimeError(f'{self!r} is bound to a different event loop')
raise RuntimeError(f"{self!r} is bound to a different event loop")
return loop

@property
Expand Down Expand Up @@ -82,11 +97,7 @@ async def acquire_read(self) -> bool:
await self._yield_after_acquire(self._RL)
return True

if (
not self._write_waiters
and self._r_state >= 0
and self._w_state == 0
):
if not self._write_waiters and self._r_state >= 0 and self._w_state == 0:
self._r_state += 1
self._owning.append((me, self._RL))
await self._yield_after_acquire(self._RL)
Expand Down Expand Up @@ -120,7 +131,7 @@ async def acquire_write(self) -> bool:
return True
elif (me, self._RL) in self._owning:
if self._r_state > 0:
raise RuntimeError('Cannot upgrade RWLock from read to write')
raise RuntimeError("Cannot upgrade RWLock from read to write")

if self._r_state == 0 and self._w_state == 0:
self._w_state += 1
Expand Down Expand Up @@ -157,7 +168,7 @@ def _release(self, lock_type: int) -> None:
try:
self._owning.remove((me, lock_type))
except ValueError as exc:
raise RuntimeError('Cannot release an un-acquired lock') from exc
raise RuntimeError("Cannot release an un-acquired lock") from exc
if lock_type == self._RL:
self._r_state -= 1
else:
Expand Down Expand Up @@ -186,10 +197,11 @@ def _wake_up(self) -> None:


class _ContextManagerMixin:
__slots__ = ()

def __enter__(self) -> None:
raise RuntimeError(
'"await" should be used as context manager expression'
)
# TODO: This Error Should really hint at either "async with" or "await" instead of just "await"
raise RuntimeError('"await" should be used as context manager expression')

def __exit__(self, *args: Any) -> None:
# This must exist because __enter__ exists, even though that
Expand All @@ -214,41 +226,51 @@ def release(self) -> None:

# Lock objects to access the _RWLockCore in reader or writer mode
class _ReaderLock(_ContextManagerMixin):
__slots__ = ("_lock",)

def __init__(self, lock: _RWLockCore) -> None:
self._lock = lock

@property
def locked(self) -> bool:
"""Determines wether or not the reader lock is owned"""
return self._lock.read_locked

async def acquire(self) -> None:
"""acquires the reading lock"""
await self._lock.acquire_read()

def release(self) -> None:
"""releases the reading lock"""
self._lock.release_read()

def __repr__(self) -> str:
status = 'locked' if self._lock._r_state > 0 else 'unlocked'
return f'<ReaderLock: [{status}]>'
status = "locked" if self._lock._r_state > 0 else "unlocked"
return f"<ReaderLock: [{status}]>"


class _WriterLock(_ContextManagerMixin):
def __init__(self, lock: _RWLockCore):
__slots__ = ("_lock",)

def __init__(self, lock: _RWLockCore) -> None:
self._lock = lock

@property
def locked(self) -> bool:
"""determines wether or not writing lock is owned"""
return self._lock.write_locked

async def acquire(self) -> None:
"""acquires the writing lock"""
await self._lock.acquire_write()

def release(self) -> None:
"""releases the writing lock"""
self._lock.release_write()

def __repr__(self) -> str:
status = 'locked' if self._lock._w_state > 0 else 'unlocked'
return f'<WriterLock: [{status}]>'
status = "locked" if self._lock._w_state > 0 else "unlocked"
return f"<WriterLock: [{status}]>"


class RWLock:
Expand All @@ -260,6 +282,8 @@ class RWLock:

core = _RWLockCore

__slots__ = ("_reader_lock", "_writer_lock",)

def __init__(self, *, fast: bool = False) -> None:
core = self.core(fast)
self._reader_lock = _ReaderLock(core)
Expand All @@ -282,4 +306,4 @@ def writer(self) -> _WriterLock:
def __repr__(self) -> str:
rl = self.reader_lock.__repr__()
wl = self.writer_lock.__repr__()
return f'<RWLock: {rl} {wl}>'
return f"<RWLock: {rl} {wl}>"
6 changes: 3 additions & 3 deletions examples/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import aiorwlock


async def go():
async def go() -> None:
rwlock = aiorwlock.RWLock()

# acquire reader lock
async with rwlock.reader_lock:
print('inside reader lock')
print("inside reader lock")
await asyncio.sleep(0.1)

# acquire writer lock
async with rwlock.writer_lock:
print('inside writer lock')
print("inside writer lock")
await asyncio.sleep(0.1)


Expand Down
6 changes: 3 additions & 3 deletions examples/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import aiorwlock


async def go():
async def go() -> None:
rwlock = aiorwlock.RWLock()

# acquire reader lock
await rwlock.reader_lock.acquire()
try:
print('inside reader lock')
print("inside reader lock")

await asyncio.sleep(0.1)
finally:
Expand All @@ -18,7 +18,7 @@ async def go():
# acquire writer lock
await rwlock.writer_lock.acquire()
try:
print('inside writer lock')
print("inside writer lock")

await asyncio.sleep(0.1)
finally:
Expand Down
12 changes: 5 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio

import pytest
import pytest_asyncio


@pytest.fixture(scope='module', params=[True, False], ids=['fast', 'slow'])
def fast_track(request):
@pytest.fixture(scope="module", params=[True, False], ids=["fast", "slow"])
def fast_track(request: pytest.FixtureRequest) -> bool:
return request.param


@pytest.fixture
def loop(event_loop):
return event_loop
46 changes: 25 additions & 21 deletions tests/test_corner_cases.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import contextlib
from typing import Any, Generator

import pytest

Expand All @@ -9,7 +10,8 @@


@contextlib.contextmanager
def should_fail(timeout, loop):
def should_fail(timeout: float) -> Generator[None, Any, None]:
loop = asyncio.get_running_loop()
task = asyncio.current_task(loop)

handle = loop.call_later(timeout, task.cancel)
Expand All @@ -19,12 +21,12 @@ def should_fail(timeout, loop):
handle.cancel()
return
else:
msg = f'Inner task expected to be cancelled: {task}'
msg = f"Inner task expected to be cancelled: {task}"
pytest.fail(msg)


@pytest.mark.asyncio
async def test_get_write_then_read(loop):
async def test_get_write_then_read() -> None:
rwlock = RWLock()

rl = rwlock.reader
Expand All @@ -39,11 +41,11 @@ async def test_get_write_then_read(loop):


@pytest.mark.asyncio
async def test_get_write_then_read_and_write_again(loop):
async def test_get_write_then_read_and_write_again() -> None:
loop = asyncio.get_event_loop()
rwlock = RWLock()
rl = rwlock.reader
wl = rwlock.writer

f = loop.create_future()
writes = []

Expand All @@ -52,7 +54,7 @@ async def get_write_lock():
with should_fail(0.1, loop):
async with wl:
assert wl.locked
writes.append('should not be here')
writes.append("should not be here")

ensure_future(get_write_lock())

Expand All @@ -68,7 +70,8 @@ async def get_write_lock():


@pytest.mark.asyncio
async def test_writers_deadlock(loop):
async def test_writers_deadlock() -> None:
loop = asyncio.get_event_loop()
rwlock = RWLock()
rl = rwlock.reader
wl = rwlock.writer
Expand All @@ -84,7 +87,7 @@ async def test_writers_deadlock(loop):
# See asyncio.Lock deadlock issue:
# https://github.com/python/cpython/pull/1031

async def coro():
async def coro() -> None:
async with wl:
assert wl.locked
await asyncio.sleep(0.2, loop)
Expand All @@ -106,12 +109,13 @@ async def coro():


@pytest.mark.asyncio
async def test_readers_cancel(loop):
async def test_readers_cancel() -> None:
loop = asyncio.get_event_loop()
rwlock = RWLock()
rl = rwlock.reader
wl = rwlock.writer

async def coro(lock):
async def coro(lock: RWLock):
async with lock:
assert lock.locked
await asyncio.sleep(0.2, loop)
Expand All @@ -132,11 +136,11 @@ async def coro(lock):


@pytest.mark.asyncio
async def test_canceled_inside_acquire(loop):
async def test_canceled_inside_acquire() -> None:
rwlock = RWLock()
rl = rwlock.reader

async def coro(lock):
async def coro(lock: RWLock):
async with lock:
pass

Expand All @@ -153,24 +157,24 @@ async def coro(lock):


@pytest.mark.asyncio
async def test_race_multiple_writers(loop):
async def test_race_multiple_writers() -> None:
seq = []

async def write_wait(lock):
async def write_wait(lock: RWLock):
async with lock.reader:
await asyncio.sleep(0.1)
seq.append('READ')
seq.append("READ")
async with lock.writer:
seq.append('START1')
seq.append("START1")
await asyncio.sleep(0.1)
seq.append('FIN1')
seq.append("FIN1")

async def write(lock):
async def write(lock: RWLock):
async with lock.writer:
seq.append('START2')
seq.append("START2")
await asyncio.sleep(0.1)
seq.append('FIN2')
seq.append("FIN2")

lock = RWLock(fast=True)
await asyncio.gather(write_wait(lock), write(lock))
assert seq == ['READ', 'START2', 'FIN2', 'START1', 'FIN1']
assert seq == ["READ", "START2", "FIN2", "START1", "FIN1"]
Loading