diff --git a/beekeepy/beekeepy/_communication/abc/overseer.py b/beekeepy/beekeepy/_communication/abc/overseer.py index 1dbf7548b44d20e8c9f4e6bad301b755c842d51e..b47b9c23fce307664f2b2673fe25a0716724f7b0 100644 --- a/beekeepy/beekeepy/_communication/abc/overseer.py +++ b/beekeepy/beekeepy/_communication/abc/overseer.py @@ -4,9 +4,9 @@ import asyncio import json import time from abc import ABC, abstractmethod -from enum import IntEnum from typing import TYPE_CHECKING, Any, Callable, ClassVar, Sequence +from beekeepy._communication.abc.rules import ContinueMode from beekeepy._utilities.context import SelfContextSync from beekeepy.exceptions import GroupedErrorsError, Json, UnknownDecisionPathError @@ -23,12 +23,6 @@ if TYPE_CHECKING: __all__ = ["AbstractOverseer"] -class ContinueMode(IntEnum): - BREAK = 0 - CONTINUE = 1 - INF = 2 - - class _OverseerExceptionManager(SelfContextSync): """This class should be considered as part of AbstractOverseer. @@ -42,6 +36,7 @@ class _OverseerExceptionManager(SelfContextSync): super().__init__() self._owner = owner self._rules = rules + self._exception_rules = rules.grouped_exceptions() self._exceptions: Sequence[OverseerError] = [] self._last_status = ContinueMode.INF self._counter = 0 @@ -85,7 +80,17 @@ class _OverseerExceptionManager(SelfContextSync): return self.CONTINUE_LOOP def should_sleep(self) -> bool: - return len(self._exceptions) > 0 + if len(self._exceptions) == 0: + # No exceptions, no need to sleep + return False + for exception in self._exceptions: + if type(exception) in self._exception_rules.preliminary: + # Preliminary exception, no need to sleep + return False + if self._counter <= 0 and type(exception) in self._exception_rules.finitely_repeatable: + # Finitely repeatable exception, but no retries left, no need to sleep + return False + return True def _finally(self) -> None: if not self._response_read_or_exception_occurred: @@ -154,21 +159,17 @@ class AbstractOverseer(ABC): ) -> tuple[list[OverseerError], ContinueMode]: exceptions: list[OverseerError] = [] - for rules_category, status in ( - (rules.infinitely_repeatable, ContinueMode.INF), - (rules.preliminary, ContinueMode.BREAK), - (rules.finitely_repeatable, ContinueMode.CONTINUE), - ): - for rule in rules_category: - exceptions_to_add = rule.check(response=response, response_raw=response_raw) - exceptions.extend(exceptions_to_add) - for ex in exceptions_to_add: - if not ex.retry(): - return (exceptions, status) - - if bool(exceptions): + for rule, status in rules.resolved_rules(): + exceptions_to_add = rule.check(response=response, response_raw=response_raw) + exceptions.extend(exceptions_to_add) + for ex in exceptions_to_add: + if not ex.retry(): return (exceptions, status) - return ([], ContinueMode.CONTINUE) + + if bool(exceptions): + return (exceptions, status) + + return (exceptions, ContinueMode.CONTINUE) def _parse(self, response: str) -> Json | list[Json] | Exception: response_parsed: Json | list[Json] | None = None diff --git a/beekeepy/beekeepy/_communication/abc/rules.py b/beekeepy/beekeepy/_communication/abc/rules.py index 5ee395b2f19b18ad1950f62807ce1e7728fc138f..eebb1ad41915e449098bf52898360a6bc514be2a 100644 --- a/beekeepy/beekeepy/_communication/abc/rules.py +++ b/beekeepy/beekeepy/_communication/abc/rules.py @@ -2,13 +2,20 @@ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Sequence +from enum import IntEnum +from typing import TYPE_CHECKING, Iterator, Sequence if TYPE_CHECKING: from beekeepy._communication.url import HttpUrl from beekeepy.exceptions import Json, OverseerError +class ContinueMode(IntEnum): + BREAK = 0 + CONTINUE = 1 + INF = 2 + + @dataclass(kw_only=True) class RulesClassifier: preliminary: Sequence[type[OverseerRule]] @@ -29,10 +36,47 @@ class RulesClassifier: @dataclass(kw_only=True) +class RulesExceptions: + preliminary: Sequence[type[OverseerError]] + infinitely_repeatable: Sequence[type[OverseerError]] + finitely_repeatable: Sequence[type[OverseerError]] + + +@dataclass(kw_only=True, frozen=True) class Rules: preliminary: Sequence[OverseerRule] + """If rules in this category detect error, retries won't be attempted.""" + infinitely_repeatable: Sequence[OverseerRule] + """If rules in this category detect error, retries will be attempted indefinitely.""" + finitely_repeatable: Sequence[OverseerRule] + """If rules in this category detect error, retries will be attempted up to settings.max_retries times.""" + + def resolved_rules(self) -> Iterator[tuple[OverseerRule, ContinueMode]]: + """Yields all rules in proper order with their associated ContinueMode.""" + from beekeepy._communication.rules import ErrorInResponse + + fallback_rule: tuple[OverseerRule, ContinueMode] | None = None + for rule, mode in ( + *((rule, ContinueMode.INF) for rule in self.infinitely_repeatable), + *((rule, ContinueMode.BREAK) for rule in self.preliminary), + *((rule, ContinueMode.CONTINUE) for rule in self.finitely_repeatable), + ): + if type(rule) is ErrorInResponse: + fallback_rule = (rule, mode) + continue + yield rule, mode + + assert fallback_rule is not None, "ErrorInResponse rule was not found among rules" + yield fallback_rule + + def grouped_exceptions(self) -> RulesExceptions: + return RulesExceptions( + preliminary=[rule.expected_exception() for rule in self.preliminary], + infinitely_repeatable=[rule.expected_exception() for rule in self.infinitely_repeatable], + finitely_repeatable=[rule.expected_exception() for rule in self.finitely_repeatable], + ) class OverseerRule(ABC): @@ -76,15 +120,19 @@ class OverseerRule(ABC): assert isinstance(self.request, dict), f"self.request is not a dict, nor list, but is `{type(self.request)}`" return self.request + @classmethod + @abstractmethod + def expected_exception(cls) -> type[OverseerError]: + """Overload this method to specify which exception should be raised in case of detection.""" + def _construct_exception( self, - error_cls: type[OverseerError], response: Json | list[Json] | Exception, whole_response: Json | list[Json], request_id: int | None, message: str = "", ) -> OverseerError: - return error_cls( + return self.expected_exception()( url=self.url, request=self.request, request_id=request_id, diff --git a/beekeepy/beekeepy/_communication/overseers.py b/beekeepy/beekeepy/_communication/overseers.py index bb7eab6022c7534282139cc5c0f4fd3b5e12df15..8ce5a5abd548b8ceb3608cb020fb570e1b9c3299 100644 --- a/beekeepy/beekeepy/_communication/overseers.py +++ b/beekeepy/beekeepy/_communication/overseers.py @@ -31,6 +31,7 @@ class CommonOverseer(AbstractOverseer): UnableToOpenWallet, InvalidPassword, UnlockIsNotAccessible, + ErrorInResponse, ], infinitely_repeatable=[ UnableToAcquireDatabaseLock, @@ -41,7 +42,6 @@ class CommonOverseer(AbstractOverseer): JussiResponse, DifferenceBetweenAmountOfRequestsAndResponses, NullResult, - ErrorInResponse, ], ) diff --git a/beekeepy/beekeepy/_communication/rules.py b/beekeepy/beekeepy/_communication/rules.py index 4601c37f0f1497dc0e83185448aaeeee99f7c548..a01c81e94af5da4be13590aa4a708205aaea58e5 100644 --- a/beekeepy/beekeepy/_communication/rules.py +++ b/beekeepy/beekeepy/_communication/rules.py @@ -37,7 +37,6 @@ class UnableToAcquireDatabaseLock(OverseerRule): if self.LOOKUP_MESSAGE in str(parsed_response): return [ self._construct_exception( - error_cls=UnableToAcquireDatabaseLockError, request_id=parsed_response.get("id"), response=parsed_response, message=f"Found `{self.LOOKUP_MESSAGE}` in response", @@ -46,6 +45,10 @@ class UnableToAcquireDatabaseLock(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return UnableToAcquireDatabaseLockError + class UnableToAcquireForkdbLock(OverseerRule): LOOKUP_MESSAGE: ClassVar[str] = "Unable to acquire forkdb lock" @@ -54,7 +57,6 @@ class UnableToAcquireForkdbLock(OverseerRule): if self.LOOKUP_MESSAGE in str(parsed_response): return [ self._construct_exception( - error_cls=UnableToAcquireForkdbLockError, message=f"Found `{self.LOOKUP_MESSAGE}` in response", response=parsed_response, request_id=parsed_response.get("id"), @@ -63,6 +65,10 @@ class UnableToAcquireForkdbLock(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return UnableToAcquireForkdbLockError + class NullResult(OverseerRule): def _check_single(self, parsed_response: Json, whole_response: Json | list[Json]) -> list[OverseerError]: @@ -73,7 +79,6 @@ class NullResult(OverseerRule): return [ self._construct_exception( - error_cls=NullResultError, message="`result` field in response is null", response=parsed_response, request_id=request_id, @@ -91,6 +96,10 @@ class NullResult(OverseerRule): "condenser_api.get_escrow", ] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return NullResultError + class ApiNotFound(OverseerRule): _API_NOT_FOUND_REGEX: ClassVar[re.Pattern[str]] = re.compile( @@ -102,7 +111,6 @@ class ApiNotFound(OverseerRule): if search_result is not None: return [ self._construct_exception( - error_cls=ApiNotFoundError, message=f"Requested api not found: {search_result.group(1)}", response=parsed_response, request_id=parsed_response.get("id"), @@ -111,13 +119,16 @@ class ApiNotFound(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return ApiNotFoundError + class JussiResponse(OverseerRule): def _check_single(self, parsed_response: Json, whole_response: Json | list[Json]) -> list[OverseerError]: if "jussi_request_id" in str(parsed_response): return [ self._construct_exception( - error_cls=JussiResponseError, message="Jussi responded instead of target service", response=parsed_response, request_id=parsed_response.get("id"), @@ -126,13 +137,16 @@ class JussiResponse(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return JussiResponseError + class UnparsableResponse(OverseerRule): def _check_non_json_response(self, parsed_response: Exception, response_raw: str) -> list[OverseerError]: if isinstance(parsed_response, json.JSONDecodeError): return [ self._construct_exception( - error_cls=UnparsableResponseError, message=( "Received response is not parsable, " f"probably plaintext or invalid json: {response_raw}" ), @@ -150,12 +164,15 @@ class UnparsableResponse(OverseerRule): ) -> list[OverseerError]: return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return UnparsableResponseError + class DifferenceBetweenAmountOfRequestsAndResponses(OverseerRule): def _check_batch(self, parsed_response: list[Json]) -> list[OverseerError]: def exception_factory(msg: str) -> OverseerError: return self._construct_exception( - error_cls=DifferenceBetweenAmountOfRequestsAndResponsesError, message=msg, response=parsed_response, whole_response=parsed_response, @@ -184,13 +201,16 @@ class DifferenceBetweenAmountOfRequestsAndResponses(OverseerRule): def _check_single(self, parsed_response: Json, whole_response: Json | list[Json]) -> list[OverseerError]: # noqa: ARG002 return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return DifferenceBetweenAmountOfRequestsAndResponsesError + class ErrorInResponse(OverseerRule): def _check_single(self, parsed_response: Json, whole_response: Json | list[Json]) -> list[OverseerError]: if (error := parsed_response.get("error")) is not None: return [ self._construct_exception( - error_cls=ErrorInResponseError, message=f"Error found in response: {error=}", response=parsed_response, request_id=parsed_response.get("id"), @@ -199,13 +219,16 @@ class ErrorInResponse(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return ErrorInResponseError + class UnlockIsNotAccessible(OverseerRule): def _check_single(self, parsed_response: Json, whole_response: Json | list[Json]) -> list[OverseerError]: if "unlock is not accessible" in str(parsed_response): return [ self._construct_exception( - error_cls=UnlockIsNotAccessibleError, message="You tried to unlock wallet too fast", response=parsed_response, request_id=parsed_response.get("id"), @@ -214,6 +237,10 @@ class UnlockIsNotAccessible(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return UnlockIsNotAccessibleError + class WalletIsAlreadyUnlocked(OverseerRule): _WALLET_IS_ALREADY_UNLOCKED_REGEXES: ClassVar[list[re.Pattern[str]]] = [ @@ -226,7 +253,6 @@ class WalletIsAlreadyUnlocked(OverseerRule): if (match := regex.search(str(parsed_response))) is not None: return [ self._construct_exception( - error_cls=WalletIsAlreadyUnlockedError, message=f"You tried to unlock already unlocked wallet: `{match.group(1)}`", response=parsed_response, request_id=parsed_response.get("id"), @@ -235,6 +261,10 @@ class WalletIsAlreadyUnlocked(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return WalletIsAlreadyUnlockedError + class UnableToOpenWallet(OverseerRule): _UNABLE_TO_OPEN_WALLET_REGEX: ClassVar[re.Pattern[str]] = re.compile( @@ -246,7 +276,6 @@ class UnableToOpenWallet(OverseerRule): if (match := self._UNABLE_TO_OPEN_WALLET_REGEX.search(str(parsed_response))) is not None: return [ self._construct_exception( - error_cls=UnableToOpenWalletError, message=f"No such wallet: {match.group(1)}", response=parsed_response, request_id=parsed_response.get("id"), @@ -255,6 +284,10 @@ class UnableToOpenWallet(OverseerRule): ] return [] + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return UnableToOpenWalletError + class InvalidPassword(OverseerRule): _INVALID_PASSWORD_REGEX: ClassVar[re.Pattern[str]] = re.compile( @@ -267,7 +300,6 @@ class InvalidPassword(OverseerRule): ) is not None: return [ self._construct_exception( - error_cls=OverseerInvalidPasswordError, message=f"Invalid password for wallet: {match.group(1)}", response=parsed_response, request_id=parsed_response.get("id"), @@ -275,3 +307,7 @@ class InvalidPassword(OverseerRule): ) ] return [] + + @classmethod + def expected_exception(cls) -> type[OverseerError]: + return OverseerInvalidPasswordError diff --git a/tests/beekeepy_test/communicator/test_overseer.py b/tests/beekeepy_test/communicator/test_overseer.py new file mode 100644 index 0000000000000000000000000000000000000000..014ac831a71cb8303aaa1f45ee4411fad97da25e --- /dev/null +++ b/tests/beekeepy_test/communicator/test_overseer.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +import pytest +from local_tools.beekeepy.testing_server import run_simple_server + +from beekeepy.communication import ( + CommonOverseer, + CommunicationSettings, + StrictOverseer, + get_communicator_cls, +) +from beekeepy.exceptions import ( + ApiNotFoundError, + JussiResponseError, + NullResultError, + OverseerError, + UnparsableResponseError, +) + +if TYPE_CHECKING: + from beekeepy.communication import AbstractCommunicator, AbstractOverseer + + +ERRORS_TO_DETECT: Final[list[tuple[type[OverseerError], str]]] = [ + ( + NullResultError, + """{"jsonrpc": "2.0", "result": null, "id": 1}""", + ), + ( + ApiNotFoundError, + """{"jsonrpc": "2.0", "error": {"code": -32003, "message": + "Assert Exception:api_itr != data._registered_apis.end(): Could not find API debug_node_api" + }, "id": 1}""", + ), + ( + JussiResponseError, + """{"jsonrpc":"2.0","id":null,"error":{"code":-32603,"message": + "Internal Error","data":{"error_id":"b6384d8c-95ad-4af0-92dc-dd7828d3c707", + "jussi_request_id":"000312363819934224"}}}""", + ), + (UnparsableResponseError, """404: Not Found"""), +] + +SYNC_COMMUNICATORS: Final[list[type[AbstractCommunicator]]] = [ + get_communicator_cls("request"), + get_communicator_cls("httpx"), +] +ASYNC_COMMUNICATORS: Final[list[type[AbstractCommunicator]]] = [ + get_communicator_cls("aiohttp"), + get_communicator_cls("httpx"), +] +OVERSEERS: Final[list[type[AbstractOverseer]]] = [CommonOverseer, StrictOverseer] + +REQUEST: Final[str] = """{"method": "aaa", "id": 1, "jsonrpc": "2.0"}""" + + +@pytest.mark.parametrize("error_and_message", ERRORS_TO_DETECT) +@pytest.mark.parametrize("overseer_cls", OVERSEERS) +@pytest.mark.parametrize("communicator", SYNC_COMMUNICATORS) +def test_sync_overseer( + error_and_message: tuple[type[OverseerError], str], + overseer_cls: type[AbstractOverseer], + communicator: type[AbstractCommunicator], +) -> None: + error, message = error_and_message + overseer = overseer_cls(communicator=communicator(settings=CommunicationSettings())) + try: + with run_simple_server(message) as url, pytest.raises(error): + overseer.send(url=url, method="POST", data=REQUEST) + finally: + overseer.teardown() + + +@pytest.mark.parametrize("error_and_message", ERRORS_TO_DETECT) +@pytest.mark.parametrize("overseer_cls", OVERSEERS) +@pytest.mark.parametrize("communicator", ASYNC_COMMUNICATORS) +async def test_async_overseer( + error_and_message: tuple[type[OverseerError], str], + overseer_cls: type[AbstractOverseer], + communicator: type[AbstractCommunicator], +) -> None: + error, message = error_and_message + overseer = overseer_cls(communicator=communicator(settings=CommunicationSettings())) + try: + with run_simple_server(message) as url, pytest.raises(error): + await overseer.async_send(url=url, method="POST", data=REQUEST) + finally: + overseer.teardown() diff --git a/tests/local-tools/local_tools/beekeepy/testing_server.py b/tests/local-tools/local_tools/beekeepy/testing_server.py new file mode 100644 index 0000000000000000000000000000000000000000..2f14a5d3e4342fae25f820b82f3d65e40ae63f67 --- /dev/null +++ b/tests/local-tools/local_tools/beekeepy/testing_server.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +import time +from abc import ABC, abstractmethod +from contextlib import contextmanager +from http import HTTPStatus +from threading import Thread +from typing import TYPE_CHECKING + +from aiohttp import web + +from beekeepy.interfaces import HttpUrl, SelfContextAsync + +if TYPE_CHECKING: + from socket import socket + from typing import Any, Iterator + + from typing_extensions import Self + + +class AsyncHttpServerError(Exception): + pass + + +class ServerNotRunningError(AsyncHttpServerError): + def __init__(self) -> None: + super().__init__("Server is not running. Call run() first.") + + +class ServerAlreadyRunningError(AsyncHttpServerError): + def __init__(self) -> None: + super().__init__("Server is already running. Call close() first.") + + +class ServerSetupError(AsyncHttpServerError): + def __init__(self, message: str) -> None: + super().__init__(message) + + +class HttpServerObserver(ABC): + @abstractmethod + async def data_received(self, data: dict[str, Any]) -> None: + """Called when any data is received via PUT method. + + Args: + data: data received as body + + Returns: + Nothing. + """ + + +class AsyncHttpServer(SelfContextAsync): + __ADDRESS = HttpUrl("0.0.0.0:0") + + def __init__(self, observer: HttpServerObserver, notification_endpoint: HttpUrl | None) -> None: + self.__observer = observer + self._app = web.Application() + self.__site: web.TCPSite | None = None + self.__running: bool = False + self.__notification_endpoint = notification_endpoint + self._setup_routes() + + def _setup_routes(self) -> None: + async def handle_put_method(request: web.Request) -> web.Response: + await self.__observer.data_received(await request.json()) + return web.Response(status=HTTPStatus.NO_CONTENT) + + self._app.router.add_route("PUT", "/", handle_put_method) + + @property + def port(self) -> int: + if not self.__site: + raise ServerNotRunningError + server: asyncio.base_events.Server | None = self.__site._server # type: ignore[assignment] + if server is None: + raise ServerSetupError("self.__site.server is None") + + server_socket: socket = server.sockets[0] + address_tuple: tuple[str, int] = server_socket.getsockname() + + if not ( + isinstance(address_tuple, tuple) and isinstance(address_tuple[0], str) and isinstance(address_tuple[1], int) + ): + raise ServerSetupError(f"address_tuple has not recognizable types: {address_tuple}") + + return address_tuple[1] + + async def run(self) -> None: + if self.__site: + raise ServerAlreadyRunningError + + time_between_checks_is_server_running = 0.5 + + runner = web.AppRunner(self._app, access_log=False) + await runner.setup() + address = self.__notification_endpoint or self.__ADDRESS + self.__site = web.TCPSite(runner, address.address, address.port) + await self.__site.start() + self.__running = True + try: + while self.__running: # noqa: ASYNC110 + await asyncio.sleep(time_between_checks_is_server_running) + finally: + await self.__site.stop() + self.__site = None + + def close(self) -> None: + if not self.__site: + raise ServerNotRunningError + self.__running = False + + async def _aenter(self) -> Self: + await self.run() + return self + + async def _afinally(self) -> None: + self.close() + + +class DummyObserver(HttpServerObserver): + async def data_received(self, data: dict[str, Any]) -> None: # noqa: ARG002 + return None + + +class TestAsyncHttpServer(AsyncHttpServer): + def __init__(self, response: str) -> None: + self.__response = response + super().__init__(DummyObserver(), None) + + def _setup_routes(self) -> None: + async def handle_post_method(request: web.Request) -> web.Response: # noqa: ARG001 + return web.Response(text=self.__response) + + self._app.router.add_route("POST", "/", handle_post_method) + + +def create_simple_server(response: str) -> TestAsyncHttpServer: + return TestAsyncHttpServer(response=response) + + +@contextmanager +def run_simple_server(response: str) -> Iterator[HttpUrl]: + server = create_simple_server(response) + + worker = Thread(target=asyncio.run, args=(server.run(),)) + worker.start() + time.sleep(0.5) + + try: + yield HttpUrl(f"http://127.0.0.1:{server.port}") + finally: + server.close() + worker.join()