From e20ade7c42b03433ed9a889633faf5b6b3701551 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Mateusz=20=C5=BBebrak?= <mzebrak@syncad.com>
Date: Mon, 17 Mar 2025 15:18:00 +0100
Subject: [PATCH] Use AsyncGuard to protect from race condition and crash when
 screen is removed

---
 clive/__private/core/app_state.py             |  6 +-
 clive/__private/core/commands/lock.py         |  2 +-
 .../commands/sync_state_with_beekeeper.py     |  2 +-
 .../core/error_handlers/tui_error_handler.py  |  5 +-
 clive/__private/core/world.py                 | 41 +++--------
 clive/__private/ui/app.py                     | 71 ++++++++++++++++++-
 .../create_profile/create_profile_form.py     | 32 +++++----
 .../ui/screens/dashboard/dashboard.py         |  4 +-
 clive/__private/ui/screens/unlock/unlock.py   | 57 +++++++--------
 .../ui/widgets/clive_basic/clive_header.py    |  5 +-
 10 files changed, 136 insertions(+), 89 deletions(-)

diff --git a/clive/__private/core/app_state.py b/clive/__private/core/app_state.py
index ace862eb20..ecb2f02859 100644
--- a/clive/__private/core/app_state.py
+++ b/clive/__private/core/app_state.py
@@ -34,16 +34,16 @@ class AppState:
         self._is_unlocked = True
         if wallets:
             await self.world.beekeeper_manager.set_wallets(wallets)
-        self.world.on_going_into_unlocked_mode()
+        await self.world.on_going_into_unlocked_mode()
         logger.info("Mode switched to UNLOCKED.")
 
-    def lock(self, source: LockSource = "unknown") -> None:
+    async def lock(self, source: LockSource = "unknown") -> None:
         if not self._is_unlocked:
             return
 
         self._is_unlocked = False
         self.world.beekeeper_manager.clear_wallets()
-        self.world.on_going_into_locked_mode(source)
+        await self.world.on_going_into_locked_mode(source)
         logger.info("Mode switched to LOCKED.")
 
     def __hash__(self) -> int:
diff --git a/clive/__private/core/commands/lock.py b/clive/__private/core/commands/lock.py
index 894b1fd92e..18df59c4b9 100644
--- a/clive/__private/core/commands/lock.py
+++ b/clive/__private/core/commands/lock.py
@@ -21,4 +21,4 @@ class Lock(Command):
     async def _execute(self) -> None:
         await self.session.lock_all()
         if self.app_state:
-            self.app_state.lock()
+            await self.app_state.lock()
diff --git a/clive/__private/core/commands/sync_state_with_beekeeper.py b/clive/__private/core/commands/sync_state_with_beekeeper.py
index d98c766acc..57113323df 100644
--- a/clive/__private/core/commands/sync_state_with_beekeeper.py
+++ b/clive/__private/core/commands/sync_state_with_beekeeper.py
@@ -59,6 +59,6 @@ class SyncStateWithBeekeeper(Command):
         if user_wallet and encryption_wallet:
             await self.app_state.unlock(WalletContainer(user_wallet, encryption_wallet))
         elif not user_wallet and not encryption_wallet:
-            self.app_state.lock(self.source)
+            await self.app_state.lock(self.source)
         else:
             raise InvalidWalletStateError(self)
diff --git a/clive/__private/core/error_handlers/tui_error_handler.py b/clive/__private/core/error_handlers/tui_error_handler.py
index b146b6361b..f20acb1887 100644
--- a/clive/__private/core/error_handlers/tui_error_handler.py
+++ b/clive/__private/core/error_handlers/tui_error_handler.py
@@ -36,4 +36,7 @@ class TUIErrorHandler(ErrorHandlerContextManager[Exception]):
         return ResultNotAvailable(error)
 
     def _switch_to_locked_mode(self) -> None:
-        self._app.world.app_state.lock()
+        async def impl() -> None:
+            await self._app.switch_mode_into_locked(save_profile=False)
+
+        self._app.run_worker_with_screen_remove_guard(impl())
diff --git a/clive/__private/core/world.py b/clive/__private/core/world.py
index d2d9d1acc9..c02536ad07 100644
--- a/clive/__private/core/world.py
+++ b/clive/__private/core/world.py
@@ -16,9 +16,6 @@ from clive.__private.core.node import Node
 from clive.__private.core.profile import Profile
 from clive.__private.core.wallet_container import WalletContainer
 from clive.__private.ui.clive_dom_node import CliveDOMNode
-from clive.__private.ui.forms.create_profile.create_profile_form import CreateProfileForm
-from clive.__private.ui.screens.dashboard import Dashboard
-from clive.__private.ui.screens.unlock import Unlock
 from clive.exceptions import ProfileNotLoadedError
 
 if TYPE_CHECKING:
@@ -116,7 +113,7 @@ class World:
                 self._node.teardown()
             self._beekeeper_manager.teardown()
 
-            self.app_state.lock()
+            await self.app_state.lock()
 
             self._profile = None
             self._node = None
@@ -173,22 +170,22 @@ class World:
         self._profile = new_profile
         await self._update_node()
 
-    def on_going_into_locked_mode(self, source: LockSource) -> None:
+    async def on_going_into_locked_mode(self, source: LockSource) -> None:
         """Triggered when the application is going into the locked mode."""
         if self._is_during_setup or self._is_during_closure:
             return
-        self._on_going_into_locked_mode(source)
+        await self._on_going_into_locked_mode(source)
 
-    def on_going_into_unlocked_mode(self) -> None:
+    async def on_going_into_unlocked_mode(self) -> None:
         """Triggered when the application is going into the unlocked mode."""
         if self._is_during_setup or self._is_during_closure:
             return
-        self._on_going_into_unlocked_mode()
+        await self._on_going_into_unlocked_mode()
 
-    def _on_going_into_locked_mode(self, _: LockSource) -> None:
+    async def _on_going_into_locked_mode(self, _: LockSource) -> None:
         """Override this method to hook when clive goes into the locked mode."""
 
-    def _on_going_into_unlocked_mode(self) -> None:
+    async def _on_going_into_unlocked_mode(self) -> None:
         """Override this method to hook when clive goes into the unlocked mode."""
 
     @asynccontextmanager
@@ -270,32 +267,12 @@ class TUIWorld(World, CliveDOMNode):
     def _watch_profile(self, profile: Profile) -> None:
         self.node.change_related_profile(profile)
 
-    def _on_going_into_locked_mode(self, source: LockSource) -> None:
-        if source == "beekeeper_wallet_lock_status_update_worker":
-            self.app.notify("Switched to the LOCKED mode due to timeout.", timeout=10)
-        self.app.pause_refresh_node_data_interval()
-        self.app.pause_refresh_alarms_data_interval()
-        self.node.cached.clear()
-
-        async def lock() -> None:
-            self._add_welcome_modes()
-            await self.app.switch_mode("unlock")
-            await self._restart_dashboard_mode()
-            await self.switch_profile(None)
-
-        self.app.run_worker(lock())
+    async def _on_going_into_locked_mode(self, source: LockSource) -> None:
+        await self.app._switch_mode_into_locked(source)
 
     def _setup_commands(self) -> TUICommands:
         return TUICommands(self)
 
-    def _add_welcome_modes(self) -> None:
-        self.app.add_mode("create_profile", CreateProfileForm)
-        self.app.add_mode("unlock", Unlock)
-
-    async def _restart_dashboard_mode(self) -> None:
-        await self.app.remove_mode("dashboard")
-        self.app.add_mode("dashboard", Dashboard)
-
     def _update_profile_related_reactive_attributes(self) -> None:
         # There's no proper way to add some proxy reactive property on textual reactives that could raise error if
         # not set yet, and still can be watched. See: https://github.com/Textualize/textual/discussions/4007
diff --git a/clive/__private/ui/app.py b/clive/__private/ui/app.py
index 04dd639616..6e24ece10a 100644
--- a/clive/__private/ui/app.py
+++ b/clive/__private/ui/app.py
@@ -1,20 +1,22 @@
 from __future__ import annotations
 
 import asyncio
+import contextlib
 import math
 import traceback
 from contextlib import asynccontextmanager, contextmanager
-from typing import TYPE_CHECKING, Any, TypeVar, cast
+from typing import TYPE_CHECKING, Any, Awaitable, TypeVar, cast
 
 from beekeepy.exceptions import CommunicationError
 from textual import on, work
 from textual._context import active_app
-from textual.app import App
+from textual.app import App, UnknownModeError
 from textual.binding import Binding
 from textual.notifications import Notification, Notify, SeverityLevel
 from textual.reactive import var
 from textual.worker import WorkerCancelled
 
+from clive.__private.core.async_guard import AsyncGuard
 from clive.__private.core.constants.terminal import TERMINAL_HEIGHT, TERMINAL_WIDTH
 from clive.__private.core.constants.tui.bindings import APP_QUIT_KEY_BINDING
 from clive.__private.core.profile import Profile
@@ -37,6 +39,7 @@ if TYPE_CHECKING:
     from textual.screen import Screen, ScreenResultType
     from textual.worker import Worker
 
+    from clive.__private.core.app_state import LockSource
 
 UpdateScreenResultT = TypeVar("UpdateScreenResultT")
 
@@ -85,6 +88,16 @@ class Clive(App[int]):
         super().__init__(*args, **kwargs)
         self._world: TUIWorld | None = None
 
+        self._screen_remove_guard = AsyncGuard()
+        """
+        Due to https://github.com/Textualize/textual/issues/5008.
+
+        Any action that involves removing a screen like remove_mode/switch_screen/pop_screen
+        cannot be awaited in the @on handler like Button.Pressed because it will deadlock the app.
+        Workaround is to not await mentioned action or run it in a separate task if something later needs to await it.
+        This workaround can create race conditions, so we need to guard against it.
+        """
+
     @property
     def world(self) -> TUIWorld:
         assert self._world is not None, "World is not set yet."
@@ -293,6 +306,27 @@ class Clive(App[int]):
         if self.world._beekeeper_manager:
             await self.world.commands.sync_state_with_beekeeper("beekeeper_wallet_lock_status_update_worker")
 
+    async def switch_mode_into_locked(self, *, save_profile: bool = True) -> None:
+        if save_profile:
+            await self.world.commands.save_profile()
+        await self.world.commands.lock()
+
+    def run_worker_with_guard(self, awaitable: Awaitable[None], guard: AsyncGuard) -> None:
+        """Run work in a worker with a guard. It means that the work will be executed only if the guard is available."""
+
+        async def work_with_release() -> None:
+            try:
+                await awaitable
+            finally:
+                guard.release()
+
+        with guard.suppress():
+            guard.acquire()
+            self.run_worker(work_with_release())
+
+    def run_worker_with_screen_remove_guard(self, awaitable: Awaitable[None]) -> None:
+        self.run_worker_with_guard(awaitable, self._screen_remove_guard)
+
     async def __debug_log(self) -> None:
         logger.debug("===================== DEBUG =====================")
         logger.debug(f"Currently focused: {self.focused}")
@@ -335,3 +369,36 @@ class Clive(App[int]):
     def _retrigger_update_alarms_data(self) -> None:
         if self.is_worker_group_empty("alarms_data"):
             self.update_alarms_data()
+
+    async def _switch_mode_into_locked(self, source: LockSource) -> None:
+        async def restart_dashboard_mode() -> None:
+            await self.remove_mode("dashboard")
+            self.add_mode("dashboard", Dashboard)
+
+        def add_welcome_modes() -> None:
+            self.add_mode("create_profile", CreateProfileForm)
+            self.add_mode("unlock", Unlock)
+
+        if source == "beekeeper_wallet_lock_status_update_worker":
+            self.notify("Switched to the LOCKED mode due to timeout.", timeout=10)
+
+        self.pause_refresh_node_data_interval()
+        self.pause_refresh_alarms_data_interval()
+        self.world.node.cached.clear()
+        add_welcome_modes()
+        await self.switch_mode("unlock")
+        await restart_dashboard_mode()
+        await self.world.switch_profile(None)
+
+    async def _switch_mode_into_unlocked(self) -> None:
+        async def remove_welcome_modes() -> None:
+            with contextlib.suppress(UnknownModeError):
+                await self.remove_mode("create_profile")
+            with contextlib.suppress(UnknownModeError):
+                await self.remove_mode("unlock")
+
+        await self.switch_mode("dashboard")
+        await remove_welcome_modes()
+        self.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True)
+        self.resume_refresh_node_data_interval()
+        self.resume_refresh_alarms_data_interval()
diff --git a/clive/__private/ui/forms/create_profile/create_profile_form.py b/clive/__private/ui/forms/create_profile/create_profile_form.py
index a09cff3951..029b937f6e 100644
--- a/clive/__private/ui/forms/create_profile/create_profile_form.py
+++ b/clive/__private/ui/forms/create_profile/create_profile_form.py
@@ -26,20 +26,22 @@ class CreateProfileForm(Form):
     async def exit_form(self) -> None:
         # when this form is displayed during onboarding, there is no previous screen to go back to
         # so this method won't be called
-        await self.app.switch_mode("unlock")
-        self.app.remove_mode("create_profile")
+
+        async def impl() -> None:
+            await self.app.switch_mode("unlock")
+            await self.app.remove_mode("create_profile")
+
+        # Has to be done in a separate task to avoid deadlock.
+        # More: https://github.com/Textualize/textual/issues/5008
+        self.app.run_worker_with_screen_remove_guard(impl())
 
     async def finish_form(self) -> None:
-        async def handle_modes() -> None:
-            await self.app.switch_mode("dashboard")
-            self.app.remove_mode("create_profile")
-            self.app.remove_mode("unlock")
-
-        self.add_post_action(
-            lambda: self.app.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True),
-            self.app.resume_refresh_alarms_data_interval,
-        )
-        await self.execute_post_actions()
-        await handle_modes()
-        self.profile.enable_saving()
-        await self.commands.save_profile()
+        async def impl() -> None:
+            await self.execute_post_actions()
+            self.profile.enable_saving()
+            await self.commands.save_profile()
+            await self.app._switch_mode_into_unlocked()
+
+        # Has to be done in a separate task to avoid deadlock.
+        # More: https://github.com/Textualize/textual/issues/5008
+        self.app.run_worker_with_screen_remove_guard(impl())
diff --git a/clive/__private/ui/screens/dashboard/dashboard.py b/clive/__private/ui/screens/dashboard/dashboard.py
index 8769f5468a..3c7285021d 100644
--- a/clive/__private/ui/screens/dashboard/dashboard.py
+++ b/clive/__private/ui/screens/dashboard/dashboard.py
@@ -329,8 +329,8 @@ class Dashboard(BaseScreen):
         self.app.push_screen(AddTrackedAccountDialog())
 
     async def action_switch_mode_into_locked(self) -> None:
-        await self.app.world.commands.save_profile()
-        await self.app.world.commands.lock()
+        with self.app._screen_remove_guard.suppress(), self.app._screen_remove_guard.guard():
+            await self.app.switch_mode_into_locked()
 
     @property
     def has_working_account(self) -> bool:
diff --git a/clive/__private/ui/screens/unlock/unlock.py b/clive/__private/ui/screens/unlock/unlock.py
index 09886023e1..82350ad2d5 100644
--- a/clive/__private/ui/screens/unlock/unlock.py
+++ b/clive/__private/ui/screens/unlock/unlock.py
@@ -108,32 +108,33 @@ class Unlock(BaseScreen):
     @on(Button.Pressed, "#unlock-button")
     @on(CliveInput.Submitted)
     async def unlock(self) -> None:
-        password_input = self.query_exactly_one(PasswordInput)
-        select_profile = self.query_exactly_one(SelectProfile)
-        lock_after_time = self.query_exactly_one(LockAfterTime)
-
-        if not password_input.validate_passed() or not lock_after_time.is_valid:
-            return
-
-        try:
-            await self.world.load_profile(
-                profile_name=select_profile.value_ensure,
-                password=password_input.value_or_error,
-                permanent=lock_after_time.should_stay_unlocked,
-                time=lock_after_time.lock_duration,
-            )
-        except InvalidPasswordError:
-            logger.error(
-                f"Profile `{select_profile.value_ensure}` was not unlocked "
-                "because entered password is invalid, skipping switching modes"
-            )
-            return
-
-        await self.app.switch_mode("dashboard")
-        self._remove_welcome_modes()
-        self.app.update_alarms_data_on_newest_node_data(suppress_cancelled_error=True)
-        self.app.resume_refresh_node_data_interval()
-        self.app.resume_refresh_alarms_data_interval()
+        async def impl() -> None:
+            password_input = self.query_exactly_one(PasswordInput)
+            select_profile = self.query_exactly_one(SelectProfile)
+            lock_after_time = self.query_exactly_one(LockAfterTime)
+
+            if not password_input.validate_passed() or not lock_after_time.is_valid:
+                return
+
+            try:
+                await self.world.load_profile(
+                    profile_name=select_profile.value_ensure,
+                    password=password_input.value_or_error,
+                    permanent=lock_after_time.should_stay_unlocked,
+                    time=lock_after_time.lock_duration,
+                )
+            except InvalidPasswordError:
+                logger.error(
+                    f"Profile `{select_profile.value_ensure}` was not unlocked "
+                    "because entered password is invalid, skipping switching modes"
+                )
+                return
+
+            await self.app._switch_mode_into_unlocked()
+
+        # Has to be done in a separate task to avoid deadlock.
+        # More: https://github.com/Textualize/textual/issues/5008
+        self.app.run_worker_with_screen_remove_guard(impl())
 
     @on(Button.Pressed, "#new-profile-button")
     async def create_new_profile(self) -> None:
@@ -146,7 +147,3 @@ class Unlock(BaseScreen):
     @on(SelectProfile.Changed)
     def clear_password_input(self) -> None:
         self.query_exactly_one(PasswordInput).clear_validation()
-
-    def _remove_welcome_modes(self) -> None:
-        self.app.remove_mode("unlock")
-        self.app.remove_mode("create_profile")
diff --git a/clive/__private/ui/widgets/clive_basic/clive_header.py b/clive/__private/ui/widgets/clive_basic/clive_header.py
index a223159639..8734569bf3 100644
--- a/clive/__private/ui/widgets/clive_basic/clive_header.py
+++ b/clive/__private/ui/widgets/clive_basic/clive_header.py
@@ -203,8 +203,9 @@ class LockStatus(DynamicOneLineButtonUnfocusable):
 
     @on(OneLineButton.Pressed)
     async def lock_wallet(self) -> None:
-        await self.commands.save_profile()
-        await self.commands.lock()
+        # Has to be done in a separate task to avoid deadlock.
+        # More: https://github.com/Textualize/textual/issues/5008
+        self.app.run_worker_with_screen_remove_guard(self.app.switch_mode_into_locked())
 
     def _wallet_to_locked_changed(self) -> None:
         self.post_message(self.WalletLocked())
-- 
GitLab