From d96b2a2799280a1dec885f6b55e53576d018e1b7 Mon Sep 17 00:00:00 2001
From: KOLANICH <kolan_n@mail.ru>
Date: Tue, 12 Jan 2021 16:52:23 +0300
Subject: [PATCH] Implemented SOCKS5 proxies and fixed h2.settings.ENABLE_PUSH
 to h2.settings.SettingCodes.ENABLE_PUSH. The code was taken from urllib3 and
 adapted to hyper.

---
 LICENSE                    |  1 +
 hyper/common/connection.py | 17 ++++++++--
 hyper/contrib.py           | 66 ++++++++++++++++++++++++++----------
 hyper/http11/connection.py | 66 ++++++++++++++++++++++++++----------
 hyper/http20/connection.py | 69 +++++++++++++++++++++++++++-----------
 5 files changed, 162 insertions(+), 57 deletions(-)

diff --git a/LICENSE b/LICENSE
index 7ef4aca0..a351a2a2 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,6 +1,7 @@
 The MIT License (MIT)
 
 Copyright (c) 2014 Cory Benfield, Google Inc
+Copyright (c) 2008-2020 Andrey Petrov and urllib3 contributors (see https://github.com/urllib3/urllib3/blob/master/CONTRIBUTORS.txt) (socks5 code was borrowed from there)
 
 Permission is hereby granted, free of charge, to any person obtaining a copy
 of this software and associated documentation files (the "Software"), to deal
diff --git a/hyper/common/connection.py b/hyper/common/connection.py
index 855994f8..81a81c04 100644
--- a/hyper/common/connection.py
+++ b/hyper/common/connection.py
@@ -46,6 +46,7 @@ class HTTPConnection(object):
     :param proxy_port: (optional) The proxy port to connect to. If not provided
         and one also isn't provided in the ``proxy_host`` parameter, defaults
         to 8080.
+    :param proxy_type: (optional) type of the proxy to use. Allows usage of socks proxies
     :param proxy_headers: (optional) The headers to send to a proxy.
     """
     def __init__(self,
@@ -57,6 +58,7 @@ def __init__(self,
                  ssl_context=None,
                  proxy_host=None,
                  proxy_port=None,
+                 proxy_type=None,
                  proxy_headers=None,
                  timeout=None,
                  **kwargs):
@@ -66,14 +68,14 @@ def __init__(self,
         self._h1_kwargs = {
             'secure': secure, 'ssl_context': ssl_context,
             'proxy_host': proxy_host, 'proxy_port': proxy_port,
-            'proxy_headers': proxy_headers, 'enable_push': enable_push,
+            'proxy_headers': proxy_headers, "proxy_type": proxy_type, 'enable_push': enable_push,
             'timeout': timeout
         }
         self._h2_kwargs = {
             'window_manager': window_manager, 'enable_push': enable_push,
             'secure': secure, 'ssl_context': ssl_context,
             'proxy_host': proxy_host, 'proxy_port': proxy_port,
-            'proxy_headers': proxy_headers,
+            'proxy_headers': proxy_headers, "proxy_type": proxy_type,
             'timeout': timeout
         }
 
@@ -150,6 +152,17 @@ def get_response(self, *args, **kwargs):
 
             return self._conn.get_response(1)
 
+    def reanimate(self):
+        """Reanimate connection reset because of proxy"""
+        if hasattr(self, "streams"):
+            for stream in list(self.streams.values()):
+                stream.remote_closed = True
+                stream.local_closed = True
+        self._conn.close()
+        self._conn = HTTP11Connection(
+            self._host, self._port, **self._h1_kwargs,
+        )
+
     # The following two methods are the implementation of the context manager
     # protocol.
     def __enter__(self):  # pragma: no cover
diff --git a/hyper/contrib.py b/hyper/contrib.py
index 79aa7d12..e1abf0b3 100644
--- a/hyper/contrib.py
+++ b/hyper/contrib.py
@@ -56,11 +56,24 @@ def get_connection(self, host, port, scheme, cert=None, verify=True,
             ssl_context = init_context(cert_path=verify, cert=cert)
 
         if proxy:
+            proxy = prepend_scheme_if_needed(proxy, 'http')
             proxy_headers = self.proxy_headers(proxy)
-            proxy_netloc = urlparse(proxy).netloc
+            parsed = urlparse(proxy)
+            proxy_netloc = parsed.netloc
+            proxy_host_port = proxy_netloc.split(":")
+            if len(proxy_host_port) == 2:
+                proxy_host, proxy_port = proxy_host_port
+                proxy_port = int(proxy_port)
+            elif len(proxy_port_host) == 1:
+                raise ValueError("Specify proxy port!")
+            else:
+                raise ValueError("Invalid proxy netloc: ", repr(proxy_netloc))
+            proxy_type = parsed.scheme
         else:
             proxy_headers = None
-            proxy_netloc = None
+            proxy_host = None
+            proxy_port = None
+            proxy_type = None
 
         # We put proxy headers in the connection_key, because
         # ``proxy_headers`` method might be overridden, so we can't
@@ -68,7 +81,7 @@ def get_connection(self, host, port, scheme, cert=None, verify=True,
         proxy_headers_key = (frozenset(proxy_headers.items())
                              if proxy_headers else None)
         connection_key = (host, port, scheme, cert, verify,
-                          proxy_netloc, proxy_headers_key)
+                          proxy_host, proxy_port, proxy_type, proxy_headers_key)
         try:
             conn = self.connections[connection_key]
         except KeyError:
@@ -78,9 +91,13 @@ def get_connection(self, host, port, scheme, cert=None, verify=True,
                 secure=secure,
                 window_manager=self.window_manager,
                 ssl_context=ssl_context,
-                proxy_host=proxy_netloc,
+                proxy_host=proxy_host,
+                proxy_port=proxy_port,
                 proxy_headers=proxy_headers,
-                timeout=timeout)
+                proxy_type=proxy_type,
+                timeout=timeout,
+                enable_push=self.enable_push
+            )
             self.connections[connection_key] = conn
 
         return conn
@@ -95,6 +112,12 @@ def send(self, request, stream=False, cert=None, verify=True, proxies=None,
             proxy = prepend_scheme_if_needed(proxy, 'http')
 
         parsed = urlparse(request.url)
+
+        # Build the selector.
+        selector = parsed.path
+        selector += '?' + parsed.query if parsed.query else ''
+        selector += '#' + parsed.fragment if parsed.fragment else ''
+
         conn = self.get_connection(
             parsed.hostname,
             parsed.port,
@@ -104,18 +127,27 @@ def send(self, request, stream=False, cert=None, verify=True, proxies=None,
             proxy=proxy,
             timeout=timeout)
 
-        # Build the selector.
-        selector = parsed.path
-        selector += '?' + parsed.query if parsed.query else ''
-        selector += '#' + parsed.fragment if parsed.fragment else ''
-
-        conn.request(
-            request.method,
-            selector,
-            request.body,
-            request.headers
-        )
-        resp = conn.get_response()
+        def do_req():
+            conn.request(
+                request.method,
+                selector,
+                request.body,
+                request.headers
+            )
+            resp = conn.get_response()
+            return conn, resp
+
+        retried = 0
+        max_retries = 1
+        while True:
+            try:
+                conn, resp = do_req()
+                break
+            except ConnectionAbortedError as e:
+                if retried < max_retries:
+                    conn.reanimate()
+                else:
+                    raise
 
         r = self.build_response(request, resp)
 
diff --git a/hyper/http11/connection.py b/hyper/http11/connection.py
index e9523745..cba86b33 100644
--- a/hyper/http11/connection.py
+++ b/hyper/http11/connection.py
@@ -101,7 +101,7 @@ class HTTP11Connection(object):
 
     def __init__(self, host, port=None, secure=None, ssl_context=None,
                  proxy_host=None, proxy_port=None, proxy_headers=None,
-                 timeout=None, **kwargs):
+                 proxy_type=None, timeout=None, **kwargs):
         if port is None:
             self.host, self.port = to_host_port_tuple(host, default_port=80)
         else:
@@ -133,11 +133,14 @@ def __init__(self, host, port=None, secure=None, ssl_context=None,
             self.proxy_host, self.proxy_port = to_host_port_tuple(
                 proxy_host, default_port=8080
             )
+            self.proxy_type = proxy_type
         elif proxy_host:
-            self.proxy_host, self.proxy_port = proxy_host, proxy_port
+            self.proxy_host, self.proxy_port, self.proxy_type = proxy_host, proxy_port, proxy_type
         else:
             self.proxy_host = None
             self.proxy_port = None
+            self.proxy_type = None
+            raise ValueError("No proxy was set!")
         self.proxy_headers = proxy_headers
 
         #: The size of the in-memory buffer used to store data from the
@@ -169,22 +172,48 @@ def connect(self):
                 connect_timeout = self._timeout
                 read_timeout = self._timeout
 
-            if self.proxy_host and self.secure:
-                # Send http CONNECT method to a proxy and acquire the socket
-                sock = _create_tunnel(
-                    self.proxy_host,
-                    self.proxy_port,
-                    self.host,
-                    self.port,
-                    proxy_headers=self.proxy_headers,
-                    timeout=self._timeout
-                )
-            elif self.proxy_host:
-                # Simple http proxy
-                sock = socket.create_connection(
-                    (self.proxy_host, self.proxy_port),
-                    timeout=connect_timeout
-                )
+            if self.proxy_host:
+                if self.proxy_type.startswith("socks"):
+                    import socks
+                    rdns = (self.proxy_type[-1]=="h")
+                    # any error will result in silently connecting without a proxy.
+                    # IDK why it is done this way
+                    if not rdns:
+                        raise ValueError("RDNS is disabled. Proxying dns queries is disabled. NSA is spying you.")
+                    if rdns and self.proxy_type.startswith("socks4"):
+                        raise ValueError("RDNS is not supported for socks4. socks.create_connection ignores it silently.")
+                    if not isinstance(self.proxy_host, str):
+                        raise ValueError("self.proxy_host", repr(self.proxy_host), "is not str")
+                    if not isinstance(self.proxy_port, int):
+                        raise ValueError("self.proxy_port", repr(self.proxy_port), "is not int")
+                    socks_version_char = self.proxy_type[5]
+                    sock = socks.create_connection(
+                        (self.host, self.port),
+                        proxy_type=getattr(socks, "PROXY_TYPE_SOCKS" + socks_version_char),
+                        proxy_addr=self.proxy_host,
+                        proxy_port=self.proxy_port,
+                        #proxy_username=username,
+                        #proxy_password=password,
+                        proxy_rdns=rdns,
+                    )
+                elif self.proxy_host and self.secure:
+                    # Send http CONNECT method to a proxy and acquire the socket
+                    sock = _create_tunnel(
+                        self.proxy_host,
+                        self.proxy_port,
+                        self.host,
+                        self.port,
+                        proxy_headers=self.proxy_headers,
+                        timeout=self._timeout
+                    )
+                elif self.proxy_host:
+                    # Simple http proxy
+                    sock = socket.create_connection(
+                        (self.proxy_host, self.proxy_port),
+                        timeout=connect_timeout
+                    )
+                else:
+                    raise Exception("Unsupported proxy type: "+repr(proxy_type))
             else:
                 sock = socket.create_connection((self.host, self.port),
                                                 timeout=connect_timeout)
@@ -454,6 +483,7 @@ def _send_file_like_obj(self, fobj):
 
         return
 
+
     def close(self):
         """
         Closes the connection. This closes the socket and then abandons the
diff --git a/hyper/http20/connection.py b/hyper/http20/connection.py
index b8be292b..0d2981fa 100644
--- a/hyper/http20/connection.py
+++ b/hyper/http20/connection.py
@@ -101,7 +101,7 @@ class HTTP20Connection(object):
 
     def __init__(self, host, port=None, secure=None, window_manager=None,
                  enable_push=False, ssl_context=None, proxy_host=None,
-                 proxy_port=None, force_proto=None, proxy_headers=None,
+                 proxy_port=None, proxy_type=None, force_proto=None, proxy_headers=None,
                  timeout=None, **kwargs):
         """
         Creates an HTTP/2 connection to a specific server.
@@ -126,11 +126,13 @@ def __init__(self, host, port=None, secure=None, window_manager=None,
             self.proxy_host, self.proxy_port = to_host_port_tuple(
                 proxy_host, default_port=8080
             )
+            self.proxy_type = proxy_type
         elif proxy_host:
-            self.proxy_host, self.proxy_port = proxy_host, proxy_port
+            self.proxy_host, self.proxy_port, self.proxy_type = proxy_host, proxy_port, proxy_type
         else:
             self.proxy_host = None
             self.proxy_port = None
+            self.proxy_type = None
         self.proxy_headers = proxy_headers
 
         #: The size of the in-memory buffer used to store data from the
@@ -353,22 +355,49 @@ def connect(self):
                 connect_timeout = self._timeout
                 read_timeout = self._timeout
 
-            if self.proxy_host and self.secure:
-                # Send http CONNECT method to a proxy and acquire the socket
-                sock = _create_tunnel(
-                    self.proxy_host,
-                    self.proxy_port,
-                    self.host,
-                    self.port,
-                    proxy_headers=self.proxy_headers,
-                    timeout=self._timeout
-                )
-            elif self.proxy_host:
-                # Simple http proxy
-                sock = socket.create_connection(
-                    (self.proxy_host, self.proxy_port),
-                    timeout=connect_timeout
-                )
+            if self.proxy_host:
+                if self.proxy_type.startswith("socks"):
+                    import socks
+                    rdns = (self.proxy_type[-1]=="h")
+                    # any error will result in silently connecting without a proxy.
+                    # IDK why it is done this way.
+                    if not rdns:
+                        raise ValueError("RDNS is disabled. Proxying dns queries is disabled. NSA is spying you.")
+                    if rdns and self.proxy_type.startswith("socks4"):
+                        raise ValueError("RDNS is not supported for socks4. socks.create_connection ignores it silently.")
+                    if not isinstance(self.proxy_host, str):
+                        raise ValueError("self.proxy_host", repr(self.proxy_host), "is not str")
+                    if not isinstance(self.proxy_port, int):
+                        raise ValueError("self.proxy_port", repr(self.proxy_port), "is not int")
+                    socks_version_char = self.proxy_type[5]
+                    sock = socks.create_connection(
+                        (self.host, self.port),
+                        proxy_type=getattr(socks, "PROXY_TYPE_SOCKS" + socks_version_char),
+                        proxy_addr=self.proxy_host,
+                        proxy_port=self.proxy_port,
+                        #proxy_username=username,
+                        #proxy_password=password,
+                        proxy_rdns=rdns,
+                    )
+                    #sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+                elif self.proxy_host and self.secure:
+                    # Send http CONNECT method to a proxy and acquire the socket
+                    sock = _create_tunnel(
+                        self.proxy_host,
+                        self.proxy_port,
+                        self.host,
+                        self.port,
+                        proxy_headers=self.proxy_headers,
+                        timeout=self._timeout
+                    )
+                elif self.proxy_host:
+                    # Simple http proxy
+                    sock = socket.create_connection(
+                        (self.proxy_host, self.proxy_port),
+                        timeout=connect_timeout
+                    )
+                else:
+                    raise Exception("Unsupported proxy type: "+repr(proxy_type))
             else:
                 sock = socket.create_connection((self.host, self.port),
                                                 timeout=connect_timeout)
@@ -403,7 +432,7 @@ def _connect_upgrade(self, sock):
         with self._conn as conn:
             conn.initiate_upgrade_connection()
             conn.update_settings(
-                {h2.settings.ENABLE_PUSH: int(self._enable_push)}
+                {h2.settings.SettingCodes.ENABLE_PUSH: int(self._enable_push)}
             )
         self._send_outstanding_data()
 
@@ -424,7 +453,7 @@ def _send_preamble(self):
         with self._conn as conn:
             conn.initiate_connection()
             conn.update_settings(
-                {h2.settings.ENABLE_PUSH: int(self._enable_push)}
+                {h2.settings.SettingCodes.ENABLE_PUSH: int(self._enable_push)}
             )
         self._send_outstanding_data()