Skip to content
This repository was archived by the owner on Apr 14, 2022. It is now read-only.

Add support for async files as request body #217

Merged
merged 1 commit into from
May 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
setup(
version=version,
cmdclass={
"build_py": unasync.cmdclass_build_py(rules=[unasync.Rule("/ahip/", "/hip/")])
"build_py": unasync.cmdclass_build_py(
rules=[
unasync.Rule(
"/ahip/",
"/hip/",
additional_replacements={
"anext": "next",
"await_if_coro": "return_non_coro",
},
)
]
)
},
)
2 changes: 1 addition & 1 deletion src/ahip/_backends/anyio_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def start_tls(self, server_hostname, ssl_context: SSLContext):
def getpeercert(self, binary_form=False):
return self._stream.getpeercert(binary_form=binary_form)

async def receive_some(self):
async def receive_some(self, read_timeout):
return await self._stream.receive_some(BUFSIZE)

async def send_and_receive_for_a_while(
Expand Down
147 changes: 74 additions & 73 deletions src/ahip/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import datetime
import itertools
import socket
import warnings

Expand All @@ -35,6 +34,7 @@
)
from .packages import six
from .util import ssl_ as ssl_util
from .util.unasync import await_if_coro, anext, ASYNC_MODE
from ._backends._common import LoopAbort
from ._backends._loader import load_backend, normalize_backend

Expand All @@ -44,24 +44,6 @@
ssl = None


def is_async_mode():
"""Tests if we're in the async part of the code or not"""

async def f():
"""Unasync transforms async functions in sync functions"""
return None

obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True


_ASYNC_MODE = is_async_mode()


# When it comes time to update this value as a part of regular maintenance
# (ie test_recent_date is failing) update it to ~6 months before the current date.
RECENT_DATE = datetime.date(2019, 1, 1)
Expand Down Expand Up @@ -106,17 +88,16 @@ def _stringify_headers(headers):
yield (name, value)


def _read_readable(readable):
async def _read_readable(readable):
# TODO: reconsider this block size
blocksize = 8192
while True:
datablock = readable.read(blocksize)
datablock = await await_if_coro(readable.read(blocksize))
if not datablock:
break
yield datablock


# XX this should return an async iterator
def _make_body_iterable(body):
"""
This function turns all possible body types that Hip supports into an
Expand All @@ -134,63 +115,83 @@ def _make_body_iterable(body):
is deliberate: users must make choices about the encoding of the data they
use.
"""
if body is None:
return []
elif isinstance(body, bytes):
return [body]
elif hasattr(body, "read"):
return _read_readable(body)
elif isinstance(body, collections.Iterable) and not isinstance(body, six.text_type):
return body
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))

async def generator():
if body is None:
return
elif isinstance(body, bytes):
yield body
elif hasattr(body, "read"):
async for chunk in _read_readable(body):
yield chunk
elif isinstance(body, collections.Iterable) and not isinstance(
body, six.text_type
):
for chunk in body:
yield chunk
else:
raise InvalidBodyError("Unacceptable body type: %s" % type(body))

return generator().__aiter__()


# XX this should return an async iterator
def _request_bytes_iterable(request, state_machine):
"""
An iterable that serialises a set of bytes for the body.
"""

def all_pieces_iter():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super happy with this function, there's probably a simpler way to do this rather than a direct translation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. Maybe @njsmith know how to improve this? In any case, this isn't blocking, so I'll merge this pull request.

h11_request = h11.Request(
method=request.method,
target=request.target,
headers=_stringify_headers(request.headers.items()),
async def generator():
h11_request = h11.Request(
method=request.method,
target=request.target,
headers=_stringify_headers(request.headers.items()),
)
yield state_machine.send(h11_request)

async for chunk in _make_body_iterable(request.body):
yield state_machine.send(h11.Data(data=chunk))

yield state_machine.send(h11.EndOfMessage())

return generator().__aiter__()

async def generator():

# Try to combine the header bytes + (first set of body bytes or end of
# message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should
# never raise StopIteration.
remaining_pieces = all_pieces_iter()
first_packet_bytes = (await anext(remaining_pieces)) + (
await anext(remaining_pieces)
)
yield state_machine.send(h11_request)

for chunk in _make_body_iterable(request.body):
yield state_machine.send(h11.Data(data=chunk))

yield state_machine.send(h11.EndOfMessage())

# Try to combine the header bytes + (first set of body bytes or end of
# message bytes) into one packet.
# As long as all_pieces_iter() yields at least two messages, this should
# never raise StopIteration.
remaining_pieces = all_pieces_iter()
first_packet_bytes = next(remaining_pieces) + next(remaining_pieces)
all_pieces_combined_iter = itertools.chain([first_packet_bytes], remaining_pieces)

# We filter out any empty strings, because we don't want to call
# send(b""). You might think this is a no-op, so it shouldn't matter
# either way. But this isn't true. For example, if we're sending a request
# with Content-Length framing, we could have this sequence:
#
# - We send the last Data event.
# - The peer immediately sends its response and closes the socket.
# - We attempt to send the EndOfMessage event, which (b/c this request has
# Content-Length framing) is encoded as b"".
# - We call send(b"").
# - This triggers the kernel / SSL layer to discover that the socket is
# closed, so it raises an exception.
#
# It's easier to fix this once here instead of worrying about it in all
# the different backends.
for piece in all_pieces_combined_iter:
if piece:
yield piece

async def all_pieces_combined_iter():
yield first_packet_bytes
async for piece in remaining_pieces:
yield piece

# We filter out any empty strings, because we don't want to call
# send(b""). You might think this is a no-op, so it shouldn't matter
# either way. But this isn't true. For example, if we're sending a request
# with Content-Length framing, we could have this sequence:
#
# - We send the last Data event.
# - The peer immediately sends its response and closes the socket.
# - We attempt to send the EndOfMessage event, which (b/c this request has
# Content-Length framing) is encoded as b"".
# - We call send(b"").
# - This triggers the kernel / SSL layer to discover that the socket is
# closed, so it raises an exception.
#
# It's easier to fix this once here instead of worrying about it in all
# the different backends.
async for piece in all_pieces_combined_iter():
if piece:
yield piece

return generator().__aiter__()


def _response_from_h11(h11_response, body_object):
Expand Down Expand Up @@ -259,8 +260,8 @@ async def _start_http_request(request, state_machine, sock, read_timeout=None):

async def produce_bytes():
try:
return next(request_bytes_iterable)
except StopIteration:
return await anext(request_bytes_iterable)
except StopAsyncIteration:
# We successfully sent the whole body!
context["send_aborted"] = False
return None
Expand Down Expand Up @@ -346,7 +347,7 @@ def __init__(
):
self.is_verified = False
self.read_timeout = None
self._backend = load_backend(normalize_backend(backend, _ASYNC_MODE))
self._backend = load_backend(normalize_backend(backend, ASYNC_MODE))
self._host = host
self._port = port
self._socket_options = (
Expand Down
2 changes: 1 addition & 1 deletion src/ahip/connectionpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ async def urlopen(

# Rewind body position, if needed. Record current position
# for future rewinds in the event of a redirect/retry.
body_pos = set_file_position(body, body_pos)
body_pos = await set_file_position(body, body_pos)

if body is not None:
_add_transport_headers(headers)
Expand Down
2 changes: 1 addition & 1 deletion src/ahip/poolmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ async def urlopen(self, method, url, redirect=True, **kw):
# for future rewinds in the event of a redirect/retry.
body = kw.get("body")
body_pos = kw.get("body_pos")
kw["body_pos"] = set_file_position(body, body_pos)
kw["body_pos"] = await set_file_position(body, body_pos)

if "headers" not in kw:
kw["headers"] = self.headers.copy()
Expand Down
11 changes: 6 additions & 5 deletions src/ahip/util/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from base64 import b64encode

from .unasync import await_if_coro
from ..packages.six import b, integer_types
from ..exceptions import UnrewindableBodyError

Expand Down Expand Up @@ -87,16 +88,16 @@ def make_headers(
return headers


def set_file_position(body, pos):
async def set_file_position(body, pos):
"""
If a position is provided, move file to that point.
Otherwise, we'll attempt to record a position for future use.
"""
if pos is not None:
rewind_body(body, pos)
await rewind_body(body, pos)
elif getattr(body, "tell", None) is not None:
try:
pos = body.tell()
pos = await await_if_coro(body.tell())
except (IOError, OSError):
# This differentiates from None, allowing us to catch
# a failed `tell()` later when trying to rewind the body.
Expand All @@ -105,7 +106,7 @@ def set_file_position(body, pos):
return pos


def rewind_body(body, body_pos):
async def rewind_body(body, body_pos):
"""
Attempt to rewind body to a certain position.
Primarily used for request redirects and retries.
Expand All @@ -119,7 +120,7 @@ def rewind_body(body, body_pos):
body_seek = getattr(body, "seek", None)
if body_seek is not None and isinstance(body_pos, integer_types):
try:
body_seek(body_pos)
await await_if_coro(body_seek(body_pos))
except (IOError, OSError):
raise UnrewindableBodyError(
"An error occurred when rewinding request body for redirect/retry."
Expand Down
40 changes: 40 additions & 0 deletions src/ahip/util/unasync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Set of utility functions for unasync that transform into sync counterparts cleanly"""

import inspect

_original_next = next


def is_async_mode():
"""Tests if we're in the async part of the code or not"""

async def f():
"""Unasync transforms async functions in sync functions"""
return None

obj = f()
if obj is None:
return False
else:
obj.close() # prevent unawaited coroutine warning
return True


ASYNC_MODE = is_async_mode()


async def anext(x):
return await x.__anext__()


async def await_if_coro(x):
if inspect.iscoroutine(x):
return await x
return x


next = _original_next


def return_non_coro(x):
return x
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# We support Python 3.6+ for async code
if sys.version_info[:2] < (3, 6):
collect_ignore_glob = ["async/*.py", "with_dummyserver/async/*.py"]
collect_ignore_glob = ["async/*.py", "with_dummyserver/async*/*.py"]

# The Python 3.8+ default loop on Windows breaks Tornado
@pytest.fixture(scope="session", autouse=True)
Expand Down
Loading