diff --git a/dvc/prompt.py b/dvc/prompt.py index cb12a73b93..e1ba18f90a 100644 --- a/dvc/prompt.py +++ b/dvc/prompt.py @@ -54,3 +54,16 @@ def password(statement): """ logger.info(f"{statement}: ") return getpass("") + + +def username(statement): + """Ask the user for a username. + + Args: + statement (str): string to prompt the user with. + + Returns: + str: username entered by the user. + """ + logger.info(f"{statement}: ") + return input("") diff --git a/dvc/tree/http.py b/dvc/tree/http.py index f3b762291e..6f6e6fb714 100644 --- a/dvc/tree/http.py +++ b/dvc/tree/http.py @@ -2,7 +2,7 @@ import os.path import threading -from funcy import cached_property, memoize, wrap_prop, wrap_with +from funcy import cached_property, wrap_prop, wrap_with import dvc.prompt as prompt from dvc.exceptions import DvcException, HTTPError @@ -17,7 +17,11 @@ @wrap_with(threading.Lock()) -@memoize +def ask_username(host): + return prompt.username("Username for `{host}`".format(host=host)) + + +@wrap_with(threading.Lock()) def ask_password(host, user): return prompt.password( "Enter a password for " @@ -33,41 +37,94 @@ class HTTPTree(BaseTree): # pylint:disable=abstract-method SESSION_RETRIES = 5 SESSION_BACKOFF_FACTOR = 0.1 + AUTHENTICATION_RETRIES = 3 REQUEST_TIMEOUT = 60 CHUNK_SIZE = 2 ** 16 def __init__(self, repo, config): super().__init__(repo, config) + self.path_info = None url = config.get("url") if url: self.path_info = self.PATH_CLS(url) - user = config.get("user", None) - if user: - self.path_info.user = user - else: - self.path_info = None + self.user = self._get_user() + self.password = self._get_password(self.user) self.auth = config.get("auth", None) self.custom_auth_header = config.get("custom_auth_header", None) - self.password = config.get("password", None) self.ask_password = config.get("ask_password", False) self.headers = {} self.ssl_verify = config.get("ssl_verify", True) self.method = config.get("method", "POST") + def _get_user(self): + user = self.config.get("user") + if user is not None: + return user + + path_info = self.path_info + if path_info is None: + return None + + import keyring + + host = path_info.host + return keyring.get_password(host, "user") + + def _get_password(self, user): + if user is None: + return None + + password = self.config.get("password") + if password is not None: + return password + + import keyring + + path_info = self.path_info + if path_info is None: + return None + host = path_info.host + return keyring.get_password(host, user) + + @wrap_with(threading.Lock()) + def _get_basic_auth(self, path_info): + from requests.auth import HTTPBasicAuth + + host = path_info.host + user = self.user + password = self.password + + if user is None: + user = ask_username(host) + self.user = user + if password is None or self.ask_password: + password = ask_password(host, user) + self.password = password + return HTTPBasicAuth(user, password) + + @wrap_with(threading.Lock()) + def _save_basic_auth(self, auth): + import keyring + + user, password = auth.username, auth.password + host = self.path_info.host if self.path_info else None + self.user = user + keyring.set_password(host, "user", user) + self.password = password + if not self.ask_password: + keyring.set_password(host, user, password) + def _auth_method(self, path_info=None): - from requests.auth import HTTPBasicAuth, HTTPDigestAuth + from requests.auth import HTTPDigestAuth if path_info is None: path_info = self.path_info if self.auth: - if self.ask_password and self.password is None: - host, user = path_info.host, path_info.user - self.password = ask_password(host, user) if self.auth == "basic": - return HTTPBasicAuth(path_info.user, self.password) + return self._get_basic_auth(path_info) if self.auth == "digest": return HTTPDigestAuth(path_info.user, self.password) if self.auth == "custom" and self.custom_auth_header: @@ -98,7 +155,32 @@ def _session(self): return session + @staticmethod + def requires_basic_auth(res): + if res.status_code != 401: + return False + auth_header = res.headers.get("WWW-Authenticate") + if auth_header is None: + return False + return auth_header.lower().startswith("basic ") + def request(self, method, url, **kwargs): + auth_method = self._auth_method() + res = self._request(method, url, auth_method, **kwargs) + if self.requires_basic_auth(res): + self.auth = "basic" + for _ in range(self.AUTHENTICATION_RETRIES): + auth_method = self._auth_method() + res = self._request(method, url, auth_method, **kwargs) + if res.status_code == 200: + self._save_basic_auth(auth_method) + break + else: + self.password = None + self.user = None + return res + + def _request(self, method, url, auth_method, **kwargs): import requests kwargs.setdefault("allow_redirects", True) @@ -106,11 +188,7 @@ def request(self, method, url, **kwargs): try: res = self._session.request( - method, - url, - auth=self._auth_method(), - headers=self.headers, - **kwargs, + method, url, auth=auth_method, headers=self.headers, **kwargs, ) redirect_no_location = ( diff --git a/setup.py b/setup.py index e324885ab6..8d780fad3a 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,7 @@ def run(self): "python-benedict>=0.21.1", "pyparsing==2.4.7", "typing_extensions>=3.7.4", + "keyring==21.8.0", ]