diff --git a/setup.py b/setup.py index 8b5ebea5..9be5ef82 100755 --- a/setup.py +++ b/setup.py @@ -15,6 +15,17 @@ setup( version=version, cmdclass={ - "build_py": unasync.cmdclass_build_py(rules=[unasync.Rule("/ahip/", "/hip/")]) + "build_py": unasync.cmdclass_build_py( + rules=[ + unasync.Rule( + "/ahip/", + "/hip/", + additional_replacements={ + "anext": "next", + "await_if_coro": "return_non_coro", + }, + ) + ] + ) }, ) diff --git a/src/ahip/_backends/anyio_backend.py b/src/ahip/_backends/anyio_backend.py index 6daf254d..988e3e38 100644 --- a/src/ahip/_backends/anyio_backend.py +++ b/src/ahip/_backends/anyio_backend.py @@ -44,7 +44,7 @@ async def start_tls(self, server_hostname, ssl_context: SSLContext): def getpeercert(self, binary_form=False): return self._stream.getpeercert(binary_form=binary_form) - async def receive_some(self): + async def receive_some(self, read_timeout): return await self._stream.receive_some(BUFSIZE) async def send_and_receive_for_a_while( diff --git a/src/ahip/connection.py b/src/ahip/connection.py index 7f5d6436..57bbcb8f 100644 --- a/src/ahip/connection.py +++ b/src/ahip/connection.py @@ -16,7 +16,6 @@ import collections import datetime -import itertools import socket import warnings @@ -35,6 +34,7 @@ ) from .packages import six from .util import ssl_ as ssl_util +from .util.unasync import await_if_coro, anext, ASYNC_MODE from ._backends._common import LoopAbort from ._backends._loader import load_backend, normalize_backend @@ -44,24 +44,6 @@ ssl = None -def is_async_mode(): - """Tests if we're in the async part of the code or not""" - - async def f(): - """Unasync transforms async functions in sync functions""" - return None - - obj = f() - if obj is None: - return False - else: - obj.close() # prevent unawaited coroutine warning - return True - - -_ASYNC_MODE = is_async_mode() - - # When it comes time to update this value as a part of regular maintenance # (ie test_recent_date is failing) update it to ~6 months before the current date. RECENT_DATE = datetime.date(2019, 1, 1) @@ -106,17 +88,16 @@ def _stringify_headers(headers): yield (name, value) -def _read_readable(readable): +async def _read_readable(readable): # TODO: reconsider this block size blocksize = 8192 while True: - datablock = readable.read(blocksize) + datablock = await await_if_coro(readable.read(blocksize)) if not datablock: break yield datablock -# XX this should return an async iterator def _make_body_iterable(body): """ This function turns all possible body types that Hip supports into an @@ -134,63 +115,83 @@ def _make_body_iterable(body): is deliberate: users must make choices about the encoding of the data they use. """ - if body is None: - return [] - elif isinstance(body, bytes): - return [body] - elif hasattr(body, "read"): - return _read_readable(body) - elif isinstance(body, collections.Iterable) and not isinstance(body, six.text_type): - return body - else: - raise InvalidBodyError("Unacceptable body type: %s" % type(body)) + + async def generator(): + if body is None: + return + elif isinstance(body, bytes): + yield body + elif hasattr(body, "read"): + async for chunk in _read_readable(body): + yield chunk + elif isinstance(body, collections.Iterable) and not isinstance( + body, six.text_type + ): + for chunk in body: + yield chunk + else: + raise InvalidBodyError("Unacceptable body type: %s" % type(body)) + + return generator().__aiter__() -# XX this should return an async iterator def _request_bytes_iterable(request, state_machine): """ An iterable that serialises a set of bytes for the body. """ def all_pieces_iter(): - h11_request = h11.Request( - method=request.method, - target=request.target, - headers=_stringify_headers(request.headers.items()), + async def generator(): + h11_request = h11.Request( + method=request.method, + target=request.target, + headers=_stringify_headers(request.headers.items()), + ) + yield state_machine.send(h11_request) + + async for chunk in _make_body_iterable(request.body): + yield state_machine.send(h11.Data(data=chunk)) + + yield state_machine.send(h11.EndOfMessage()) + + return generator().__aiter__() + + async def generator(): + + # Try to combine the header bytes + (first set of body bytes or end of + # message bytes) into one packet. + # As long as all_pieces_iter() yields at least two messages, this should + # never raise StopIteration. + remaining_pieces = all_pieces_iter() + first_packet_bytes = (await anext(remaining_pieces)) + ( + await anext(remaining_pieces) ) - yield state_machine.send(h11_request) - - for chunk in _make_body_iterable(request.body): - yield state_machine.send(h11.Data(data=chunk)) - - yield state_machine.send(h11.EndOfMessage()) - - # Try to combine the header bytes + (first set of body bytes or end of - # message bytes) into one packet. - # As long as all_pieces_iter() yields at least two messages, this should - # never raise StopIteration. - remaining_pieces = all_pieces_iter() - first_packet_bytes = next(remaining_pieces) + next(remaining_pieces) - all_pieces_combined_iter = itertools.chain([first_packet_bytes], remaining_pieces) - - # We filter out any empty strings, because we don't want to call - # send(b""). You might think this is a no-op, so it shouldn't matter - # either way. But this isn't true. For example, if we're sending a request - # with Content-Length framing, we could have this sequence: - # - # - We send the last Data event. - # - The peer immediately sends its response and closes the socket. - # - We attempt to send the EndOfMessage event, which (b/c this request has - # Content-Length framing) is encoded as b"". - # - We call send(b""). - # - This triggers the kernel / SSL layer to discover that the socket is - # closed, so it raises an exception. - # - # It's easier to fix this once here instead of worrying about it in all - # the different backends. - for piece in all_pieces_combined_iter: - if piece: - yield piece + + async def all_pieces_combined_iter(): + yield first_packet_bytes + async for piece in remaining_pieces: + yield piece + + # We filter out any empty strings, because we don't want to call + # send(b""). You might think this is a no-op, so it shouldn't matter + # either way. But this isn't true. For example, if we're sending a request + # with Content-Length framing, we could have this sequence: + # + # - We send the last Data event. + # - The peer immediately sends its response and closes the socket. + # - We attempt to send the EndOfMessage event, which (b/c this request has + # Content-Length framing) is encoded as b"". + # - We call send(b""). + # - This triggers the kernel / SSL layer to discover that the socket is + # closed, so it raises an exception. + # + # It's easier to fix this once here instead of worrying about it in all + # the different backends. + async for piece in all_pieces_combined_iter(): + if piece: + yield piece + + return generator().__aiter__() def _response_from_h11(h11_response, body_object): @@ -259,8 +260,8 @@ async def _start_http_request(request, state_machine, sock, read_timeout=None): async def produce_bytes(): try: - return next(request_bytes_iterable) - except StopIteration: + return await anext(request_bytes_iterable) + except StopAsyncIteration: # We successfully sent the whole body! context["send_aborted"] = False return None @@ -346,7 +347,7 @@ def __init__( ): self.is_verified = False self.read_timeout = None - self._backend = load_backend(normalize_backend(backend, _ASYNC_MODE)) + self._backend = load_backend(normalize_backend(backend, ASYNC_MODE)) self._host = host self._port = port self._socket_options = ( diff --git a/src/ahip/connectionpool.py b/src/ahip/connectionpool.py index a1da0844..40adb1df 100644 --- a/src/ahip/connectionpool.py +++ b/src/ahip/connectionpool.py @@ -620,7 +620,7 @@ async def urlopen( # Rewind body position, if needed. Record current position # for future rewinds in the event of a redirect/retry. - body_pos = set_file_position(body, body_pos) + body_pos = await set_file_position(body, body_pos) if body is not None: _add_transport_headers(headers) diff --git a/src/ahip/poolmanager.py b/src/ahip/poolmanager.py index 609248bf..a37fc9d6 100644 --- a/src/ahip/poolmanager.py +++ b/src/ahip/poolmanager.py @@ -323,7 +323,7 @@ async def urlopen(self, method, url, redirect=True, **kw): # for future rewinds in the event of a redirect/retry. body = kw.get("body") body_pos = kw.get("body_pos") - kw["body_pos"] = set_file_position(body, body_pos) + kw["body_pos"] = await set_file_position(body, body_pos) if "headers" not in kw: kw["headers"] = self.headers.copy() diff --git a/src/ahip/util/request.py b/src/ahip/util/request.py index fbfe8f92..c7911260 100644 --- a/src/ahip/util/request.py +++ b/src/ahip/util/request.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from base64 import b64encode +from .unasync import await_if_coro from ..packages.six import b, integer_types from ..exceptions import UnrewindableBodyError @@ -87,16 +88,16 @@ def make_headers( return headers -def set_file_position(body, pos): +async def set_file_position(body, pos): """ If a position is provided, move file to that point. Otherwise, we'll attempt to record a position for future use. """ if pos is not None: - rewind_body(body, pos) + await rewind_body(body, pos) elif getattr(body, "tell", None) is not None: try: - pos = body.tell() + pos = await await_if_coro(body.tell()) except (IOError, OSError): # This differentiates from None, allowing us to catch # a failed `tell()` later when trying to rewind the body. @@ -105,7 +106,7 @@ def set_file_position(body, pos): return pos -def rewind_body(body, body_pos): +async def rewind_body(body, body_pos): """ Attempt to rewind body to a certain position. Primarily used for request redirects and retries. @@ -119,7 +120,7 @@ def rewind_body(body, body_pos): body_seek = getattr(body, "seek", None) if body_seek is not None and isinstance(body_pos, integer_types): try: - body_seek(body_pos) + await await_if_coro(body_seek(body_pos)) except (IOError, OSError): raise UnrewindableBodyError( "An error occurred when rewinding request body for redirect/retry." diff --git a/src/ahip/util/unasync.py b/src/ahip/util/unasync.py new file mode 100644 index 00000000..79ea98ba --- /dev/null +++ b/src/ahip/util/unasync.py @@ -0,0 +1,40 @@ +"""Set of utility functions for unasync that transform into sync counterparts cleanly""" + +import inspect + +_original_next = next + + +def is_async_mode(): + """Tests if we're in the async part of the code or not""" + + async def f(): + """Unasync transforms async functions in sync functions""" + return None + + obj = f() + if obj is None: + return False + else: + obj.close() # prevent unawaited coroutine warning + return True + + +ASYNC_MODE = is_async_mode() + + +async def anext(x): + return await x.__anext__() + + +async def await_if_coro(x): + if inspect.iscoroutine(x): + return await x + return x + + +next = _original_next + + +def return_non_coro(x): + return x diff --git a/test/conftest.py b/test/conftest.py index 3d77bee9..01be389c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -14,7 +14,7 @@ # We support Python 3.6+ for async code if sys.version_info[:2] < (3, 6): - collect_ignore_glob = ["async/*.py", "with_dummyserver/async/*.py"] + collect_ignore_glob = ["async/*.py", "with_dummyserver/async*/*.py"] # The Python 3.8+ default loop on Windows breaks Tornado @pytest.fixture(scope="session", autouse=True) diff --git a/test/with_dummyserver/async/test_poolmanager.py b/test/with_dummyserver/async/test_poolmanager.py index af06e5d2..b8a5ed31 100644 --- a/test/with_dummyserver/async/test_poolmanager.py +++ b/test/with_dummyserver/async/test_poolmanager.py @@ -1,4 +1,7 @@ -from ahip import PoolManager +import io +import pytest +from ahip import PoolManager, Retry +from ahip.exceptions import UnrewindableBodyError from test.with_dummyserver import conftest @@ -22,3 +25,97 @@ async def test_redirect(self, http_server, backend, anyio_backend): assert r.status == 200 assert r.data == b"Dummy server!" + + +class TestFileUploads: + @conftest.test_all_backends + async def test_redirect_put_file(self, http_server, backend, anyio_backend): + """PUT with file object should work with a redirection response""" + base_url = "http://{}:{}".format(http_server.host, http_server.port) + retry = Retry(total=3, status_forcelist=[418]) + # httplib reads in 8k chunks; use a larger content length + content_length = 65535 + data = b"A" * content_length + uploaded_file = io.BytesIO(data) + headers = { + "test-name": "test_redirect_put_file", + "Content-Length": str(content_length), + } + url = "%s/redirect?target=/echo&status=307" % base_url + + with PoolManager(backend=backend) as http: + resp = await http.urlopen( + "PUT", url, headers=headers, retries=retry, body=uploaded_file + ) + assert resp.status == 200 + assert resp.data == data + + @conftest.test_all_backends + async def test_retries_put_filehandle(self, http_server, backend, anyio_backend): + """HTTP PUT retry with a file-like object should not timeout""" + base_url = "http://{}:{}".format(http_server.host, http_server.port) + retry = Retry(total=3, status_forcelist=[418]) + # httplib reads in 8k chunks; use a larger content length + content_length = 65535 + data = b"A" * content_length + uploaded_file = io.BytesIO(data) + headers = { + "test-name": "test_retries_put_filehandle", + "Content-Length": str(content_length), + } + + with PoolManager(backend=backend) as http: + resp = await http.urlopen( + "PUT", + "%s/successful_retry" % base_url, + headers=headers, + retries=retry, + body=uploaded_file, + redirect=False, + ) + assert resp.status == 200 + + @conftest.test_all_backends + async def test_redirect_with_failed_tell(self, http_server, backend, anyio_backend): + """Abort request if failed to get a position from tell()""" + + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + class BadTellObject(io.BytesIO): + def tell(self): + raise IOError + + body = BadTellObject(b"the data") + url = "%s/redirect?target=/successful_retry" % base_url + # httplib uses fileno if Content-Length isn't supplied, + # which is unsupported by BytesIO. + headers = {"Content-Length": "8"} + + with PoolManager() as http: + with pytest.raises(UnrewindableBodyError) as e: + await http.urlopen("PUT", url, headers=headers, body=body) + assert "Unable to record file position for" in str(e.value) + + @conftest.test_all_backends + async def test_redirect_with_failed_seek(self, http_server, backend, anyio_backend): + """Abort request if failed to restore position with seek()""" + + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + class BadSeekObject(io.BytesIO): + def seek(self, *_): + raise IOError + + body = BadSeekObject(b"the data") + url = "%s/redirect?target=/successful_retry" % base_url + # httplib uses fileno if Content-Length isn't supplied, + # which is unsupported by BytesIO. + headers = {"Content-Length": "8"} + + with PoolManager() as http: + with pytest.raises(UnrewindableBodyError) as e: + await http.urlopen("PUT", url, headers=headers, body=body) + assert ( + "An error occurred when rewinding request body for redirect/retry." + == str(e.value) + ) diff --git a/test/with_dummyserver/async_only/__init__.py b/test/with_dummyserver/async_only/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/with_dummyserver/async_only/test_poolmanager.py b/test/with_dummyserver/async_only/test_poolmanager.py new file mode 100644 index 00000000..4c23ffc8 --- /dev/null +++ b/test/with_dummyserver/async_only/test_poolmanager.py @@ -0,0 +1,100 @@ +import io +import pytest +import trio +import anyio +from ahip import PoolManager +from ahip.exceptions import UnrewindableBodyError + +from test.with_dummyserver import conftest + + +class TestFileUploads: + @conftest.test_all_backends + async def test_upload_anyio_async_files(self, http_server, backend, anyio_backend): + """Uploading a file opened via 'anyio.aopen()' should be possible""" + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + with open(__file__, mode="rb") as f: + data = f.read() + content_length = len(data) + + headers = { + "Content-Length": str(content_length), + } + url = "%s/echo" % base_url + + with PoolManager(backend=backend) as http: + async with await anyio.aopen(__file__, mode="rb") as f: + resp = await http.urlopen("PUT", url, headers=headers, body=f) + assert resp.status == 200 + assert resp.data == data + + @pytest.mark.trio + async def test_upload_trio_wrapped_files(self, http_server): + """Uploading a file wrapped via 'trio.wrap_file()' should be possible""" + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + with open(__file__, mode="rb") as f: + data = f.read() + content_length = len(data) + + headers = { + "Content-Length": str(content_length), + } + url = "%s/echo" % base_url + + with PoolManager(backend="trio") as http: + with open(__file__, mode="rb") as f: + f = trio.wrap_file(f) + resp = await http.urlopen("PUT", url, headers=headers, body=f) + assert resp.status == 200 + assert resp.data == data + + @conftest.test_all_backends + async def test_redirect_with_failed_async_tell( + self, http_server, backend, anyio_backend + ): + """Abort request if failed to get a position from async tell()""" + + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + class BadTellObject(io.BytesIO): + async def tell(self): + raise IOError + + body = BadTellObject(b"the data") + url = "%s/redirect?target=/successful_retry" % base_url + # httplib uses fileno if Content-Length isn't supplied, + # which is unsupported by BytesIO. + headers = {"Content-Length": "8"} + + with PoolManager() as http: + with pytest.raises(UnrewindableBodyError) as e: + await http.urlopen("PUT", url, headers=headers, body=body) + assert "Unable to record file position for" in str(e.value) + + @conftest.test_all_backends + async def test_redirect_with_failed_async_seek( + self, http_server, backend, anyio_backend + ): + """Abort request if failed to restore position with async seek()""" + + base_url = "http://{}:{}".format(http_server.host, http_server.port) + + class BadSeekObject(io.BytesIO): + async def seek(self, *_): + raise IOError + + body = BadSeekObject(b"the data") + url = "%s/redirect?target=/successful_retry" % base_url + # httplib uses fileno if Content-Length isn't supplied, + # which is unsupported by BytesIO. + headers = {"Content-Length": "8"} + + with PoolManager() as http: + with pytest.raises(UnrewindableBodyError) as e: + await http.urlopen("PUT", url, headers=headers, body=body) + assert ( + "An error occurred when rewinding request body for redirect/retry." + == str(e.value) + ) diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index f7187f5b..98a621ac 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -1,4 +1,3 @@ -import io import json import time @@ -8,7 +7,7 @@ from dummyserver.testcase import HTTPDummyServerTestCase, IPv6HTTPDummyServerTestCase from hip.base import DEFAULT_PORTS from hip.poolmanager import PoolManager -from hip.exceptions import MaxRetryError, NewConnectionError, UnrewindableBodyError +from hip.exceptions import MaxRetryError, NewConnectionError from hip.util.retry import Retry, RequestHistory from test import LONG_TIMEOUT @@ -255,6 +254,25 @@ def test_raise_on_status(self): assert r.status == 500 + @pytest.mark.parametrize( + ["target", "expected_target"], + [ + ("/echo_uri?q=1#fragment", b"/echo_uri?q=1"), + ("/echo_uri?#", b"/echo_uri?"), + ("/echo_uri#?", b"/echo_uri"), + ("/echo_uri#?#", b"/echo_uri"), + ("/echo_uri??#", b"/echo_uri??"), + ("/echo_uri?%3f#", b"/echo_uri?%3F"), + ("/echo_uri?%3F#", b"/echo_uri?%3F"), + ("/echo_uri?[]", b"/echo_uri?%5B%5D"), + ], + ) + def test_encode_http_target(self, target, expected_target): + with PoolManager() as http: + url = "http://%s:%d%s" % (self.host, self.port, target) + r = http.request("GET", url) + assert r.data == expected_target + def test_missing_port(self): # Can a URL that lacks an explicit port like ':80' succeed, or # will all such URLs fail with an error? @@ -542,26 +560,6 @@ def test_multi_redirect_history(self): ] assert actual == expected - def test_redirect_put_file(self): - """PUT with file object should work with a redirection response""" - retry = Retry(total=3, status_forcelist=[418]) - # httplib reads in 8k chunks; use a larger content length - content_length = 65535 - data = b"A" * content_length - uploaded_file = io.BytesIO(data) - headers = { - "test-name": "test_redirect_put_file", - "Content-Length": str(content_length), - } - url = "%s/redirect?target=/echo&status=307" % self.base_url - - with PoolManager() as http: - resp = http.urlopen( - "PUT", url, headers=headers, retries=retry, body=uploaded_file - ) - assert resp.status == 200 - assert resp.data == data - class TestRetryAfter(HTTPDummyServerTestCase): @classmethod @@ -639,73 +637,6 @@ def test_redirect_after(self): assert delta < 1 -class TestFileBodiesOnRetryOrRedirect(HTTPDummyServerTestCase): - def setup_class(self): - super(TestFileBodiesOnRetryOrRedirect, self).setup_class() - self.base_url = "http://%s:%d" % (self.host, self.port) - self.base_url_alt = "http://%s:%d" % (self.host_alt, self.port) - - def test_retries_put_filehandle(self): - """HTTP PUT retry with a file-like object should not timeout""" - retry = Retry(total=3, status_forcelist=[418]) - # httplib reads in 8k chunks; use a larger content length - content_length = 65535 - data = b"A" * content_length - uploaded_file = io.BytesIO(data) - headers = { - "test-name": "test_retries_put_filehandle", - "Content-Length": str(content_length), - } - - with PoolManager() as http: - resp = http.urlopen( - "PUT", - "%s/successful_retry" % self.base_url, - headers=headers, - retries=retry, - body=uploaded_file, - redirect=False, - ) - assert resp.status == 200 - - def test_redirect_with_failed_tell(self): - """Abort request if failed to get a position from tell()""" - - class BadTellObject(io.BytesIO): - def tell(self): - raise IOError - - body = BadTellObject(b"the data") - url = "%s/redirect?target=/successful_retry" % self.base_url - # httplib uses fileno if Content-Length isn't supplied, - # which is unsupported by BytesIO. - headers = {"Content-Length": "8"} - - with PoolManager() as http: - with pytest.raises(UnrewindableBodyError) as e: - http.urlopen("PUT", url, headers=headers, body=body) - assert "Unable to record file position for" in str(e.value) - - @pytest.mark.parametrize( - ["target", "expected_target"], - [ - ("/echo_uri?q=1#fragment", b"/echo_uri?q=1"), - ("/echo_uri?#", b"/echo_uri?"), - ("/echo_uri#?", b"/echo_uri"), - ("/echo_uri#?#", b"/echo_uri"), - ("/echo_uri??#", b"/echo_uri??"), - ("/echo_uri?%3f#", b"/echo_uri?%3F"), - ("/echo_uri?%3F#", b"/echo_uri?%3F"), - ("/echo_uri?[]", b"/echo_uri?%5B%5D"), - ], - ) - def test_encode_http_target(self, target, expected_target): - with PoolManager() as http: - url = "http://%s:%d%s" % (self.host, self.port, target) - r = http.request("GET", url) - assert r.data == expected_target - - @pytest.mark.skipif(not HAS_IPV6, reason="IPv6 is not supported on this system") class TestIPv6PoolManager(IPv6HTTPDummyServerTestCase): @classmethod