285
285
name: nginx
286
286
state: present
287
287
"""
288
- import os
289
288
import getpass
290
289
import json
290
+ import os
291
291
import pty
292
292
import random
293
293
import re
296
296
import subprocess
297
297
import time
298
298
from typing import Optional
299
+ from typing import NoReturn
300
+ from typing import Tuple
299
301
300
302
try :
301
303
import boto3
304
306
pass
305
307
306
308
from functools import wraps
307
- from ansible_collections . amazon . aws . plugins . module_utils . botocore import HAS_BOTO3
309
+
308
310
from ansible .errors import AnsibleConnectionFailure
309
311
from ansible .errors import AnsibleError
310
312
from ansible .errors import AnsibleFileNotFound
311
- from ansible .module_utils .basic import missing_required_lib
312
- from ansible .module_utils .six .moves import xrange
313
313
from ansible .module_utils ._text import to_bytes
314
314
from ansible .module_utils ._text import to_text
315
+ from ansible .module_utils .basic import missing_required_lib
315
316
from ansible .plugins .connection import ConnectionBase
316
317
from ansible .plugins .shell .powershell import _common_args
317
318
from ansible .utils .display import Display
318
319
320
+ from ansible_collections .amazon .aws .plugins .module_utils .botocore import HAS_BOTO3
321
+
319
322
display = Display ()
320
323
321
324
@@ -375,6 +378,29 @@ def chunks(lst, n):
375
378
yield lst [i :i + n ] # fmt: skip
376
379
377
380
381
+ def filter_ansi (line : str , is_windows : bool ) -> str :
382
+ """Remove any ANSI terminal control codes.
383
+
384
+ :param line: The input line.
385
+ :param is_windows: Whether the output is coming from a Windows host.
386
+ :returns: The result line.
387
+ """
388
+ line = to_text (line )
389
+
390
+ if is_windows :
391
+ osc_filter = re .compile (r"\x1b\][^\x07]*\x07" )
392
+ line = osc_filter .sub ("" , line )
393
+ ansi_filter = re .compile (r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]" )
394
+ line = ansi_filter .sub ("" , line )
395
+
396
+ # Replace or strip sequence (at terminal width)
397
+ line = line .replace ("\r \r \n " , "\n " )
398
+ if len (line ) == 201 :
399
+ line = line [:- 1 ]
400
+
401
+ return line
402
+
403
+
378
404
class Connection (ConnectionBase ):
379
405
"""AWS SSM based connections"""
380
406
@@ -401,6 +427,9 @@ def __init__(self, *args, **kwargs):
401
427
raise AnsibleError (missing_required_lib ("boto3" ))
402
428
403
429
self .host = self ._play_context .remote_addr
430
+ self ._instance_id = None
431
+ self ._polling_obj = None
432
+ self ._has_timeout = False
404
433
405
434
if getattr (self ._shell , "SHELL_FAMILY" , "" ) == "powershell" :
406
435
self .delegate = None
@@ -541,14 +570,19 @@ def reset(self):
541
570
self .close ()
542
571
return self .start_session ()
543
572
573
+ @property
574
+ def instance_id (self ) -> str :
575
+ if not self ._instance_id :
576
+ self ._instance_id = self .host if self .get_option ("instance_id" ) is None else self .get_option ("instance_id" )
577
+ return self ._instance_id
578
+
579
+ @instance_id .setter
580
+ def instance_id (self , instance_id : str ) -> NoReturn :
581
+ self ._instance_id = instance_id
582
+
544
583
def start_session (self ):
545
584
"""start ssm session"""
546
585
547
- if self .get_option ("instance_id" ) is None :
548
- self .instance_id = self .host
549
- else :
550
- self .instance_id = self .get_option ("instance_id" )
551
-
552
586
self ._vvv (f"ESTABLISH SSM CONNECTION TO: { self .instance_id } " )
553
587
554
588
executable = self .get_option ("plugin" )
@@ -592,8 +626,6 @@ def start_session(self):
592
626
os .close (stdout_w )
593
627
self ._stdout = os .fdopen (stdout_r , "rb" , 0 )
594
628
self ._session = session
595
- self ._poll_stdout = select .poll ()
596
- self ._poll_stdout .register (self ._stdout , select .POLLIN )
597
629
598
630
# Disable command echo and prompt.
599
631
self ._prepare_terminal ()
@@ -602,49 +634,56 @@ def start_session(self):
602
634
603
635
return session
604
636
605
- @_ssm_retry
606
- def exec_command (self , cmd , in_data = None , sudoable = True ):
607
- """run a command on the ssm host"""
608
-
609
- super ().exec_command (cmd , in_data = in_data , sudoable = sudoable )
610
-
611
- self ._vvv (f"EXEC: { to_text (cmd )} " )
612
-
613
- session = self ._session
614
-
615
- mark_begin = "" .join ([random .choice (string .ascii_letters ) for i in xrange (self .MARK_LENGTH )])
616
- if self .is_windows :
617
- mark_start = mark_begin + " $LASTEXITCODE"
618
- else :
619
- mark_start = mark_begin
620
- mark_end = "" .join ([random .choice (string .ascii_letters ) for i in xrange (self .MARK_LENGTH )])
637
+ def poll_stdout (self , timeout : int = 1000 ) -> bool :
638
+ """Polls the stdout file descriptor.
621
639
622
- # Wrap command in markers accordingly for the shell used
623
- cmd = self ._wrap_command (cmd , sudoable , mark_start , mark_end )
624
-
625
- self ._flush_stderr (session )
640
+ :param timeout: Specifies the length of time in milliseconds which the system will wait.
641
+ :returns: A boolean to specify the polling result
642
+ """
643
+ if self ._polling_obj is None :
644
+ self ._polling_obj = select .poll ()
645
+ self ._polling_obj .register (self ._stdout , select .POLLIN )
646
+ return bool (self ._polling_obj .poll (timeout ))
626
647
627
- for chunk in chunks ( cmd , 1024 ) :
628
- session . stdin . write ( to_bytes ( chunk , errors = "surrogate_or_strict" ))
648
+ def poll ( self , label : str , cmd : str ) -> NoReturn :
649
+ """Poll session to retrieve content from stdout.
629
650
651
+ :param label: A label for the display (EXEC, PRE...)
652
+ :param cmd: The command being executed
653
+ """
654
+ start = round (time .time ())
655
+ yield self .poll_stdout ()
656
+ timeout = self .get_option ("ssm_timeout" )
657
+ while self ._session .poll () is None :
658
+ remaining = start + timeout - round (time .time ())
659
+ self ._vvvv (f"{ label } remaining: { remaining } second(s)" )
660
+ if remaining < 0 :
661
+ self ._has_timeout = True
662
+ raise AnsibleConnectionFailure (f"{ label } command '{ cmd } ' timeout on host: { self .instance_id } " )
663
+ yield self .poll_stdout ()
664
+
665
+ def exec_communicate (self , cmd : str , mark_start : str , mark_begin : str , mark_end : str ) -> Tuple [int , str , str ]:
666
+ """Interact with session.
667
+ Read stdout between the markers until 'mark_end' is reached.
668
+
669
+ :param cmd: The command being executed.
670
+ :param mark_start: The marker which starts the output.
671
+ :param mark_begin: The begin marker.
672
+ :param mark_end: The end marker.
673
+ :returns: A tuple with the return code, the stdout and the stderr content.
674
+ """
630
675
# Read stdout between the markers
631
676
stdout = ""
632
677
win_line = ""
633
678
begin = False
634
- stop_time = int (round (time .time ())) + self .get_option ("ssm_timeout" )
635
- while session .poll () is None :
636
- remaining = stop_time - int (round (time .time ()))
637
- if remaining < 1 :
638
- self ._timeout = True
639
- self ._vvvv (f"EXEC timeout stdout: \n { to_text (stdout )} " )
640
- raise AnsibleConnectionFailure (f"SSM exec_command timeout on host: { self .instance_id } " )
641
- if self ._poll_stdout .poll (1000 ):
642
- line = self ._filter_ansi (self ._stdout .readline ())
643
- self ._vvvv (f"EXEC stdout line: \n { to_text (line )} " )
644
- else :
645
- self ._vvvv (f"EXEC remaining: { remaining } " )
679
+ returncode = None
680
+ for poll_result in self .poll ("EXEC" , cmd ):
681
+ if not poll_result :
646
682
continue
647
683
684
+ line = filter_ansi (self ._stdout .readline (), self .is_windows )
685
+ self ._vvvv (f"EXEC stdout line: \n { line } " )
686
+
648
687
if not begin and self .is_windows :
649
688
win_line = win_line + line
650
689
line = win_line
@@ -662,9 +701,33 @@ def exec_command(self, cmd, in_data=None, sudoable=True):
662
701
break
663
702
stdout = stdout + line
664
703
665
- stderr = self ._flush_stderr (session )
704
+ # see https://github.com/pylint-dev/pylint/issues/8909)
705
+ return (returncode , stdout , self ._flush_stderr (self ._session )) # pylint: disable=unreachable
706
+
707
+ @_ssm_retry
708
+ def exec_command (self , cmd : str , in_data : bool = None , sudoable : bool = True ) -> Tuple [int , str , str ]:
709
+ """run a command on the ssm host"""
710
+
711
+ super ().exec_command (cmd , in_data = in_data , sudoable = sudoable )
666
712
667
- return (returncode , stdout , stderr )
713
+ self ._vvv (f"EXEC: { to_text (cmd )} " )
714
+
715
+ mark_begin = "" .join ([random .choice (string .ascii_letters ) for i in range (self .MARK_LENGTH )])
716
+ if self .is_windows :
717
+ mark_start = mark_begin + " $LASTEXITCODE"
718
+ else :
719
+ mark_start = mark_begin
720
+ mark_end = "" .join ([random .choice (string .ascii_letters ) for i in range (self .MARK_LENGTH )])
721
+
722
+ # Wrap command in markers accordingly for the shell used
723
+ cmd = self ._wrap_command (cmd , mark_start , mark_end )
724
+
725
+ self ._flush_stderr (self ._session )
726
+
727
+ for chunk in chunks (cmd , 1024 ):
728
+ self ._session .stdin .write (to_bytes (chunk , errors = "surrogate_or_strict" ))
729
+
730
+ return self .exec_communicate (cmd , mark_start , mark_begin , mark_end )
668
731
669
732
def _prepare_terminal (self ):
670
733
"""perform any one-time terminal settings"""
@@ -682,7 +745,7 @@ def _prepare_terminal(self):
682
745
disable_echo_cmd = to_bytes ("stty -echo\n " , errors = "surrogate_or_strict" )
683
746
684
747
disable_prompt_complete = None
685
- end_mark = "" .join ([random .choice (string .ascii_letters ) for i in xrange (self .MARK_LENGTH )])
748
+ end_mark = "" .join ([random .choice (string .ascii_letters ) for i in range (self .MARK_LENGTH )])
686
749
disable_prompt_cmd = to_bytes (
687
750
"PS1='' ; bind 'set enable-bracketed-paste off'; printf '\\ n%s\\ n' '" + end_mark + "'\n " ,
688
751
errors = "surrogate_or_strict" ,
@@ -691,18 +754,12 @@ def _prepare_terminal(self):
691
754
692
755
stdout = ""
693
756
# Custom command execution for when we're waiting for startup
694
- stop_time = int (round (time .time ())) + self .get_option ("ssm_timeout" )
695
- while (not disable_prompt_complete ) and (self ._session .poll () is None ):
696
- remaining = stop_time - int (round (time .time ()))
697
- if remaining < 1 :
698
- self ._timeout = True
699
- self ._vvvv (f"PRE timeout stdout: \n { to_bytes (stdout )} " )
700
- raise AnsibleConnectionFailure (f"SSM start_session timeout on host: { self .instance_id } " )
701
- if self ._poll_stdout .poll (1000 ):
757
+ for poll_result in self .poll ("PRE" , "start_session" ):
758
+ if disable_prompt_complete :
759
+ break
760
+ if poll_result :
702
761
stdout += to_text (self ._stdout .read (1024 ))
703
762
self ._vvvv (f"PRE stdout line: \n { to_bytes (stdout )} " )
704
- else :
705
- self ._vvvv (f"PRE remaining: { remaining } " )
706
763
707
764
# wait til prompt is ready
708
765
if startup_complete is False :
@@ -734,12 +791,13 @@ def _prepare_terminal(self):
734
791
stdout = stdout [match .end ():] # fmt: skip
735
792
disable_prompt_complete = True
736
793
737
- if not disable_prompt_complete :
794
+ # see https://github.com/pylint-dev/pylint/issues/8909)
795
+ if not disable_prompt_complete : # pylint: disable=unreachable
738
796
raise AnsibleConnectionFailure (f"SSM process closed during _prepare_terminal on host: { self .instance_id } " )
739
797
self ._vvvv ("PRE Terminal configured" )
740
798
741
- def _wrap_command (self , cmd , sudoable , mark_start , mark_end ) :
742
- """wrap command so stdout and status can be extracted"""
799
+ def _wrap_command (self , cmd : str , mark_start : str , mark_end : str ) -> str :
800
+ """Wrap command so stdout and status can be extracted"""
743
801
744
802
if self .is_windows :
745
803
if not cmd .startswith (" " .join (_common_args ) + " -EncodedCommand" ):
@@ -789,23 +847,6 @@ def _post_process(self, stdout, mark_begin):
789
847
790
848
return (returncode , stdout )
791
849
792
- def _filter_ansi (self , line ):
793
- """remove any ANSI terminal control codes"""
794
- line = to_text (line )
795
-
796
- if self .is_windows :
797
- osc_filter = re .compile (r"\x1b\][^\x07]*\x07" )
798
- line = osc_filter .sub ("" , line )
799
- ansi_filter = re .compile (r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]" )
800
- line = ansi_filter .sub ("" , line )
801
-
802
- # Replace or strip sequence (at terminal width)
803
- line = line .replace ("\r \r \n " , "\n " )
804
- if len (line ) == 201 :
805
- line = line [:- 1 ]
806
-
807
- return line
808
-
809
850
def _flush_stderr (self , session_process ):
810
851
"""read and return stderr with minimal blocking"""
811
852
@@ -995,7 +1036,7 @@ def close(self):
995
1036
"""terminate the connection"""
996
1037
if self ._session_id :
997
1038
self ._vvv (f"CLOSING SSM CONNECTION TO: { self .instance_id } " )
998
- if self ._timeout :
1039
+ if self ._has_timeout :
999
1040
self ._session .terminate ()
1000
1041
else :
1001
1042
cmd = b"\n exit\n "
0 commit comments