Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement streaming callables #225

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions samples/https_flask/functions/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,15 @@ def world():
@https_fn.on_request()
def httpsflaskexample(request):
return entrypoint(app, request)


@https_fn.on_call()
def callableexample(request: https_fn.CallableRequest):
return request.data


@https_fn.on_call()
def streamingcallable(request: https_fn.CallableRequest):
yield "Hello,"
yield "world!"
return request.data
2 changes: 1 addition & 1 deletion src/firebase_functions/firestore_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _firestore_endpoint_handler(
func(database_event_with_auth_context)
else:
# mypy cannot infer that the event type is correct, hence the cast
_typing.cast(_C1 | _C2, func)(database_event)
_typing.cast(_C1 | _C2, func)(database_event) # type: ignore


@_util.copy_func_kwargs(FirestoreOptions)
Expand Down
49 changes: 48 additions & 1 deletion src/firebase_functions/https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import typing_extensions as _typing_extensions
import enum as _enum
import json as _json
import inspect as _inspect
import firebase_functions.private.util as _util
import firebase_functions.core as _core
from functions_framework import logging as _logging
Expand Down Expand Up @@ -352,6 +353,22 @@ class CallableRequest(_typing.Generic[_core.T]):
_C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any]


class _IterWithReturn:
""" Utility class to capture return statements from a generator """

def __init__(self, iterable):
self.iterable = iterable

def __iter__(self):
try:
self.value = yield from self.iterable
except RuntimeError as e:
if isinstance(e.__cause__, StopIteration):
self.value = e.__cause__.value
else:
raise


def _on_call_handler(func: _C2, request: Request,
enforce_app_check: bool) -> Response:
try:
Expand Down Expand Up @@ -401,7 +418,19 @@ def _on_call_handler(func: _C2, request: Request,
"Firebase-Instance-ID-Token"),
)
result = _core._with_init(func)(context)
return _jsonify(result=result)
if not _inspect.isgenerator(result):
return _jsonify(result=result)

if request.headers.get("Accept") != "text/event-stream":
vals = _IterWithReturn(result)
# Consume and drop yielded results
list(vals)
return _jsonify(result=vals.value)

else:
return Response(_sse_encode_generator(result),
content_type="text/event-stream")

# Disable broad exceptions lint since we want to handle all exceptions here
# and wrap as an HttpsError.
# pylint: disable=broad-except
Expand All @@ -413,6 +442,24 @@ def _on_call_handler(func: _C2, request: Request,
return _make_response(_jsonify(error=err._as_dict()), status)


def _sse_encode_generator(gen: _typing.Generator):
with_return = _IterWithReturn(gen)
try:
for chunk in with_return:
data = _json.dumps(obj={"message": chunk})
yield f"data: {data}\n\n"
result = _json.dumps({"result": with_return.value})
yield f"data: {result}\n\n"
# pylint: disable=broad-except
except Exception as err:
if not isinstance(err, HttpsError):
_logging.error("Unhandled error: %s", err)
err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL")
json = _json.dumps(obj={"error": err._as_dict()})
yield f"error: {json}\n\n"
yield "END"


@_util.copy_func_kwargs(HttpsOptions)
def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]:
"""
Expand Down
188 changes: 180 additions & 8 deletions tests/test_https_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import unittest
from unittest.mock import Mock
from flask import Flask, Request
from werkzeug.test import EnvironBuilder

Expand All @@ -25,7 +24,9 @@ def init():
nonlocal hello
hello = "world"

func = Mock(__name__="example_func")
@https_fn.on_request()
def func(_):
pass

with app.test_request_context("/"):
environ = EnvironBuilder(
Expand All @@ -37,9 +38,8 @@ def init():
},
).get_environ()
request = Request(environ)
decorated_func = https_fn.on_request()(func)

decorated_func(request)
func(request)

self.assertEqual(hello, "world")

Expand All @@ -53,7 +53,9 @@ def init():
nonlocal hello
hello = "world"

func = Mock(__name__="example_func")
@https_fn.on_call()
def func(_):
pass

with app.test_request_context("/"):
environ = EnvironBuilder(
Expand All @@ -65,8 +67,178 @@ def init():
},
).get_environ()
request = Request(environ)
decorated_func = https_fn.on_call()(func)

decorated_func(request)
func(request)

self.assertEqual("world", hello)

def test_callable_encoding(self):
app = Flask(__name__)

@https_fn.on_call()
def add(req: https_fn.CallableRequest[int]):
return req.data + 1

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST", json={
"data": 1
}).get_environ()
request = Request(environ)

response = add(request)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_json(), {"result": 2})

def test_callable_errors(self):
app = Flask(__name__)

@https_fn.on_call()
def throw_generic_error(req):
# pylint: disable=broad-exception-raised
raise Exception("Invalid type")

@https_fn.on_call()
def throw_access_denied(req):
raise https_fn.HttpsError(
https_fn.FunctionsErrorCode.PERMISSION_DENIED,
"Permission is denied")

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST", json={
"data": None
}).get_environ()
request = Request(environ)

response = throw_generic_error(request)
self.assertEqual(response.status_code, 500)
self.assertEqual(
response.get_json(),
{"error": {
"message": "INTERNAL",
"status": "INTERNAL"
}})

response = throw_access_denied(request)
self.assertEqual(response.status_code, 403)
self.assertEqual(
response.get_json(), {
"error": {
"message": "Permission is denied",
"status": "PERMISSION_DENIED"
}
})

def test_yielding_without_streaming(self):
app = Flask(__name__)

@https_fn.on_call()
def yielder(req: https_fn.CallableRequest[int]):
yield from range(req.data)
return "OK"

@https_fn.on_call()
def yield_thrower(req: https_fn.CallableRequest[int]):
yield from range(req.data)
raise https_fn.HttpsError(
https_fn.FunctionsErrorCode.PERMISSION_DENIED,
"Can't read anymore")

@https_fn.on_call()
def legacy_yielder(req: https_fn.CallableRequest[int]):
yield from range(req.data)
# Prior to Python 3.3, this was the way "return" was handled
# Python 3.5 made this messy however because it converts
# raised StopIteration into a RuntimeError
# pylint: disable=stop-iteration-return
raise StopIteration("OK")

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST", json={
"data": 5
}).get_environ()

request = Request(environ)
response = yielder(request)

self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_json(), {"result": "OK"})

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST", json={
"data": 5
}).get_environ()

request = Request(environ)
response = legacy_yielder(request)

self.assertEqual(response.status_code, 200)
self.assertEqual(response.get_json(), {"result": "OK"})

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST", json={
"data": 3
}).get_environ()

request = Request(environ)
response = yield_thrower(request)

self.assertEqual(response.status_code, 403)
self.assertEqual(
response.get_json(), {
"error": {
"message": "Can't read anymore",
"status": "PERMISSION_DENIED"
}
})

def test_yielding_with_streaming(self):
app = Flask(__name__)

@https_fn.on_call()
def yielder(req: https_fn.CallableRequest[int]):
yield from range(req.data)
return "OK"

@https_fn.on_call()
def yield_thrower(req: https_fn.CallableRequest[int]):
yield from range(req.data)
raise https_fn.HttpsError(https_fn.FunctionsErrorCode.INTERNAL,
"Throwing")

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST",
json={
"data": 2
},
headers={
"accept": "text/event-stream"
}).get_environ()

request = Request(environ)
response = yielder(request)

self.assertEqual(response.status_code, 200)
chunks = list(response.response)
self.assertEqual(chunks, [
'data: {"message": 0}\n\n', 'data: {"message": 1}\n\n',
'data: {"result": "OK"}\n\n', "END"
])

with app.test_request_context("/"):
environ = EnvironBuilder(method="POST",
json={
"data": 2
},
headers={
"accept": "text/event-stream"
}).get_environ()

request = Request(environ)
response = yield_thrower(request)

self.assertEqual(response.status_code, 200)
chunks = list(response.response)
self.assertEqual(chunks, [
'data: {"message": 0}\n\n', 'data: {"message": 1}\n\n',
'error: {"error": {"status": "INTERNAL", "message": "Throwing"}}\n\n',
"END"
])
Loading