Skip to content

Commit

Permalink
Merge pull request #988 from Pythagora-io/console-input
Browse files Browse the repository at this point in the history
console ui: add multiline, readline, paste, initial/default text support
  • Loading branch information
LeonOstrez authored Jun 3, 2024
2 parents b9becc8 + 4927bee commit 6b5c5c5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
12 changes: 10 additions & 2 deletions core/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from core.log import get_logger
from core.state.state_manager import StateManager
from core.telemetry import telemetry
from core.ui.base import UIBase, UIClosedError, pythagora_source
from core.ui.base import UIBase, UIClosedError, UserInput, pythagora_source

log = get_logger(__name__)

Expand Down Expand Up @@ -112,7 +112,15 @@ async def start_new_project(sm: StateManager, ui: UIBase) -> bool:
:param ui: User interface.
:return: True if the project was created successfully, False otherwise.
"""
user_input = await ui.ask_question("What is the project name?", allow_empty=False, source=pythagora_source)
try:
user_input = await ui.ask_question(
"What is the project name?",
allow_empty=False,
source=pythagora_source,
)
except (KeyboardInterrupt, UIClosedError):
user_input = UserInput(cancelled=True)

if user_input.cancelled:
return False

Expand Down
7 changes: 6 additions & 1 deletion core/ui/console.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional

from prompt_toolkit.shortcuts import PromptSession

from core.log import get_logger
from core.ui.base import ProjectStage, UIBase, UIClosedError, UISource, UserInput

Expand Down Expand Up @@ -57,9 +59,12 @@ async def ask_question(
default_str = " (default)" if k == default else ""
print(f" [{k}]: {v}{default_str}")

session = PromptSession("> ")

while True:
try:
choice = input("> ").strip()
choice = await session.prompt_async(default=initial_text or "")
choice = choice.strip()
except KeyboardInterrupt:
raise UIClosedError()
if not choice and default:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ psutil = "^5.9.8"
httpx = "^0.27.0"
alembic = "^1.13.1"
python-dotenv = "^1.0.1"
prompt-toolkit = "^3.0.45"

[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
Expand Down
23 changes: 13 additions & 10 deletions tests/ui/test_console.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import pytest

Expand Down Expand Up @@ -35,8 +35,9 @@ async def test_stream(capsys):


@pytest.mark.asyncio
@patch("builtins.input", return_value="awesome")
async def test_ask_question_simple(mock_input):
@patch("core.ui.console.PromptSession")
async def test_ask_question_simple(mock_PromptSession):
prompt_async = mock_PromptSession.return_value.prompt_async = AsyncMock(return_value="awesome")
ui = PlainConsoleUI()

await ui.start()
Expand All @@ -48,12 +49,13 @@ async def test_ask_question_simple(mock_input):

await ui.stop()

mock_input.assert_called_once()
prompt_async.assert_awaited_once()


@pytest.mark.asyncio
@patch("builtins.input", return_value="yes")
async def test_ask_question_with_buttons(mock_input):
@patch("core.ui.console.PromptSession")
async def test_ask_question_with_buttons(mock_PromptSession):
prompt_async = mock_PromptSession.return_value.prompt_async = AsyncMock(return_value="yes")
ui = PlainConsoleUI()

await ui.start()
Expand All @@ -68,12 +70,13 @@ async def test_ask_question_with_buttons(mock_input):

await ui.stop()

mock_input.assert_called_once()
prompt_async.assert_awaited_once()


@pytest.mark.asyncio
@patch("builtins.input", side_effect=KeyboardInterrupt())
async def test_ask_question_interrupted(mock_input):
@patch("core.ui.console.PromptSession")
async def test_ask_question_interrupted(mock_PromptSession):
prompt_async = mock_PromptSession.return_value.prompt_async = AsyncMock(side_effect=KeyboardInterrupt)
ui = PlainConsoleUI()

await ui.start()
Expand All @@ -82,4 +85,4 @@ async def test_ask_question_interrupted(mock_input):

await ui.stop()

mock_input.assert_called_once()
prompt_async.assert_awaited_once()

0 comments on commit 6b5c5c5

Please sign in to comment.