Skip to content
Snippets Groups Projects

Async guard

Merged Mateusz Żebrak requested to merge mzebrak/async-guard into develop
4 files
+ 52
29
Compare changes
  • Side-by-side
  • Inline
Files
4
from __future__ import annotations
import asyncio
import contextlib
from typing import TYPE_CHECKING, Final, Generator
from typing import AsyncGenerator, Final
from clive.exceptions import CliveError
from textual.rlock import RLock
if TYPE_CHECKING:
from types import TracebackType
from clive.exceptions import CliveError
class AsyncGuardNotAvailableError(CliveError):
@@ -21,33 +19,36 @@ class AsyncGuardNotAvailableError(CliveError):
class AsyncGuard:
"""
A helper class to manage an asynchronous event-like lock, ensuring exclusive execution.
Lock-like guard, ensuring exclusive execution.
Use this for scenarios where you want to prevent concurrent execution of an async task.
When the guard is acquired by some other task, the guarded block could not execute, error will be raised instead.
When the guard is acquired by some other task, the guarded block won't be executed.
Can be used to wait until the guard is available again or for skipping code execution.
Can be used together with `suppress`. Look into its documentation for more details.
Usage:
```
async_guard = AsyncGuard()
with async_guard:
async with async_guard.guard():
# Protected code which shouldn't be executed concurrently
```
"""
# Will run later, when the guard is available again
def __init__(self) -> None:
self._event = asyncio.Event()
async with async_guard.guard_or_error():
# Protected code which shouldn't be executed concurrently
# Will raise an error if the guard is already acquired
def __enter__(self) -> None:
if not self.is_available:
raise AsyncGuardNotAvailableError
async with async_guard.suppress(), async_guard.guard_or_error():
# Code that should be skipped when guard is acquired, but also should acquire the guard if available
self.acquire()
if async_guard.is_available:
# Code that should be skipped when guard is acquired
```
"""
def __exit__(self, _: type[BaseException] | None, ex: BaseException | None, ___: TracebackType | None) -> None:
self.release()
def __init__(self) -> None:
self._lock = RLock()
@property
def is_available(self) -> bool:
@@ -56,17 +57,39 @@ class AsyncGuard:
Use this to determine if an instruction can proceed with no conflicts.
"""
return not self._event.is_set()
return not self._lock.is_locked
async def acquire(self) -> None:
await self._lock.acquire()
async def acquire_or_error(self) -> None:
if not self.is_available:
raise AsyncGuardNotAvailableError
def acquire(self) -> None:
self._event.set()
await self.acquire()
def release(self) -> None:
self._event.clear()
self._lock.release()
@contextlib.asynccontextmanager
async def guard(self) -> AsyncGenerator[None]:
await self.acquire()
try:
yield
finally:
self.release()
@contextlib.asynccontextmanager
async def guard_or_error(self) -> AsyncGenerator[None]:
await self.acquire_or_error()
try:
yield
finally:
self.release()
@staticmethod
@contextlib.contextmanager
def suppress() -> Generator[None]:
@contextlib.asynccontextmanager
async def suppress() -> AsyncGenerator[None]:
"""
Suppresses the AsyncGuardNotAvailable error raised by the guard.
@@ -76,7 +99,7 @@ class AsyncGuard:
```
async_guard = AsyncGuard()
with async_guard.suppress(), async_guard:
async with async_guard.suppress(), async_guard.guard_or_error():
# Code that should be skipped when guard is acquired
```
"""
Loading