diff --git a/samples/https_flask/functions/main.py b/samples/https_flask/functions/main.py index 09738d8..eb2b736 100644 --- a/samples/https_flask/functions/main.py +++ b/samples/https_flask/functions/main.py @@ -23,3 +23,15 @@ def world(): @https_fn.on_request() def httpsflaskexample(request): return entrypoint(app, request) + + +@https_fn.on_call() +def callableexample(request: https_fn.CallableRequest): + return request.data + + +@https_fn.on_call() +def streamingcallable(request: https_fn.CallableRequest): + yield "Hello," + yield "world!" + return request.data diff --git a/src/firebase_functions/firestore_fn.py b/src/firebase_functions/firestore_fn.py index a9d4f2a..6292d2f 100644 --- a/src/firebase_functions/firestore_fn.py +++ b/src/firebase_functions/firestore_fn.py @@ -219,7 +219,7 @@ def _firestore_endpoint_handler( func(database_event_with_auth_context) else: # mypy cannot infer that the event type is correct, hence the cast - _typing.cast(_C1 | _C2, func)(database_event) + _typing.cast(_C1 | _C2, func)(database_event) # type: ignore @_util.copy_func_kwargs(FirestoreOptions) diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 10749e9..df263ce 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -21,6 +21,7 @@ import typing_extensions as _typing_extensions import enum as _enum import json as _json +import inspect as _inspect import firebase_functions.private.util as _util import firebase_functions.core as _core from functions_framework import logging as _logging @@ -352,6 +353,22 @@ class CallableRequest(_typing.Generic[_core.T]): _C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] +class _IterWithReturn: + """ Utility class to capture return statements from a generator """ + + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + try: + self.value = yield from self.iterable + except RuntimeError as e: + if isinstance(e.__cause__, StopIteration): + self.value = e.__cause__.value + else: + raise + + def _on_call_handler(func: _C2, request: Request, enforce_app_check: bool) -> Response: try: @@ -401,7 +418,19 @@ def _on_call_handler(func: _C2, request: Request, "Firebase-Instance-ID-Token"), ) result = _core._with_init(func)(context) - return _jsonify(result=result) + if not _inspect.isgenerator(result): + return _jsonify(result=result) + + if request.headers.get("Accept") != "text/event-stream": + vals = _IterWithReturn(result) + # Consume and drop yielded results + list(vals) + return _jsonify(result=vals.value) + + else: + return Response(_sse_encode_generator(result), + content_type="text/event-stream") + # Disable broad exceptions lint since we want to handle all exceptions here # and wrap as an HttpsError. # pylint: disable=broad-except @@ -413,6 +442,24 @@ def _on_call_handler(func: _C2, request: Request, return _make_response(_jsonify(error=err._as_dict()), status) +def _sse_encode_generator(gen: _typing.Generator): + with_return = _IterWithReturn(gen) + try: + for chunk in with_return: + data = _json.dumps(obj={"message": chunk}) + yield f"data: {data}\n\n" + result = _json.dumps({"result": with_return.value}) + yield f"data: {result}\n\n" + # pylint: disable=broad-except + except Exception as err: + if not isinstance(err, HttpsError): + _logging.error("Unhandled error: %s", err) + err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") + json = _json.dumps(obj={"error": err._as_dict()}) + yield f"error: {json}\n\n" + yield "END" + + @_util.copy_func_kwargs(HttpsOptions) def on_request(**kwargs) -> _typing.Callable[[_C1], _C1]: """ diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py index e128b39..9df5acd 100644 --- a/tests/test_https_fn.py +++ b/tests/test_https_fn.py @@ -3,7 +3,6 @@ """ import unittest -from unittest.mock import Mock from flask import Flask, Request from werkzeug.test import EnvironBuilder @@ -25,7 +24,9 @@ def init(): nonlocal hello hello = "world" - func = Mock(__name__="example_func") + @https_fn.on_request() + def func(_): + pass with app.test_request_context("/"): environ = EnvironBuilder( @@ -37,9 +38,8 @@ def init(): }, ).get_environ() request = Request(environ) - decorated_func = https_fn.on_request()(func) - decorated_func(request) + func(request) self.assertEqual(hello, "world") @@ -53,7 +53,9 @@ def init(): nonlocal hello hello = "world" - func = Mock(__name__="example_func") + @https_fn.on_call() + def func(_): + pass with app.test_request_context("/"): environ = EnvironBuilder( @@ -65,8 +67,178 @@ def init(): }, ).get_environ() request = Request(environ) - decorated_func = https_fn.on_call()(func) - - decorated_func(request) + func(request) self.assertEqual("world", hello) + + def test_callable_encoding(self): + app = Flask(__name__) + + @https_fn.on_call() + def add(req: https_fn.CallableRequest[int]): + return req.data + 1 + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", json={ + "data": 1 + }).get_environ() + request = Request(environ) + + response = add(request) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), {"result": 2}) + + def test_callable_errors(self): + app = Flask(__name__) + + @https_fn.on_call() + def throw_generic_error(req): + # pylint: disable=broad-exception-raised + raise Exception("Invalid type") + + @https_fn.on_call() + def throw_access_denied(req): + raise https_fn.HttpsError( + https_fn.FunctionsErrorCode.PERMISSION_DENIED, + "Permission is denied") + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", json={ + "data": None + }).get_environ() + request = Request(environ) + + response = throw_generic_error(request) + self.assertEqual(response.status_code, 500) + self.assertEqual( + response.get_json(), + {"error": { + "message": "INTERNAL", + "status": "INTERNAL" + }}) + + response = throw_access_denied(request) + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.get_json(), { + "error": { + "message": "Permission is denied", + "status": "PERMISSION_DENIED" + } + }) + + def test_yielding_without_streaming(self): + app = Flask(__name__) + + @https_fn.on_call() + def yielder(req: https_fn.CallableRequest[int]): + yield from range(req.data) + return "OK" + + @https_fn.on_call() + def yield_thrower(req: https_fn.CallableRequest[int]): + yield from range(req.data) + raise https_fn.HttpsError( + https_fn.FunctionsErrorCode.PERMISSION_DENIED, + "Can't read anymore") + + @https_fn.on_call() + def legacy_yielder(req: https_fn.CallableRequest[int]): + yield from range(req.data) + # Prior to Python 3.3, this was the way "return" was handled + # Python 3.5 made this messy however because it converts + # raised StopIteration into a RuntimeError + # pylint: disable=stop-iteration-return + raise StopIteration("OK") + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", json={ + "data": 5 + }).get_environ() + + request = Request(environ) + response = yielder(request) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), {"result": "OK"}) + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", json={ + "data": 5 + }).get_environ() + + request = Request(environ) + response = legacy_yielder(request) + + self.assertEqual(response.status_code, 200) + self.assertEqual(response.get_json(), {"result": "OK"}) + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", json={ + "data": 3 + }).get_environ() + + request = Request(environ) + response = yield_thrower(request) + + self.assertEqual(response.status_code, 403) + self.assertEqual( + response.get_json(), { + "error": { + "message": "Can't read anymore", + "status": "PERMISSION_DENIED" + } + }) + + def test_yielding_with_streaming(self): + app = Flask(__name__) + + @https_fn.on_call() + def yielder(req: https_fn.CallableRequest[int]): + yield from range(req.data) + return "OK" + + @https_fn.on_call() + def yield_thrower(req: https_fn.CallableRequest[int]): + yield from range(req.data) + raise https_fn.HttpsError(https_fn.FunctionsErrorCode.INTERNAL, + "Throwing") + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", + json={ + "data": 2 + }, + headers={ + "accept": "text/event-stream" + }).get_environ() + + request = Request(environ) + response = yielder(request) + + self.assertEqual(response.status_code, 200) + chunks = list(response.response) + self.assertEqual(chunks, [ + 'data: {"message": 0}\n\n', 'data: {"message": 1}\n\n', + 'data: {"result": "OK"}\n\n', "END" + ]) + + with app.test_request_context("/"): + environ = EnvironBuilder(method="POST", + json={ + "data": 2 + }, + headers={ + "accept": "text/event-stream" + }).get_environ() + + request = Request(environ) + response = yield_thrower(request) + + self.assertEqual(response.status_code, 200) + chunks = list(response.response) + self.assertEqual(chunks, [ + 'data: {"message": 0}\n\n', 'data: {"message": 1}\n\n', + 'error: {"error": {"status": "INTERNAL", "message": "Throwing"}}\n\n', + "END" + ])