From 6d636f2dc43efeb60c9a806625cd374b0a98429f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=2EM=2E=20Fern=C3=A1ndez?= Date: Tue, 17 Nov 2020 11:39:18 +0100 Subject: [PATCH] Fix private keys are not found on Linux/Windows (#216) * Fix #161 * Fix #145 * use ephemeral ports for tests Co-authored-by: Pavel White --- appveyor.yml | 26 +++++++++++--- sshtunnel.py | 75 +++++++++++++++++++++-------------------- tests/test_forwarder.py | 28 ++++++++------- 3 files changed, 76 insertions(+), 53 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index b48e044a..ad754699 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -11,6 +11,14 @@ environment: PYTHON_VERSION: "2.7.x" PYTHON_ARCH: "64" + - PYTHON: "C:\\Python35" + PYTHON_VERSION: "3.5.x" + PYTHON_ARCH: "32" + + - PYTHON: "C:\\Python35-x64" + PYTHON_VERSION: "3.5.x" + PYTHON_ARCH: "64" + - PYTHON: "C:\\Python36" PYTHON_VERSION: "3.6.x" PYTHON_ARCH: "32" @@ -26,18 +34,26 @@ environment: - PYTHON: "C:\\Python37-x64" PYTHON_VERSION: "3.7.x" PYTHON_ARCH: "64" + + - PYTHON: "C:\\Python38" + PYTHON_VERSION: "3.8.x" + PYTHON_ARCH: "32" + + - PYTHON: "C:\\Python38-x64" + PYTHON_VERSION: "3.8.x" + PYTHON_ARCH: "64" init: - "ECHO %PYTHON% %PYTHON_VERSION% %PYTHON_ARCH%" install: - set "PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Tools\\Scripts;%PATH%" - - "python -m pip install --upgrade pip" - - "pip install paramiko" - - "pip install mock pytest pytest-cov pytest-xdist" + - python -m pip install --upgrade pip + - pip install paramiko + - pip install mock pytest pytest-cov pytest-xdist build: off test_script: - - "python setup.py install" - - "py.test --showlocals --durations=10 -n4 tests" + - python setup.py install + - py.test --showlocals --durations=10 -n4 tests diff --git a/sshtunnel.py b/sshtunnel.py index 0c6d1ca3..15de4ccf 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -59,13 +59,10 @@ DEFAULT_LOGLEVEL = logging.ERROR #: default level if no logger passed (ERROR) TRACE_LEVEL = 1 logging.addLevelName(TRACE_LEVEL, 'TRACE') +DEFAULT_SSH_DIRECTORY = '~/.ssh' -if os.name == 'posix': - DEFAULT_SSH_DIRECTORY = '~/.ssh' - UnixStreamServer = socketserver.UnixStreamServer -else: - DEFAULT_SSH_DIRECTORY = '~/ssh' - UnixStreamServer = socketserver.TCPServer +StreamServer = socketserver.UnixStreamServer if os.name == 'posix' \ + else socketserver.TCPServer #: Path of optional ssh configuration file SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, 'config') @@ -352,7 +349,7 @@ def handle(self): src_addr=src_address, timeout=TUNNEL_TIMEOUT ) - except Exception as e: + except Exception as e: # pragma: no cover msg_tupe = 'ssh ' if isinstance(e, paramiko.SSHException) else '' exc_msg = 'open new channel {0}error: {1}'.format(msg_tupe, e) log_msg = '{0} {1}'.format(self.info, exc_msg) @@ -437,15 +434,15 @@ class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): daemon_threads = _DAEMON -class _UnixStreamForwardServer(UnixStreamServer): +class _StreamForwardServer(StreamServer): """ - Serve over UNIX domain sockets (does not work on Windows) + Serve over domain sockets (does not work on Windows) """ def __init__(self, *args, **kwargs): self.logger = create_logger(kwargs.pop('logger', None)) self.tunnel_ok = queue.Queue(1) - UnixStreamServer.__init__(self, *args, **kwargs) + StreamServer.__init__(self, *args, **kwargs) @property def local_address(self): @@ -472,8 +469,8 @@ def remote_port(self): return self.RequestHandlerClass.remote_address[1] -class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn, - _UnixStreamForwardServer): +class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn, + _StreamForwardServer): """ Allow concurrent connections to each tunnel """ @@ -832,9 +829,9 @@ class Handler(_ForwardHandler): def _make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer - def _make_unix_ssh_forward_server_class(self, remote_address_): - return _ThreadingUnixStreamForwardServer if \ - self._threaded else _UnixStreamForwardServer + def _make_stream_ssh_forward_server_class(self, remote_address_): + return _ThreadingStreamForwardServer if self._threaded \ + else _StreamForwardServer def _make_ssh_forward_server(self, remote_address, local_bind_address): """ @@ -842,10 +839,9 @@ def _make_ssh_forward_server(self, remote_address, local_bind_address): """ _Handler = self._make_ssh_forward_handler_class(remote_address) try: - if isinstance(local_bind_address, string_types): - forward_maker_class = self._make_unix_ssh_forward_server_class - else: - forward_maker_class = self._make_ssh_forward_server_class + forward_maker_class = self._make_stream_ssh_forward_server_class \ + if isinstance(local_bind_address, string_types) \ + else self._make_ssh_forward_server_class _Server = forward_maker_class(remote_address) ssh_forward_server = _Server( local_bind_address, @@ -1085,16 +1081,19 @@ def get_keys(logger=None, host_pkey_directories=None, allow_agent=False): keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \ if allow_agent else [] - if host_pkey_directories is not None: - paramiko_key_types = {'rsa': paramiko.RSAKey, - 'dsa': paramiko.DSSKey, - 'ecdsa': paramiko.ECDSAKey, - 'ed25519': paramiko.Ed25519Key} - for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]: - for keytype in paramiko_key_types.keys(): - ssh_pkey_expanded = os.path.expanduser( - os.path.join(directory, 'id_{}'.format(keytype)) - ) + if host_pkey_directories is None: + host_pkey_directories = [DEFAULT_SSH_DIRECTORY] + + paramiko_key_types = {'rsa': paramiko.RSAKey, + 'dsa': paramiko.DSSKey, + 'ecdsa': paramiko.ECDSAKey, + 'ed25519': paramiko.Ed25519Key} + for directory in host_pkey_directories: + for keytype in paramiko_key_types.keys(): + ssh_pkey_expanded = os.path.expanduser( + os.path.join(directory, 'id_{}'.format(keytype)) + ) + try: if os.path.isfile(ssh_pkey_expanded): ssh_pkey = SSHTunnelForwarder.read_private_key_file( pkey_file=ssh_pkey_expanded, @@ -1103,11 +1102,12 @@ def get_keys(logger=None, host_pkey_directories=None, allow_agent=False): ) if ssh_pkey: keys.append(ssh_pkey) + except OSError as exc: + if logger: + logger.warning('Private key file {0} check error: {1}' + .format(ssh_pkey_expanded, exc)) if logger: - logger.info('{0} keys loaded from host directory'.format( - len(keys)) - ) - + logger.info('{0} key(s) loaded'.format(len(keys))) return keys @staticmethod @@ -1455,12 +1455,12 @@ def _stop_transport(self, force=False): _srv.shutdown() _srv.server_close() # clean up the UNIX domain socket if we're using one - if isinstance(_srv, _UnixStreamForwardServer): + if isinstance(_srv, _StreamForwardServer): try: os.unlink(_srv.local_address) except Exception as e: self.logger.error('Unable to unlink socket {0}: {1}' - .format(self.local_address, repr(e))) + .format(_srv.local_address, repr(e))) self.is_alive = False if self.is_active: self.logger.info('Closing ssh transport') @@ -1862,7 +1862,7 @@ def _parse_arguments(args=None): return vars(parser.parse_args(args)) -def _cli_main(args=None): +def _cli_main(args=None, **extras): """ Pass input arguments to open_tunnel Mandatory: ssh_address, -R (remote bind address list) @@ -1894,6 +1894,9 @@ def _cli_main(args=None): logging.DEBUG, TRACE_LEVEL] arguments.setdefault('debug_level', levels[verbosity]) + # do this while supporting py27/py34 instead of merging dicts + for (extra, value) in extras.items(): + arguments.setdefault(extra, value) with open_tunnel(**arguments) as tunnel: if tunnel.is_alive: input_(''' diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 78a11d3e..40662d08 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -74,6 +74,7 @@ def capture_stdout_stderr(): ssh_config_file=None, allow_agent=False, skip_tunnel_checkup=True, + host_pkey_directories=[], ) # CONSTANTS @@ -409,9 +410,9 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.info('<<< forward-server received STOP signal') except socket.error: self.log.critical('{0} sending RST'.format(info)) - except Exception as e: - # we reach this point usually when schan is None (paramiko bug?) - self.log.critical(repr(e)) + # except Exception as e: + # # we reach this point usually when schan is None (paramiko bug?) + # self.log.critical(repr(e)) finally: if schan: self.log.debug('{0} closing connection...'.format(info)) @@ -420,7 +421,7 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT): self.log.debug('{0} connection closed.'.format(info)) def randomize_eport(self): - return self.eport + random.randint(1, 999) + return random.randint(49152, 65535) def test_echo_server(self): with self._test_server( @@ -487,7 +488,8 @@ def test_open_tunnel(self): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ssh_config_file=None, - allow_agent=False + allow_agent=False, + host_pkey_directories=[], ) self.assertEqual(server.ssh_host, self.saddr) self.assertEqual(server.ssh_port, self.sport) @@ -1005,7 +1007,8 @@ def test_cli_main_exits_when_pressing_enter(self, input): '-R', '{0}:{1}'.format(self.eaddr, self.eport), '-c', '', - '-n']) + '-n'], + host_pkey_directories=[]) self.stop_echo_and_ssh_server() @unittest.skipIf(sys.version_info < (2, 7), @@ -1121,7 +1124,8 @@ def check_make_ssh_forward_server_sets_daemon(self, case): remote_bind_address=(self.eaddr, self.eport), logger=self.log, ssh_config_file=None, - allow_agent=False + allow_agent=False, + host_pkey_directories=[], ) try: tunnel.daemon_forward_servers = case @@ -1181,7 +1185,7 @@ def test_get_keys(self): ) self.assertIsInstance(keys, list) self.assertTrue( - any('1 keys loaded from host directory' in msg + any('1 key(s) loaded' in msg for msg in self.sshtunnel_log_messages['info']) ) shutil.rmtree(tmp_dir) @@ -1382,16 +1386,16 @@ def test_process_deprecations(self): 'item', kwargs.copy()) - def check_address(self): + def test_check_address(self): """ Test that an exception is raised with incorrect bind addresses """ address_list = [('10.0.0.1', 10000), ('10.0.0.1', 10001)] if os.name == 'posix': # UNIX sockets supported by the platform address_list.append('/tmp/unix-socket') + # UNIX sockets not supported on remote addresses + with self.assertRaises(AssertionError): + sshtunnel.check_addresses(address_list, is_remote=True) self.assertIsNone(sshtunnel.check_addresses(address_list)) - # UNIX sockets not supported on remote addresses - with self.assertRaises(AssertionError): - sshtunnel.check_addresses(address_list, is_remote=True) with self.assertRaises(ValueError): sshtunnel.check_address('this is not valid') with self.assertRaises(ValueError):