diff --git a/src/promptflow/promptflow/_cli/_pf/_connection.py b/src/promptflow/promptflow/_cli/_pf/_connection.py index c16f4f67620..8988dac2a09 100644 --- a/src/promptflow/promptflow/_cli/_pf/_connection.py +++ b/src/promptflow/promptflow/_cli/_pf/_connection.py @@ -3,13 +3,12 @@ # --------------------------------------------------------- import argparse -import getpass import json import logging from functools import partial from promptflow._cli._params import add_param_set, logging_params -from promptflow._cli._utils import activate_action, confirm, exception_handler +from promptflow._cli._utils import activate_action, confirm, exception_handler, print_yellow_warning, get_secret_input from promptflow._sdk._constants import LOGGER_NAME from promptflow._sdk._load_functions import load_connection from promptflow._sdk._pf_client import PFClient @@ -149,7 +148,12 @@ def validate_and_interactive_get_secrets(connection, is_update=False): if not missing_secrets_prompt: print(prompt) missing_secrets_prompt = True - connection.secrets[name] = getpass.getpass(prompt=f"{name}: ") + while True: + secret = get_secret_input(prompt=f"{name}: ") + if secret: + break + print_yellow_warning("Secret can't be empty.") + connection.secrets[name] = secret if missing_secrets_prompt: print("=================== Required secrets collected ===================") return connection diff --git a/src/promptflow/promptflow/_cli/_utils.py b/src/promptflow/promptflow/_cli/_utils.py index 8cf8ec27808..d6cf454d61c 100644 --- a/src/promptflow/promptflow/_cli/_utils.py +++ b/src/promptflow/promptflow/_cli/_utils.py @@ -21,7 +21,7 @@ from dotenv import load_dotenv from tabulate import tabulate -from promptflow._sdk._utils import print_red_error +from promptflow._sdk._utils import print_red_error, print_yellow_warning from promptflow._utils.utils import is_in_ci_pipeline from promptflow.exceptions import ErrorTarget, PromptflowException, UserErrorException @@ -349,3 +349,64 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def get_secret_input(prompt, mask="*"): + """Get secret input with mask printed on screen in CLI. + + Provide better handling for control characters: + - Handle Ctrl-C as KeyboardInterrupt + - Ignore control characters and print warning message. + """ + if not isinstance(prompt, str): + raise TypeError(f"prompt must be a str, not ${type(prompt).__name__}") + if not isinstance(mask, str): + raise TypeError(f"mask argument must be a one-character str, not ${type(mask).__name__}") + if len(mask) != 1: + raise ValueError("mask argument must be a one-character str") + + if sys.platform == "win32": + # For some reason, mypy reports that msvcrt doesn't have getch, ignore this warning: + from msvcrt import getch # type: ignore + else: # macOS and Linux + import tty + import termios + + def getch(): + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(sys.stdin.fileno()) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + secret_input = [] + sys.stdout.write(prompt) + sys.stdout.flush() + + while True: + key = ord(getch()) + if key == 13: # Enter key pressed. + sys.stdout.write("\n") + return "".join(secret_input) + elif key == 3: # Ctrl-C pressed. + raise KeyboardInterrupt() + elif key in (8, 127): # Backspace/Del key erases previous output. + if len(secret_input) > 0: + # Erases previous character. + sys.stdout.write("\b \b") # \b doesn't erase the character, it just moves the cursor back. + sys.stdout.flush() + secret_input = secret_input[:-1] + elif 0 <= key <= 31: + msg = "\nThe last user input got ignored as it is control character." + print_yellow_warning(msg) + sys.stdout.write(prompt + mask * len(secret_input)) + sys.stdout.flush() + else: + # display the mask character. + char = chr(key) + sys.stdout.write(mask) + sys.stdout.flush() + secret_input.append(char) diff --git a/src/promptflow/promptflow/_sdk/_utils.py b/src/promptflow/promptflow/_sdk/_utils.py index 11af27e3dff..c72c0ca7046 100644 --- a/src/promptflow/promptflow/_sdk/_utils.py +++ b/src/promptflow/promptflow/_sdk/_utils.py @@ -157,7 +157,7 @@ def _get_from_keyring(): raise StoreConnectionEncryptionKeyError( "System keyring backend service not found in your operating system. " "See https://pypi.org/project/keyring/ to install requirement for different operating system, " - "or 'pip install keyrings.alt' to use the third-party backend. Reach more detail about this error at" + "or 'pip install keyrings.alt' to use the third-party backend. Reach more detail about this error at " "https://microsoft.github.io/promptflow/how-to-guides/faq.html#connection-creation-failed-with-storeconnectionencryptionkeyerror" # noqa: E501 ) from e diff --git a/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py b/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py index 07f25864b18..290010f8f50 100644 --- a/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py +++ b/src/promptflow/tests/sdk_cli_test/unittests/test_connection.py @@ -245,7 +245,7 @@ def test_validate_and_interactive_get_secrets(self): name="test_connection", secrets={"key1": SCRUBBED_VALUE, "key2": "", "key3": "", "key4": "", "key5": "**"}, ) - with patch("getpass.getpass", new=lambda prompt: "test_value"): + with patch("promptflow._cli._pf._connection.get_secret_input", new=lambda prompt: "test_value"): validate_and_interactive_get_secrets(connection, is_update=False) assert connection.secrets == { "key1": "test_value", @@ -260,7 +260,7 @@ def test_validate_and_interactive_get_secrets(self): name="test_connection", secrets={"key1": SCRUBBED_VALUE, "key2": "", "key3": "", "key4": "", "key5": "**"}, ) - with patch("getpass.getpass", new=lambda prompt: "test_value"): + with patch("promptflow._cli._pf._connection.get_secret_input", new=lambda prompt: "test_value"): validate_and_interactive_get_secrets(connection, is_update=True) assert connection.secrets == { "key1": SCRUBBED_VALUE,