Skip to content
Snippets Groups Projects
Commit e20ade7c authored by Mateusz Żebrak's avatar Mateusz Żebrak
Browse files

Use AsyncGuard to protect from race condition and crash when screen is

removed
parent 262f7f9a
No related branches found
No related tags found
2 merge requests!601Async guard,!600v1.27.5.21 Release
...@@ -34,16 +34,16 @@ class AppState: ...@@ -34,16 +34,16 @@ class AppState:
self._is_unlocked = True self._is_unlocked = True
if wallets: if wallets:
await self.world.beekeeper_manager.set_wallets(wallets) await self.world.beekeeper_manager.set_wallets(wallets)
self.world.on_going_into_unlocked_mode() await self.world.on_going_into_unlocked_mode()
logger.info("Mode switched to UNLOCKED.") logger.info("Mode switched to UNLOCKED.")
def lock(self, source: LockSource = "unknown") -> None: async def lock(self, source: LockSource = "unknown") -> None:
if not self._is_unlocked: if not self._is_unlocked:
return return
self._is_unlocked = False self._is_unlocked = False
self.world.beekeeper_manager.clear_wallets() self.world.beekeeper_manager.clear_wallets()
self.world.on_going_into_locked_mode(source) await self.world.on_going_into_locked_mode(source)
logger.info("Mode switched to LOCKED.") logger.info("Mode switched to LOCKED.")
def __hash__(self) -> int: def __hash__(self) -> int:
......
...@@ -21,4 +21,4 @@ class Lock(Command): ...@@ -21,4 +21,4 @@ class Lock(Command):
async def _execute(self) -> None: async def _execute(self) -> None:
await self.session.lock_all() await self.session.lock_all()
if self.app_state: if self.app_state:
self.app_state.lock() await self.app_state.lock()
...@@ -59,6 +59,6 @@ class SyncStateWithBeekeeper(Command): ...@@ -59,6 +59,6 @@ class SyncStateWithBeekeeper(Command):
if user_wallet and encryption_wallet: if user_wallet and encryption_wallet:
await self.app_state.unlock(WalletContainer(user_wallet, encryption_wallet)) await self.app_state.unlock(WalletContainer(user_wallet, encryption_wallet))
elif not user_wallet and not encryption_wallet: elif not user_wallet and not encryption_wallet:
self.app_state.lock(self.source) await self.app_state.lock(self.source)
else: else:
raise InvalidWalletStateError(self) raise InvalidWalletStateError(self)
...@@ -36,4 +36,7 @@ class TUIErrorHandler(ErrorHandlerContextManager[Exception]): ...@@ -36,4 +36,7 @@ class TUIErrorHandler(ErrorHandlerContextManager[Exception]):
return ResultNotAvailable(error) return ResultNotAvailable(error)
def _switch_to_locked_mode(self) -> None: def _switch_to_locked_mode(self) -> None:
self._app.world.app_state.lock() async def impl() -> None:
await self._app.switch_mode_into_locked(save_profile=False)
self._app.run_worker_with_screen_remove_guard(impl())
...@@ -16,9 +16,6 @@ from clive.__private.core.node import Node ...@@ -16,9 +16,6 @@ from clive.__private.core.node import Node
from clive.__private.core.profile import Profile from clive.__private.core.profile import Profile
from clive.__private.core.wallet_container import WalletContainer from clive.__private.core.wallet_container import WalletContainer
from clive.__private.ui.clive_dom_node import CliveDOMNode from clive.__private.ui.clive_dom_node import CliveDOMNode
from clive.__private.ui.forms.create_profile.create_profile_form import CreateProfileForm
from clive.__private.ui.screens.dashboard import Dashboard
from clive.__private.ui.screens.unlock import Unlock
from clive.exceptions import ProfileNotLoadedError from clive.exceptions import ProfileNotLoadedError
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -116,7 +113,7 @@ class World: ...@@ -116,7 +113,7 @@ class World:
self._node.teardown() self._node.teardown()
self._beekeeper_manager.teardown() self._beekeeper_manager.teardown()
self.app_state.lock() await self.app_state.lock()
self._profile = None self._profile = None
self._node = None self._node = None
...@@ -173,22 +170,22 @@ class World: ...@@ -173,22 +170,22 @@ class World:
self._profile = new_profile self._profile = new_profile
await self._update_node() await self._update_node()
def on_going_into_locked_mode(self, source: LockSource) -> None: async def on_going_into_locked_mode(self, source: LockSource) -> None:
"""Triggered when the application is going into the locked mode.""" """Triggered when the application is going into the locked mode."""
if self._is_during_setup or self._is_during_closure: if self._is_during_setup or self._is_during_closure:
return return
self._on_going_into_locked_mode(source) await self._on_going_into_locked_mode(source)
def on_going_into_unlocked_mode(self) -> None: async def on_going_into_unlocked_mode(self) -> None:
"""Triggered when the application is going into the unlocked mode.""" """Triggered when the application is going into the unlocked mode."""
if self._is_during_setup or self._is_during_closure: if self._is_during_setup or self._is_during_closure:
return return
self._on_going_into_unlocked_mode() await self._on_going_into_unlocked_mode()
def _on_going_into_locked_mode(self, _: LockSource) -> None: async def _on_going_into_locked_mode(self, _: LockSource) -> None:
"""Override this method to hook when clive goes into the locked mode.""" """Override this method to hook when clive goes into the locked mode."""
def _on_going_into_unlocked_mode(self) -> None: async def _on_going_into_unlocked_mode(self) -> None:
"""Override this method to hook when clive goes into the unlocked mode.""" """Override this method to hook when clive goes into the unlocked mode."""
@asynccontextmanager @asynccontextmanager
...@@ -270,32 +267,12 @@ class TUIWorld(World, CliveDOMNode): ...@@ -270,32 +267,12 @@ class TUIWorld(World, CliveDOMNode):
def _watch_profile(self, profile: Profile) -> None: def _watch_profile(self, profile: Profile) -> None:
self.node.change_related_profile(profile) self.node.change_related_profile(profile)
def _on_going_into_locked_mode(self, source: LockSource) -> None: async def _on_going_into_locked_mode(self, source: LockSource) -> None:
if source == "beekeeper_wallet_lock_status_update_worker": await self.app._switch_mode_into_locked(source)
self.app.notify("Switched to the LOCKED mode due to timeout.", timeout=10)
self.app.pause_refresh_node_data_interval()
self.app.pause_refresh_alarms_data_interval()
self.node.cached.clear()
async def lock() -> None:
self._add_welcome_modes()
await self.app.switch_mode("unlock")
await self._restart_dashboard_mode()
await self.switch_profile(None)
self.app.run_worker(lock())
def _setup_commands(self) -> TUICommands: def _setup_commands(self) -> TUICommands:
return TUICommands(self) return TUICommands(self)
def _add_welcome_modes(self) -> None:
self.app.add_mode("create_profile", CreateProfileForm)
self.app.add_mode("unlock", Unlock)
async def _restart_dashboard_mode(self) -> None:
await self.app.remove_mode("dashboard")
self.app.add_mode("dashboard", Dashboard)
def _update_profile_related_reactive_attributes(self) -> None: def _update_profile_related_reactive_attributes(self) -> None:
# There's no proper way to add some proxy reactive property on textual reactives that could raise error if # There's no proper way to add some proxy reactive property on textual reactives that could raise error if
# not set yet, and still can be watched. See: https://github.com/Textualize/textual/discussions/4007 # not set yet, and still can be watched. See: https://github.com/Textualize/textual/discussions/4007
......
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import math import math
import traceback import traceback
from contextlib import asynccontextmanager, contextmanager from contextlib import asynccontextmanager, contextmanager
from typing import TYPE_CHECKING, Any, TypeVar, cast from typing import TYPE_CHECKING, Any, Awaitable, TypeVar, cast
from beekeepy.exceptions import CommunicationError from beekeepy.exceptions import CommunicationError
from textual import on, work from textual import on, work
from textual._context import active_app from textual._context import active_app
from textual.app import App from textual.app import App, UnknownModeError
from textual.binding import Binding from textual.binding import Binding
from textual.notifications import Notification, Notify, SeverityLevel from textual.notifications import Notification, Notify, SeverityLevel
from textual.reactive import var from textual.reactive import var
from textual.worker import WorkerCancelled from textual.worker import WorkerCancelled
from clive.__private.core.async_guard import AsyncGuard
from clive.__private.core.constants.terminal import TERMINAL_HEIGHT, TERMINAL_WIDTH from clive.__private.core.constants.terminal import TERMINAL_HEIGHT, TERMINAL_WIDTH
from clive.__private.core.constants.tui.bindings import APP_QUIT_KEY_BINDING from clive.__private.core.constants.tui.bindings import APP_QUIT_KEY_BINDING
from clive.__private.core.profile import Profile from clive.__private.core.profile import Profile
...@@ -37,6 +39,7 @@ if TYPE_CHECKING: ...@@ -37,6 +39,7 @@ if TYPE_CHECKING:
from textual.screen import Screen, ScreenResultType from textual.screen import Screen, ScreenResultType
from textual.worker import Worker from textual.worker import Worker
from clive.__private.core.app_state import LockSource
UpdateScreenResultT = TypeVar("UpdateScreenResultT") UpdateScreenResultT = TypeVar("UpdateScreenResultT")
...@@ -85,6 +88,16 @@ class Clive(App[int]): ...@@ -85,6 +88,16 @@ class Clive(App[int]):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._world: TUIWorld | None = None self._world: TUIWorld | None = None
self._screen_remove_guard = AsyncGuard()
"""
Due to https://github.com/Textualize/textual/issues/5008.
Any action that involves removing a screen like remove_mode/switch_screen/pop_screen
cannot be awaited in the @on handler like Button.Pressed because it will deadlock the app.
Workaround is to not await mentioned action or run it in a separate task if something later needs to await it.
This workaround can create race conditions, so we need to guard against it.
"""
@property @property
def world(self) -> TUIWorld: def world(self) -> TUIWorld:
assert self._world is not None, "World is not set yet." assert self._world is not None, "World is not set yet."
...@@ -293,6 +306,27 @@ class Clive(App[int]): ...@@ -293,6 +306,27 @@ class Clive(App[int]):
if self.world._beekeeper_manager: if self.world._beekeeper_manager:
await self.world.commands.sync_state_with_beekeeper("beekeeper_wallet_lock_status_update_worker") await self.world.commands.sync_state_with_beekeeper("beekeeper_wallet_lock_status_update_worker")
async def switch_mode_into_locked(self, *, save_profile: bool = True) -> None:
if save_profile:
await self.world.commands.save_profile()
await self.world.commands.lock()
def run_worker_with_guard(self, awaitable: Awaitable[None], guard: AsyncGuard) -> None:
"""Run work in a worker with a guard. It means that the work will be executed only if the guard is available."""
async def work_with_release() -> None:
try:
await awaitable
finally:
guard.release()
with guard.suppress():
guard.acquire()
self.run_worker(work_with_release())
def run_worker_with_screen_remove_guard(self, awaitable: Awaitable[None]) -> None:
self.run_worker_with_guard(awaitable, self._screen_remove_guard)
async def __debug_log(self) -> None: async def __debug_log(self) -> None:
logger.debug("===================== DEBUG =====================") logger.debug("===================== DEBUG =====================")
logger.debug(f"Currently focused: {self.focused}") logger.debug(f"Currently focused: {self.focused}")
...@@ -335,3 +369,36 @@ class Clive(App[int]): ...@@ -335,3 +369,36 @@ class Clive(App[int]):
def _retrigger_update_alarms_data(self) -> None: def _retrigger_update_alarms_data(self) -> None:
if self.is_worker_group_empty("alarms_data"): if self.is_worker_group_empty("alarms_data"):
self.update_alarms_data() self.update_alarms_data()
async def _switch_mode_into_locked(self, source: LockSource) -> None:
async def restart_dashboard_mode() -> None:
await self.remove_mode("dashboard")
self.add_mode("dashboard", Dashboard)
def add_welcome_modes() -> None:
self.add_mode("create_profile", CreateProfileForm)
self.add_mode("unlock", Unlock)
if source == "beekeeper_wallet_lock_status_update_worker":
self.notify("Switched to the LOCKED mode due to timeout.", timeout=10)
self.pause_refresh_node_data_interval()
self.pause_refresh_alarms_data_interval()
self.world.node.cached.clear()
add_welcome_modes()
await self.switch_mode("unlock")
await restart_dashboard_mode()
await self.world.switch_profile(None)
async def _switch_mode_into_unlocked(self) -> None:
async def remove_welcome_modes() -> None:
with contextlib.suppress(UnknownModeError):
await self.remove_mode("create_profile")
with contextlib.suppress(UnknownModeError):
await self.remove_mode("unlock")
await self.switch_mode("dashboard")
await remove_welcome_modes()
self.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True)
self.resume_refresh_node_data_interval()
self.resume_refresh_alarms_data_interval()
...@@ -26,20 +26,22 @@ class CreateProfileForm(Form): ...@@ -26,20 +26,22 @@ class CreateProfileForm(Form):
async def exit_form(self) -> None: async def exit_form(self) -> None:
# when this form is displayed during onboarding, there is no previous screen to go back to # when this form is displayed during onboarding, there is no previous screen to go back to
# so this method won't be called # so this method won't be called
await self.app.switch_mode("unlock")
self.app.remove_mode("create_profile") async def impl() -> None:
await self.app.switch_mode("unlock")
await self.app.remove_mode("create_profile")
# Has to be done in a separate task to avoid deadlock.
# More: https://github.com/Textualize/textual/issues/5008
self.app.run_worker_with_screen_remove_guard(impl())
async def finish_form(self) -> None: async def finish_form(self) -> None:
async def handle_modes() -> None: async def impl() -> None:
await self.app.switch_mode("dashboard") await self.execute_post_actions()
self.app.remove_mode("create_profile") self.profile.enable_saving()
self.app.remove_mode("unlock") await self.commands.save_profile()
await self.app._switch_mode_into_unlocked()
self.add_post_action(
lambda: self.app.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True), # Has to be done in a separate task to avoid deadlock.
self.app.resume_refresh_alarms_data_interval, # More: https://github.com/Textualize/textual/issues/5008
) self.app.run_worker_with_screen_remove_guard(impl())
await self.execute_post_actions()
await handle_modes()
self.profile.enable_saving()
await self.commands.save_profile()
...@@ -329,8 +329,8 @@ class Dashboard(BaseScreen): ...@@ -329,8 +329,8 @@ class Dashboard(BaseScreen):
self.app.push_screen(AddTrackedAccountDialog()) self.app.push_screen(AddTrackedAccountDialog())
async def action_switch_mode_into_locked(self) -> None: async def action_switch_mode_into_locked(self) -> None:
await self.app.world.commands.save_profile() with self.app._screen_remove_guard.suppress(), self.app._screen_remove_guard.guard():
await self.app.world.commands.lock() await self.app.switch_mode_into_locked()
@property @property
def has_working_account(self) -> bool: def has_working_account(self) -> bool:
......
...@@ -108,32 +108,33 @@ class Unlock(BaseScreen): ...@@ -108,32 +108,33 @@ class Unlock(BaseScreen):
@on(Button.Pressed, "#unlock-button") @on(Button.Pressed, "#unlock-button")
@on(CliveInput.Submitted) @on(CliveInput.Submitted)
async def unlock(self) -> None: async def unlock(self) -> None:
password_input = self.query_exactly_one(PasswordInput) async def impl() -> None:
select_profile = self.query_exactly_one(SelectProfile) password_input = self.query_exactly_one(PasswordInput)
lock_after_time = self.query_exactly_one(LockAfterTime) select_profile = self.query_exactly_one(SelectProfile)
lock_after_time = self.query_exactly_one(LockAfterTime)
if not password_input.validate_passed() or not lock_after_time.is_valid:
return if not password_input.validate_passed() or not lock_after_time.is_valid:
return
try:
await self.world.load_profile( try:
profile_name=select_profile.value_ensure, await self.world.load_profile(
password=password_input.value_or_error, profile_name=select_profile.value_ensure,
permanent=lock_after_time.should_stay_unlocked, password=password_input.value_or_error,
time=lock_after_time.lock_duration, permanent=lock_after_time.should_stay_unlocked,
) time=lock_after_time.lock_duration,
except InvalidPasswordError: )
logger.error( except InvalidPasswordError:
f"Profile `{select_profile.value_ensure}` was not unlocked " logger.error(
"because entered password is invalid, skipping switching modes" f"Profile `{select_profile.value_ensure}` was not unlocked "
) "because entered password is invalid, skipping switching modes"
return )
return
await self.app.switch_mode("dashboard")
self._remove_welcome_modes() await self.app._switch_mode_into_unlocked()
self.app.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True)
self.app.resume_refresh_node_data_interval() # Has to be done in a separate task to avoid deadlock.
self.app.resume_refresh_alarms_data_interval() # More: https://github.com/Textualize/textual/issues/5008
self.app.run_worker_with_screen_remove_guard(impl())
@on(Button.Pressed, "#new-profile-button") @on(Button.Pressed, "#new-profile-button")
async def create_new_profile(self) -> None: async def create_new_profile(self) -> None:
...@@ -146,7 +147,3 @@ class Unlock(BaseScreen): ...@@ -146,7 +147,3 @@ class Unlock(BaseScreen):
@on(SelectProfile.Changed) @on(SelectProfile.Changed)
def clear_password_input(self) -> None: def clear_password_input(self) -> None:
self.query_exactly_one(PasswordInput).clear_validation() self.query_exactly_one(PasswordInput).clear_validation()
def _remove_welcome_modes(self) -> None:
self.app.remove_mode("unlock")
self.app.remove_mode("create_profile")
...@@ -203,8 +203,9 @@ class LockStatus(DynamicOneLineButtonUnfocusable): ...@@ -203,8 +203,9 @@ class LockStatus(DynamicOneLineButtonUnfocusable):
@on(OneLineButton.Pressed) @on(OneLineButton.Pressed)
async def lock_wallet(self) -> None: async def lock_wallet(self) -> None:
await self.commands.save_profile() # Has to be done in a separate task to avoid deadlock.
await self.commands.lock() # More: https://github.com/Textualize/textual/issues/5008
self.app.run_worker_with_screen_remove_guard(self.app.switch_mode_into_locked())
def _wallet_to_locked_changed(self) -> None: def _wallet_to_locked_changed(self) -> None:
self.post_message(self.WalletLocked()) self.post_message(self.WalletLocked())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment