Skip to content
Snippets Groups Projects

Async guard

Merged Mateusz Żebrak requested to merge mzebrak/async-guard into develop
4 files
+ 52
29
Compare changes
  • Side-by-side
  • Inline
Files
4
from __future__ import annotations
from __future__ import annotations
import asyncio
import contextlib
import contextlib
from typing import TYPE_CHECKING, Final, Generator
from typing import AsyncGenerator, Final
from clive.exceptions import CliveError
from textual.rlock import RLock
if TYPE_CHECKING:
from clive.exceptions import CliveError
from types import TracebackType
class AsyncGuardNotAvailableError(CliveError):
class AsyncGuardNotAvailableError(CliveError):
@@ -21,33 +19,36 @@ class AsyncGuardNotAvailableError(CliveError):
@@ -21,33 +19,36 @@ class AsyncGuardNotAvailableError(CliveError):
class AsyncGuard:
class AsyncGuard:
"""
"""
A helper class to manage an asynchronous event-like lock, ensuring exclusive execution.
Lock-like guard, ensuring exclusive execution.
Use this for scenarios where you want to prevent concurrent execution of an async task.
Use this for scenarios where you want to prevent concurrent execution of an async task.
When the guard is acquired by some other task, the guarded block could not execute, error will be raised instead.
When the guard is acquired by some other task, the guarded block won't be executed.
 
Can be used to wait until the guard is available again or for skipping code execution.
Can be used together with `suppress`. Look into its documentation for more details.
Can be used together with `suppress`. Look into its documentation for more details.
Usage:
Usage:
```
```
async_guard = AsyncGuard()
async_guard = AsyncGuard()
with async_guard:
async with async_guard.guard():
# Protected code which shouldn't be executed concurrently
# Protected code which shouldn't be executed concurrently
```
# Will run later, when the guard is available again
"""
def __init__(self) -> None:
async with async_guard.guard_or_error():
self._event = asyncio.Event()
# Protected code which shouldn't be executed concurrently
 
# Will raise an error if the guard is already acquired
def __enter__(self) -> None:
async with async_guard.suppress(), async_guard.guard_or_error():
if not self.is_available:
# Code that should be skipped when guard is acquired, but also should acquire the guard if available
raise AsyncGuardNotAvailableError
self.acquire()
if async_guard.is_available:
 
# Code that should be skipped when guard is acquired
 
```
 
"""
def __exit__(self, _: type[BaseException] | None, ex: BaseException | None, ___: TracebackType | None) -> None:
def __init__(self) -> None:
self.release()
self._lock = RLock()
@property
@property
def is_available(self) -> bool:
def is_available(self) -> bool:
@@ -56,17 +57,39 @@ class AsyncGuard:
@@ -56,17 +57,39 @@ class AsyncGuard:
Use this to determine if an instruction can proceed with no conflicts.
Use this to determine if an instruction can proceed with no conflicts.
"""
"""
return not self._event.is_set()
return not self._lock.is_locked
 
 
async def acquire(self) -> None:
 
await self._lock.acquire()
 
 
async def acquire_or_error(self) -> None:
 
if not self.is_available:
 
raise AsyncGuardNotAvailableError
def acquire(self) -> None:
await self.acquire()
self._event.set()
def release(self) -> None:
def release(self) -> None:
self._event.clear()
self._lock.release()
 
 
@contextlib.asynccontextmanager
 
async def guard(self) -> AsyncGenerator[None]:
 
await self.acquire()
 
try:
 
yield
 
finally:
 
self.release()
 
 
@contextlib.asynccontextmanager
 
async def guard_or_error(self) -> AsyncGenerator[None]:
 
await self.acquire_or_error()
 
try:
 
yield
 
finally:
 
self.release()
@staticmethod
@staticmethod
@contextlib.contextmanager
@contextlib.asynccontextmanager
def suppress() -> Generator[None]:
async def suppress() -> AsyncGenerator[None]:
"""
"""
Suppresses the AsyncGuardNotAvailable error raised by the guard.
Suppresses the AsyncGuardNotAvailable error raised by the guard.
@@ -76,7 +99,7 @@ class AsyncGuard:
@@ -76,7 +99,7 @@ class AsyncGuard:
```
```
async_guard = AsyncGuard()
async_guard = AsyncGuard()
with async_guard.suppress(), async_guard:
async with async_guard.suppress(), async_guard.guard_or_error():
# Code that should be skipped when guard is acquired
# Code that should be skipped when guard is acquired
```
```
"""
"""
Loading