Skip to content

Commit d84f3a1

Browse files
authored
connection/aws_ssm - refactor exec_command function to improve maintanability (#2226)
SUMMARY Refer to https://issues.redhat.com/browse/ACA-2093 Refactor exec_command() and add unit tests ISSUE TYPE Feature Pull Request COMPONENT NAME connection/aws_ssm Reviewed-by: GomathiselviS <[email protected]> Reviewed-by: Bikouo Aubin Reviewed-by: Mark Chappell
1 parent 56b0886 commit d84f3a1

File tree

5 files changed

+323
-79
lines changed

5 files changed

+323
-79
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
minor_changes:
3+
- aws_ssm - Refactor exec_command Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/2224).

plugins/connection/aws_ssm.py

+120-79
Original file line numberDiff line numberDiff line change
@@ -285,9 +285,9 @@
285285
name: nginx
286286
state: present
287287
"""
288-
import os
289288
import getpass
290289
import json
290+
import os
291291
import pty
292292
import random
293293
import re
@@ -296,6 +296,8 @@
296296
import subprocess
297297
import time
298298
from typing import Optional
299+
from typing import NoReturn
300+
from typing import Tuple
299301

300302
try:
301303
import boto3
@@ -304,18 +306,19 @@
304306
pass
305307

306308
from functools import wraps
307-
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3
309+
308310
from ansible.errors import AnsibleConnectionFailure
309311
from ansible.errors import AnsibleError
310312
from ansible.errors import AnsibleFileNotFound
311-
from ansible.module_utils.basic import missing_required_lib
312-
from ansible.module_utils.six.moves import xrange
313313
from ansible.module_utils._text import to_bytes
314314
from ansible.module_utils._text import to_text
315+
from ansible.module_utils.basic import missing_required_lib
315316
from ansible.plugins.connection import ConnectionBase
316317
from ansible.plugins.shell.powershell import _common_args
317318
from ansible.utils.display import Display
318319

320+
from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3
321+
319322
display = Display()
320323

321324

@@ -375,6 +378,29 @@ def chunks(lst, n):
375378
yield lst[i:i + n] # fmt: skip
376379

377380

381+
def filter_ansi(line: str, is_windows: bool) -> str:
382+
"""Remove any ANSI terminal control codes.
383+
384+
:param line: The input line.
385+
:param is_windows: Whether the output is coming from a Windows host.
386+
:returns: The result line.
387+
"""
388+
line = to_text(line)
389+
390+
if is_windows:
391+
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
392+
line = osc_filter.sub("", line)
393+
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
394+
line = ansi_filter.sub("", line)
395+
396+
# Replace or strip sequence (at terminal width)
397+
line = line.replace("\r\r\n", "\n")
398+
if len(line) == 201:
399+
line = line[:-1]
400+
401+
return line
402+
403+
378404
class Connection(ConnectionBase):
379405
"""AWS SSM based connections"""
380406

@@ -401,6 +427,9 @@ def __init__(self, *args, **kwargs):
401427
raise AnsibleError(missing_required_lib("boto3"))
402428

403429
self.host = self._play_context.remote_addr
430+
self._instance_id = None
431+
self._polling_obj = None
432+
self._has_timeout = False
404433

405434
if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
406435
self.delegate = None
@@ -541,14 +570,19 @@ def reset(self):
541570
self.close()
542571
return self.start_session()
543572

573+
@property
574+
def instance_id(self) -> str:
575+
if not self._instance_id:
576+
self._instance_id = self.host if self.get_option("instance_id") is None else self.get_option("instance_id")
577+
return self._instance_id
578+
579+
@instance_id.setter
580+
def instance_id(self, instance_id: str) -> NoReturn:
581+
self._instance_id = instance_id
582+
544583
def start_session(self):
545584
"""start ssm session"""
546585

547-
if self.get_option("instance_id") is None:
548-
self.instance_id = self.host
549-
else:
550-
self.instance_id = self.get_option("instance_id")
551-
552586
self._vvv(f"ESTABLISH SSM CONNECTION TO: {self.instance_id}")
553587

554588
executable = self.get_option("plugin")
@@ -592,8 +626,6 @@ def start_session(self):
592626
os.close(stdout_w)
593627
self._stdout = os.fdopen(stdout_r, "rb", 0)
594628
self._session = session
595-
self._poll_stdout = select.poll()
596-
self._poll_stdout.register(self._stdout, select.POLLIN)
597629

598630
# Disable command echo and prompt.
599631
self._prepare_terminal()
@@ -602,49 +634,56 @@ def start_session(self):
602634

603635
return session
604636

605-
@_ssm_retry
606-
def exec_command(self, cmd, in_data=None, sudoable=True):
607-
"""run a command on the ssm host"""
608-
609-
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
610-
611-
self._vvv(f"EXEC: {to_text(cmd)}")
612-
613-
session = self._session
614-
615-
mark_begin = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
616-
if self.is_windows:
617-
mark_start = mark_begin + " $LASTEXITCODE"
618-
else:
619-
mark_start = mark_begin
620-
mark_end = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
637+
def poll_stdout(self, timeout: int = 1000) -> bool:
638+
"""Polls the stdout file descriptor.
621639
622-
# Wrap command in markers accordingly for the shell used
623-
cmd = self._wrap_command(cmd, sudoable, mark_start, mark_end)
624-
625-
self._flush_stderr(session)
640+
:param timeout: Specifies the length of time in milliseconds which the system will wait.
641+
:returns: A boolean to specify the polling result
642+
"""
643+
if self._polling_obj is None:
644+
self._polling_obj = select.poll()
645+
self._polling_obj.register(self._stdout, select.POLLIN)
646+
return bool(self._polling_obj.poll(timeout))
626647

627-
for chunk in chunks(cmd, 1024):
628-
session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))
648+
def poll(self, label: str, cmd: str) -> NoReturn:
649+
"""Poll session to retrieve content from stdout.
629650
651+
:param label: A label for the display (EXEC, PRE...)
652+
:param cmd: The command being executed
653+
"""
654+
start = round(time.time())
655+
yield self.poll_stdout()
656+
timeout = self.get_option("ssm_timeout")
657+
while self._session.poll() is None:
658+
remaining = start + timeout - round(time.time())
659+
self._vvvv(f"{label} remaining: {remaining} second(s)")
660+
if remaining < 0:
661+
self._has_timeout = True
662+
raise AnsibleConnectionFailure(f"{label} command '{cmd}' timeout on host: {self.instance_id}")
663+
yield self.poll_stdout()
664+
665+
def exec_communicate(self, cmd: str, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]:
666+
"""Interact with session.
667+
Read stdout between the markers until 'mark_end' is reached.
668+
669+
:param cmd: The command being executed.
670+
:param mark_start: The marker which starts the output.
671+
:param mark_begin: The begin marker.
672+
:param mark_end: The end marker.
673+
:returns: A tuple with the return code, the stdout and the stderr content.
674+
"""
630675
# Read stdout between the markers
631676
stdout = ""
632677
win_line = ""
633678
begin = False
634-
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
635-
while session.poll() is None:
636-
remaining = stop_time - int(round(time.time()))
637-
if remaining < 1:
638-
self._timeout = True
639-
self._vvvv(f"EXEC timeout stdout: \n{to_text(stdout)}")
640-
raise AnsibleConnectionFailure(f"SSM exec_command timeout on host: {self.instance_id}")
641-
if self._poll_stdout.poll(1000):
642-
line = self._filter_ansi(self._stdout.readline())
643-
self._vvvv(f"EXEC stdout line: \n{to_text(line)}")
644-
else:
645-
self._vvvv(f"EXEC remaining: {remaining}")
679+
returncode = None
680+
for poll_result in self.poll("EXEC", cmd):
681+
if not poll_result:
646682
continue
647683

684+
line = filter_ansi(self._stdout.readline(), self.is_windows)
685+
self._vvvv(f"EXEC stdout line: \n{line}")
686+
648687
if not begin and self.is_windows:
649688
win_line = win_line + line
650689
line = win_line
@@ -662,9 +701,33 @@ def exec_command(self, cmd, in_data=None, sudoable=True):
662701
break
663702
stdout = stdout + line
664703

665-
stderr = self._flush_stderr(session)
704+
# see https://github.com/pylint-dev/pylint/issues/8909)
705+
return (returncode, stdout, self._flush_stderr(self._session)) # pylint: disable=unreachable
706+
707+
@_ssm_retry
708+
def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]:
709+
"""run a command on the ssm host"""
710+
711+
super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
666712

667-
return (returncode, stdout, stderr)
713+
self._vvv(f"EXEC: {to_text(cmd)}")
714+
715+
mark_begin = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
716+
if self.is_windows:
717+
mark_start = mark_begin + " $LASTEXITCODE"
718+
else:
719+
mark_start = mark_begin
720+
mark_end = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
721+
722+
# Wrap command in markers accordingly for the shell used
723+
cmd = self._wrap_command(cmd, mark_start, mark_end)
724+
725+
self._flush_stderr(self._session)
726+
727+
for chunk in chunks(cmd, 1024):
728+
self._session.stdin.write(to_bytes(chunk, errors="surrogate_or_strict"))
729+
730+
return self.exec_communicate(cmd, mark_start, mark_begin, mark_end)
668731

669732
def _prepare_terminal(self):
670733
"""perform any one-time terminal settings"""
@@ -682,7 +745,7 @@ def _prepare_terminal(self):
682745
disable_echo_cmd = to_bytes("stty -echo\n", errors="surrogate_or_strict")
683746

684747
disable_prompt_complete = None
685-
end_mark = "".join([random.choice(string.ascii_letters) for i in xrange(self.MARK_LENGTH)])
748+
end_mark = "".join([random.choice(string.ascii_letters) for i in range(self.MARK_LENGTH)])
686749
disable_prompt_cmd = to_bytes(
687750
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\n%s\\n' '" + end_mark + "'\n",
688751
errors="surrogate_or_strict",
@@ -691,18 +754,12 @@ def _prepare_terminal(self):
691754

692755
stdout = ""
693756
# Custom command execution for when we're waiting for startup
694-
stop_time = int(round(time.time())) + self.get_option("ssm_timeout")
695-
while (not disable_prompt_complete) and (self._session.poll() is None):
696-
remaining = stop_time - int(round(time.time()))
697-
if remaining < 1:
698-
self._timeout = True
699-
self._vvvv(f"PRE timeout stdout: \n{to_bytes(stdout)}")
700-
raise AnsibleConnectionFailure(f"SSM start_session timeout on host: {self.instance_id}")
701-
if self._poll_stdout.poll(1000):
757+
for poll_result in self.poll("PRE", "start_session"):
758+
if disable_prompt_complete:
759+
break
760+
if poll_result:
702761
stdout += to_text(self._stdout.read(1024))
703762
self._vvvv(f"PRE stdout line: \n{to_bytes(stdout)}")
704-
else:
705-
self._vvvv(f"PRE remaining: {remaining}")
706763

707764
# wait til prompt is ready
708765
if startup_complete is False:
@@ -734,12 +791,13 @@ def _prepare_terminal(self):
734791
stdout = stdout[match.end():] # fmt: skip
735792
disable_prompt_complete = True
736793

737-
if not disable_prompt_complete:
794+
# see https://github.com/pylint-dev/pylint/issues/8909)
795+
if not disable_prompt_complete: # pylint: disable=unreachable
738796
raise AnsibleConnectionFailure(f"SSM process closed during _prepare_terminal on host: {self.instance_id}")
739797
self._vvvv("PRE Terminal configured")
740798

741-
def _wrap_command(self, cmd, sudoable, mark_start, mark_end):
742-
"""wrap command so stdout and status can be extracted"""
799+
def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
800+
"""Wrap command so stdout and status can be extracted"""
743801

744802
if self.is_windows:
745803
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
@@ -789,23 +847,6 @@ def _post_process(self, stdout, mark_begin):
789847

790848
return (returncode, stdout)
791849

792-
def _filter_ansi(self, line):
793-
"""remove any ANSI terminal control codes"""
794-
line = to_text(line)
795-
796-
if self.is_windows:
797-
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
798-
line = osc_filter.sub("", line)
799-
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
800-
line = ansi_filter.sub("", line)
801-
802-
# Replace or strip sequence (at terminal width)
803-
line = line.replace("\r\r\n", "\n")
804-
if len(line) == 201:
805-
line = line[:-1]
806-
807-
return line
808-
809850
def _flush_stderr(self, session_process):
810851
"""read and return stderr with minimal blocking"""
811852

@@ -995,7 +1036,7 @@ def close(self):
9951036
"""terminate the connection"""
9961037
if self._session_id:
9971038
self._vvv(f"CLOSING SSM CONNECTION TO: {self.instance_id}")
998-
if self._timeout:
1039+
if self._has_timeout:
9991040
self._session.terminate()
10001041
else:
10011042
cmd = b"\nexit\n"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# This file is part of Ansible
4+
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
5+
6+
from unittest.mock import MagicMock
7+
8+
import pytest
9+
10+
from ansible_collections.community.aws.plugins.connection.aws_ssm import Connection
11+
from ansible_collections.community.aws.plugins.connection.aws_ssm import ConnectionBase
12+
13+
14+
@pytest.fixture(name="connection_aws_ssm")
15+
def fixture_connection_aws_ssm():
16+
play_context = MagicMock()
17+
play_context.shell = True
18+
19+
def connection_init(*args, **kwargs):
20+
pass
21+
22+
Connection.__init__ = connection_init
23+
ConnectionBase.exec_command = MagicMock()
24+
connection = Connection()
25+
26+
connection._instance_id = "i-0a1b2c3d4e5f"
27+
connection._polling_obj = None
28+
connection._has_timeout = False
29+
connection.is_windows = False
30+
31+
connection.poll_stdout = MagicMock()
32+
connection._session = MagicMock()
33+
connection._session.poll = MagicMock()
34+
connection._session.poll.side_effect = lambda: None
35+
connection._stdout = MagicMock()
36+
connection._flush_stderr = MagicMock()
37+
38+
def display_msg(msg):
39+
print("--- AWS SSM CONNECTION --- ", msg)
40+
41+
connection._v = MagicMock()
42+
connection._v.side_effect = display_msg
43+
44+
connection._vv = MagicMock()
45+
connection._vv.side_effect = display_msg
46+
47+
connection._vvv = MagicMock()
48+
connection._vvv.side_effect = display_msg
49+
50+
connection._vvvv = MagicMock()
51+
connection._vvvv.side_effect = display_msg
52+
53+
connection.get_option = MagicMock()
54+
return connection

0 commit comments

Comments
 (0)