Skip to content

Commit

Permalink
Fix private keys are not found on Linux/Windows (#216)
Browse files Browse the repository at this point in the history
* Fix #161
* Fix #145 
* use ephemeral ports for tests

Co-authored-by: Pavel White <[email protected]>
  • Loading branch information
fernandezcuesta and pahaz authored Nov 17, 2020
1 parent cd374d5 commit 6d636f2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 53 deletions.
26 changes: 21 additions & 5 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@ environment:
PYTHON_VERSION: "2.7.x"
PYTHON_ARCH: "64"

- PYTHON: "C:\\Python35"
PYTHON_VERSION: "3.5.x"
PYTHON_ARCH: "32"

- PYTHON: "C:\\Python35-x64"
PYTHON_VERSION: "3.5.x"
PYTHON_ARCH: "64"

- PYTHON: "C:\\Python36"
PYTHON_VERSION: "3.6.x"
PYTHON_ARCH: "32"
Expand All @@ -26,18 +34,26 @@ environment:
- PYTHON: "C:\\Python37-x64"
PYTHON_VERSION: "3.7.x"
PYTHON_ARCH: "64"

- PYTHON: "C:\\Python38"
PYTHON_VERSION: "3.8.x"
PYTHON_ARCH: "32"

- PYTHON: "C:\\Python38-x64"
PYTHON_VERSION: "3.8.x"
PYTHON_ARCH: "64"
init:
- "ECHO %PYTHON% %PYTHON_VERSION% %PYTHON_ARCH%"

install:
- set "PATH=%PYTHON%;%PYTHON%\\Scripts;%PYTHON%\\Tools\\Scripts;%PATH%"
- "python -m pip install --upgrade pip"
- "pip install paramiko"
- "pip install mock pytest pytest-cov pytest-xdist"
- python -m pip install --upgrade pip
- pip install paramiko
- pip install mock pytest pytest-cov pytest-xdist

build: off

test_script:
- "python setup.py install"
- "py.test --showlocals --durations=10 -n4 tests"
- python setup.py install
- py.test --showlocals --durations=10 -n4 tests

75 changes: 39 additions & 36 deletions sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,10 @@
DEFAULT_LOGLEVEL = logging.ERROR #: default level if no logger passed (ERROR)
TRACE_LEVEL = 1
logging.addLevelName(TRACE_LEVEL, 'TRACE')
DEFAULT_SSH_DIRECTORY = '~/.ssh'

if os.name == 'posix':
DEFAULT_SSH_DIRECTORY = '~/.ssh'
UnixStreamServer = socketserver.UnixStreamServer
else:
DEFAULT_SSH_DIRECTORY = '~/ssh'
UnixStreamServer = socketserver.TCPServer
StreamServer = socketserver.UnixStreamServer if os.name == 'posix' \
else socketserver.TCPServer

#: Path of optional ssh configuration file
SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, 'config')
Expand Down Expand Up @@ -352,7 +349,7 @@ def handle(self):
src_addr=src_address,
timeout=TUNNEL_TIMEOUT
)
except Exception as e:
except Exception as e: # pragma: no cover
msg_tupe = 'ssh ' if isinstance(e, paramiko.SSHException) else ''
exc_msg = 'open new channel {0}error: {1}'.format(msg_tupe, e)
log_msg = '{0} {1}'.format(self.info, exc_msg)
Expand Down Expand Up @@ -437,15 +434,15 @@ class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer):
daemon_threads = _DAEMON


class _UnixStreamForwardServer(UnixStreamServer):
class _StreamForwardServer(StreamServer):
"""
Serve over UNIX domain sockets (does not work on Windows)
Serve over domain sockets (does not work on Windows)
"""

def __init__(self, *args, **kwargs):
self.logger = create_logger(kwargs.pop('logger', None))
self.tunnel_ok = queue.Queue(1)
UnixStreamServer.__init__(self, *args, **kwargs)
StreamServer.__init__(self, *args, **kwargs)

@property
def local_address(self):
Expand All @@ -472,8 +469,8 @@ def remote_port(self):
return self.RequestHandlerClass.remote_address[1]


class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn,
_UnixStreamForwardServer):
class _ThreadingStreamForwardServer(socketserver.ThreadingMixIn,
_StreamForwardServer):
"""
Allow concurrent connections to each tunnel
"""
Expand Down Expand Up @@ -832,20 +829,19 @@ class Handler(_ForwardHandler):
def _make_ssh_forward_server_class(self, remote_address_):
return _ThreadingForwardServer if self._threaded else _ForwardServer

def _make_unix_ssh_forward_server_class(self, remote_address_):
return _ThreadingUnixStreamForwardServer if \
self._threaded else _UnixStreamForwardServer
def _make_stream_ssh_forward_server_class(self, remote_address_):
return _ThreadingStreamForwardServer if self._threaded \
else _StreamForwardServer

def _make_ssh_forward_server(self, remote_address, local_bind_address):
"""
Make SSH forward proxy Server class
"""
_Handler = self._make_ssh_forward_handler_class(remote_address)
try:
if isinstance(local_bind_address, string_types):
forward_maker_class = self._make_unix_ssh_forward_server_class
else:
forward_maker_class = self._make_ssh_forward_server_class
forward_maker_class = self._make_stream_ssh_forward_server_class \
if isinstance(local_bind_address, string_types) \
else self._make_ssh_forward_server_class
_Server = forward_maker_class(remote_address)
ssh_forward_server = _Server(
local_bind_address,
Expand Down Expand Up @@ -1085,16 +1081,19 @@ def get_keys(logger=None, host_pkey_directories=None, allow_agent=False):
keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \
if allow_agent else []

if host_pkey_directories is not None:
paramiko_key_types = {'rsa': paramiko.RSAKey,
'dsa': paramiko.DSSKey,
'ecdsa': paramiko.ECDSAKey,
'ed25519': paramiko.Ed25519Key}
for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]:
for keytype in paramiko_key_types.keys():
ssh_pkey_expanded = os.path.expanduser(
os.path.join(directory, 'id_{}'.format(keytype))
)
if host_pkey_directories is None:
host_pkey_directories = [DEFAULT_SSH_DIRECTORY]

paramiko_key_types = {'rsa': paramiko.RSAKey,
'dsa': paramiko.DSSKey,
'ecdsa': paramiko.ECDSAKey,
'ed25519': paramiko.Ed25519Key}
for directory in host_pkey_directories:
for keytype in paramiko_key_types.keys():
ssh_pkey_expanded = os.path.expanduser(
os.path.join(directory, 'id_{}'.format(keytype))
)
try:
if os.path.isfile(ssh_pkey_expanded):
ssh_pkey = SSHTunnelForwarder.read_private_key_file(
pkey_file=ssh_pkey_expanded,
Expand All @@ -1103,11 +1102,12 @@ def get_keys(logger=None, host_pkey_directories=None, allow_agent=False):
)
if ssh_pkey:
keys.append(ssh_pkey)
except OSError as exc:
if logger:
logger.warning('Private key file {0} check error: {1}'
.format(ssh_pkey_expanded, exc))
if logger:
logger.info('{0} keys loaded from host directory'.format(
len(keys))
)

logger.info('{0} key(s) loaded'.format(len(keys)))
return keys

@staticmethod
Expand Down Expand Up @@ -1455,12 +1455,12 @@ def _stop_transport(self, force=False):
_srv.shutdown()
_srv.server_close()
# clean up the UNIX domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
if isinstance(_srv, _StreamForwardServer):
try:
os.unlink(_srv.local_address)
except Exception as e:
self.logger.error('Unable to unlink socket {0}: {1}'
.format(self.local_address, repr(e)))
.format(_srv.local_address, repr(e)))
self.is_alive = False
if self.is_active:
self.logger.info('Closing ssh transport')
Expand Down Expand Up @@ -1862,7 +1862,7 @@ def _parse_arguments(args=None):
return vars(parser.parse_args(args))


def _cli_main(args=None):
def _cli_main(args=None, **extras):
""" Pass input arguments to open_tunnel
Mandatory: ssh_address, -R (remote bind address list)
Expand Down Expand Up @@ -1894,6 +1894,9 @@ def _cli_main(args=None):
logging.DEBUG,
TRACE_LEVEL]
arguments.setdefault('debug_level', levels[verbosity])
# do this while supporting py27/py34 instead of merging dicts
for (extra, value) in extras.items():
arguments.setdefault(extra, value)
with open_tunnel(**arguments) as tunnel:
if tunnel.is_alive:
input_('''
Expand Down
28 changes: 16 additions & 12 deletions tests/test_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def capture_stdout_stderr():
ssh_config_file=None,
allow_agent=False,
skip_tunnel_checkup=True,
host_pkey_directories=[],
)

# CONSTANTS
Expand Down Expand Up @@ -409,9 +410,9 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT):
self.log.info('<<< forward-server received STOP signal')
except socket.error:
self.log.critical('{0} sending RST'.format(info))
except Exception as e:
# we reach this point usually when schan is None (paramiko bug?)
self.log.critical(repr(e))
# except Exception as e:
# # we reach this point usually when schan is None (paramiko bug?)
# self.log.critical(repr(e))
finally:
if schan:
self.log.debug('{0} closing connection...'.format(info))
Expand All @@ -420,7 +421,7 @@ def _do_forwarding(self, timeout=sshtunnel.SSH_TIMEOUT):
self.log.debug('{0} connection closed.'.format(info))

def randomize_eport(self):
return self.eport + random.randint(1, 999)
return random.randint(49152, 65535)

def test_echo_server(self):
with self._test_server(
Expand Down Expand Up @@ -487,7 +488,8 @@ def test_open_tunnel(self):
remote_bind_address=(self.eaddr, self.eport),
logger=self.log,
ssh_config_file=None,
allow_agent=False
allow_agent=False,
host_pkey_directories=[],
)
self.assertEqual(server.ssh_host, self.saddr)
self.assertEqual(server.ssh_port, self.sport)
Expand Down Expand Up @@ -1005,7 +1007,8 @@ def test_cli_main_exits_when_pressing_enter(self, input):
'-R', '{0}:{1}'.format(self.eaddr,
self.eport),
'-c', '',
'-n'])
'-n'],
host_pkey_directories=[])
self.stop_echo_and_ssh_server()

@unittest.skipIf(sys.version_info < (2, 7),
Expand Down Expand Up @@ -1121,7 +1124,8 @@ def check_make_ssh_forward_server_sets_daemon(self, case):
remote_bind_address=(self.eaddr, self.eport),
logger=self.log,
ssh_config_file=None,
allow_agent=False
allow_agent=False,
host_pkey_directories=[],
)
try:
tunnel.daemon_forward_servers = case
Expand Down Expand Up @@ -1181,7 +1185,7 @@ def test_get_keys(self):
)
self.assertIsInstance(keys, list)
self.assertTrue(
any('1 keys loaded from host directory' in msg
any('1 key(s) loaded' in msg
for msg in self.sshtunnel_log_messages['info'])
)
shutil.rmtree(tmp_dir)
Expand Down Expand Up @@ -1382,16 +1386,16 @@ def test_process_deprecations(self):
'item',
kwargs.copy())

def check_address(self):
def test_check_address(self):
""" Test that an exception is raised with incorrect bind addresses """
address_list = [('10.0.0.1', 10000),
('10.0.0.1', 10001)]
if os.name == 'posix': # UNIX sockets supported by the platform
address_list.append('/tmp/unix-socket')
# UNIX sockets not supported on remote addresses
with self.assertRaises(AssertionError):
sshtunnel.check_addresses(address_list, is_remote=True)
self.assertIsNone(sshtunnel.check_addresses(address_list))
# UNIX sockets not supported on remote addresses
with self.assertRaises(AssertionError):
sshtunnel.check_addresses(address_list, is_remote=True)
with self.assertRaises(ValueError):
sshtunnel.check_address('this is not valid')
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 6d636f2

Please sign in to comment.