Skip to content

Commit

Permalink
[Promptflow CLI] refine secret input handling (microsoft#287)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes]**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.


![image](https://github.com/microsoft/promptflow/assets/47586720/11b41a36-4e1c-47ff-a8b6-c2121c22f7d2)
  • Loading branch information
wangchao1230 authored Sep 4, 2023
1 parent 5925b3b commit adddc9b
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 7 deletions.
10 changes: 7 additions & 3 deletions src/promptflow/promptflow/_cli/_pf/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# ---------------------------------------------------------

import argparse
import getpass
import json
import logging
from functools import partial

from promptflow._cli._params import add_param_set, logging_params
from promptflow._cli._utils import activate_action, confirm, exception_handler
from promptflow._cli._utils import activate_action, confirm, exception_handler, print_yellow_warning, get_secret_input
from promptflow._sdk._constants import LOGGER_NAME
from promptflow._sdk._load_functions import load_connection
from promptflow._sdk._pf_client import PFClient
Expand Down Expand Up @@ -149,7 +148,12 @@ def validate_and_interactive_get_secrets(connection, is_update=False):
if not missing_secrets_prompt:
print(prompt)
missing_secrets_prompt = True
connection.secrets[name] = getpass.getpass(prompt=f"{name}: ")
while True:
secret = get_secret_input(prompt=f"{name}: ")
if secret:
break
print_yellow_warning("Secret can't be empty.")
connection.secrets[name] = secret
if missing_secrets_prompt:
print("=================== Required secrets collected ===================")
return connection
Expand Down
63 changes: 62 additions & 1 deletion src/promptflow/promptflow/_cli/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dotenv import load_dotenv
from tabulate import tabulate

from promptflow._sdk._utils import print_red_error
from promptflow._sdk._utils import print_red_error, print_yellow_warning
from promptflow._utils.utils import is_in_ci_pipeline
from promptflow.exceptions import ErrorTarget, PromptflowException, UserErrorException

Expand Down Expand Up @@ -349,3 +349,64 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def get_secret_input(prompt, mask="*"):
"""Get secret input with mask printed on screen in CLI.
Provide better handling for control characters:
- Handle Ctrl-C as KeyboardInterrupt
- Ignore control characters and print warning message.
"""
if not isinstance(prompt, str):
raise TypeError(f"prompt must be a str, not ${type(prompt).__name__}")
if not isinstance(mask, str):
raise TypeError(f"mask argument must be a one-character str, not ${type(mask).__name__}")
if len(mask) != 1:
raise ValueError("mask argument must be a one-character str")

if sys.platform == "win32":
# For some reason, mypy reports that msvcrt doesn't have getch, ignore this warning:
from msvcrt import getch # type: ignore
else: # macOS and Linux
import tty
import termios

def getch():
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setraw(sys.stdin.fileno())
ch = sys.stdin.read(1)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
return ch

secret_input = []
sys.stdout.write(prompt)
sys.stdout.flush()

while True:
key = ord(getch())
if key == 13: # Enter key pressed.
sys.stdout.write("\n")
return "".join(secret_input)
elif key == 3: # Ctrl-C pressed.
raise KeyboardInterrupt()
elif key in (8, 127): # Backspace/Del key erases previous output.
if len(secret_input) > 0:
# Erases previous character.
sys.stdout.write("\b \b") # \b doesn't erase the character, it just moves the cursor back.
sys.stdout.flush()
secret_input = secret_input[:-1]
elif 0 <= key <= 31:
msg = "\nThe last user input got ignored as it is control character."
print_yellow_warning(msg)
sys.stdout.write(prompt + mask * len(secret_input))
sys.stdout.flush()
else:
# display the mask character.
char = chr(key)
sys.stdout.write(mask)
sys.stdout.flush()
secret_input.append(char)
2 changes: 1 addition & 1 deletion src/promptflow/promptflow/_sdk/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _get_from_keyring():
raise StoreConnectionEncryptionKeyError(
"System keyring backend service not found in your operating system. "
"See https://pypi.org/project/keyring/ to install requirement for different operating system, "
"or 'pip install keyrings.alt' to use the third-party backend. Reach more detail about this error at"
"or 'pip install keyrings.alt' to use the third-party backend. Reach more detail about this error at "
"https://microsoft.github.io/promptflow/how-to-guides/faq.html#connection-creation-failed-with-storeconnectionencryptionkeyerror" # noqa: E501
) from e

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_validate_and_interactive_get_secrets(self):
name="test_connection",
secrets={"key1": SCRUBBED_VALUE, "key2": "", "key3": "<no-change>", "key4": "<user-input>", "key5": "**"},
)
with patch("getpass.getpass", new=lambda prompt: "test_value"):
with patch("promptflow._cli._pf._connection.get_secret_input", new=lambda prompt: "test_value"):
validate_and_interactive_get_secrets(connection, is_update=False)
assert connection.secrets == {
"key1": "test_value",
Expand All @@ -260,7 +260,7 @@ def test_validate_and_interactive_get_secrets(self):
name="test_connection",
secrets={"key1": SCRUBBED_VALUE, "key2": "", "key3": "<no-change>", "key4": "<user-input>", "key5": "**"},
)
with patch("getpass.getpass", new=lambda prompt: "test_value"):
with patch("promptflow._cli._pf._connection.get_secret_input", new=lambda prompt: "test_value"):
validate_and_interactive_get_secrets(connection, is_update=True)
assert connection.secrets == {
"key1": SCRUBBED_VALUE,
Expand Down

0 comments on commit adddc9b

Please sign in to comment.