From 505c367a8b44757ab6ed0d035b1c494ed1de6f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 6 Feb 2025 19:00:44 +0900 Subject: [PATCH] refactor: Make api handler work only in class methods --- src/ai/backend/common/api_handlers.py | 14 ++-- tests/common/test_api_handlers.py | 116 ++++++++++++++++---------- 2 files changed, 80 insertions(+), 50 deletions(-) diff --git a/src/ai/backend/common/api_handlers.py b/src/ai/backend/common/api_handlers.py index 01970d8a553..1527b67b755 100644 --- a/src/ai/backend/common/api_handlers.py +++ b/src/ai/backend/common/api_handlers.py @@ -150,6 +150,10 @@ def no_content(cls, status_code: int) -> Self: def to_json(self) -> Optional[JSONDict]: return self._data.model_dump(mode="json") if self._data else None + @property + def status_code(self) -> int: + return self._status_code + _ParamType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam @@ -252,13 +256,14 @@ async def _parse_and_execute_handler( return web.json_response( response.to_json, - status=response._status_code, + status=response.status_code, ) def api_handler(handler: BaseHandler) -> ParsedRequestHandler: """ This decorator processes HTTP request parameters using Pydantic models. + NOTICE: API hander methods must be classmethod. It handlers are not class methods it will not work as intended 1. Request Body: @api_handler @@ -330,13 +335,6 @@ async def handler( @functools.wraps(handler) async def wrapped(first_arg: Any, *args, **kwargs) -> web.Response: - if isinstance(first_arg, web.Request): - return await _parse_and_execute_handler( - request=first_arg, handler=handler, signature=original_signature - ) - - # If handler is method defined in class - # Remove 'self' in parameters instance = first_arg sanitized_signature = original_signature.replace( parameters=list(original_signature.parameters.values())[1:] diff --git a/tests/common/test_api_handlers.py b/tests/common/test_api_handlers.py index 45877c091a1..dc246a74304 100644 --- a/tests/common/test_api_handlers.py +++ b/tests/common/test_api_handlers.py @@ -119,19 +119,21 @@ class TestMessageResponse(BaseResponseModel): message: str -@pytest.mark.asyncio -async def test_empty_parameter(aiohttp_client): +class TestMessageHandler: @api_handler - async def handler() -> APIResponse: + async def handle_message(self) -> APIResponse: return APIResponse.build( status_code=200, response_model=TestMessageResponse(message="test") ) + +@pytest.mark.asyncio +async def test_empty_parameter(aiohttp_client): + handler = TestMessageHandler() app = web.Application() - app.router.add_route("GET", "/test", handler) + app.router.add_route("GET", "/test", handler.handle_message) client = await aiohttp_client(app) - resp = await client.get("/test") assert resp.status == 200 @@ -149,21 +151,23 @@ class TestPostUserResponse(BaseResponseModel): age: int -@pytest.mark.asyncio -async def test_body_parameter(aiohttp_client): +class TestPostUserHandler: @api_handler - async def handler(user: BodyParam[TestPostUserModel]) -> APIResponse: + async def handle_user(self, user: BodyParam[TestPostUserModel]) -> APIResponse: parsed_user = user.parsed return APIResponse.build( status_code=200, response_model=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age), ) + +@pytest.mark.asyncio +async def test_body_parameter(aiohttp_client): + handler = TestPostUserHandler() app = web.Application() - app.router.add_route("POST", "/test", handler) + app.router.add_route("POST", "/test", handler.handle_user) client = await aiohttp_client(app) - test_data = {"name": "John", "age": 30} resp = await client.post("/test", json=test_data) @@ -183,10 +187,9 @@ class TestSearchQueryResponse(BaseResponseModel): page: Optional[int] = Field(default=1) -@pytest.mark.asyncio -async def test_query_parameter(aiohttp_client): +class TestSearchQueryHandler: @api_handler - async def handler(query: QueryParam[TestSearchQueryModel]) -> APIResponse: + async def handle_search(self, query: QueryParam[TestSearchQueryModel]) -> APIResponse: parsed_query = query.parsed return APIResponse.build( status_code=200, @@ -195,8 +198,12 @@ async def handler(query: QueryParam[TestSearchQueryModel]) -> APIResponse: ), ) + +@pytest.mark.asyncio +async def test_query_parameter(aiohttp_client): + handler = TestSearchQueryHandler() app = web.Application() - app.router.add_get("/test", handler) + app.router.add_get("/test", handler.handle_search) client = await aiohttp_client(app) resp = await client.get("/test?search=test&page=2") @@ -215,18 +222,21 @@ class TestAuthHeaderResponse(BaseResponseModel): authorization: str -@pytest.mark.asyncio -async def test_header_parameter(aiohttp_client): +class TestAuthHeaderHandler: @api_handler - async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> APIResponse: + async def handle_auth(self, headers: HeaderParam[TestAuthHeaderModel]) -> APIResponse: parsed_headers = headers.parsed return APIResponse.build( status_code=200, response_model=TestAuthHeaderResponse(authorization=parsed_headers.authorization), ) + +@pytest.mark.asyncio +async def test_header_parameter(aiohttp_client): + handler = TestAuthHeaderHandler() app = web.Application() - app.router.add_get("/test", handler) + app.router.add_get("/test", handler.handle_auth) client = await aiohttp_client(app) headers = {"Authorization": "Bearer token123"} @@ -245,17 +255,20 @@ class TestUserPathResponse(BaseResponseModel): user_id: str -@pytest.mark.asyncio -async def test_path_parameter(aiohttp_client): +class TestUserPathHandler: @api_handler - async def handler(path: PathParam[TestUserPathModel]) -> APIResponse: + async def handle_path(self, path: PathParam[TestUserPathModel]) -> APIResponse: parsed_path = path.parsed return APIResponse.build( status_code=200, response_model=TestUserPathResponse(user_id=parsed_path.user_id) ) + +@pytest.mark.asyncio +async def test_path_parameter(aiohttp_client): + handler = TestUserPathHandler() app = web.Application() - app.router.add_get("/test/{user_id}", handler) + app.router.add_get("/test/{user_id}", handler.handle_path) client = await aiohttp_client(app) resp = await client.get("/test/123") @@ -277,14 +290,18 @@ class TestAuthResponse(BaseResponseModel): is_authorized: bool = Field(default=False) -@pytest.mark.asyncio -async def test_middleware_parameter(aiohttp_client): +class TestAuthHandler: @api_handler - async def handler(auth: TestAuthInfo) -> APIResponse: + async def handle_middleware_auth(self, auth: TestAuthInfo) -> APIResponse: return APIResponse.build( status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) + +@pytest.mark.asyncio +async def test_middleware_parameter(aiohttp_client): + handler = TestAuthHandler() + @web.middleware async def auth_middleware(request, handler): request["is_authorized"] = True @@ -292,7 +309,7 @@ async def auth_middleware(request, handler): app = web.Application() app.middlewares.append(auth_middleware) - app.router.add_get("/test", handler) + app.router.add_get("/test", handler.handle_middleware_auth) client = await aiohttp_client(app) resp = await client.get("/test") @@ -302,14 +319,18 @@ async def auth_middleware(request, handler): assert data["is_authorized"] -@pytest.mark.asyncio -async def test_middleware_parameter_invalid_type(aiohttp_client): +class TestInvalidAuthHandler: @api_handler - async def handler(auth: TestAuthInfo) -> APIResponse: + async def handle_invalid_auth(self, auth: TestAuthInfo) -> APIResponse: return APIResponse.build( status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) + +@pytest.mark.asyncio +async def test_middleware_parameter_invalid_type(aiohttp_client): + handler = TestInvalidAuthHandler() + @web.middleware async def broken_auth_middleware(request, handler): request["is_authorized"] = "not_a_boolean" @@ -317,7 +338,7 @@ async def broken_auth_middleware(request, handler): app = web.Application() app.middlewares.append(broken_auth_middleware) - app.router.add_get("/test", handler) + app.router.add_get("/test", handler.handle_invalid_auth) client = await aiohttp_client(app) resp = await client.get("/test") @@ -350,10 +371,10 @@ class TestCombinedResponse(BaseResponseModel): is_authorized: bool -@pytest.mark.asyncio -async def test_multiple_parameters(aiohttp_client): +class TestMultipleParamsHandler: @api_handler - async def handler( + async def handle_multiple( + self, body: BodyParam[TestCreateUserModel], auth: TestMiddlewareModel, query: QueryParam[TestSearchParamModel], @@ -370,6 +391,11 @@ async def handler( ), ) + +@pytest.mark.asyncio +async def test_multiple_parameters(aiohttp_client): + handler = TestMultipleParamsHandler() + @web.middleware async def auth_middleware(request, handler): request["is_authorized"] = True @@ -377,7 +403,7 @@ async def auth_middleware(request, handler): app = web.Application() app.middlewares.append(auth_middleware) - app.router.add_post("/test", handler) + app.router.add_post("/test", handler.handle_multiple) client = await aiohttp_client(app) test_data = {"user_name": "John"} @@ -400,18 +426,21 @@ class TestRegisterUserResponse(BaseResponseModel): age: int -@pytest.mark.asyncio -async def test_invalid_body(aiohttp_client): +class TestRegisterUserHandler: @api_handler - async def handler(user: BodyParam[TestRegisterUserModel]) -> APIResponse: + async def handle_register(self, user: BodyParam[TestRegisterUserModel]) -> APIResponse: test_user = user.parsed return APIResponse.build( status_code=200, response_model=TestRegisterUserResponse(name=test_user.name, age=test_user.age), ) + +@pytest.mark.asyncio +async def test_invalid_body(aiohttp_client): + handler = TestRegisterUserHandler() app = web.Application() - app.router.add_post("/test", handler) + app.router.add_post("/test", handler.handle_register) client = await aiohttp_client(app) test_data = {"name": "John"} # age field missing @@ -429,10 +458,9 @@ class TestProductSearchResponse(BaseResponseModel): page: Optional[int] = Field(default=1) -@pytest.mark.asyncio -async def test_invalid_query_parameter(aiohttp_client): +class TestProductSearchHandler: @api_handler - async def handler(query: QueryParam[TestProductSearchModel]) -> APIResponse: + async def handle_product_search(self, query: QueryParam[TestProductSearchModel]) -> APIResponse: parsed_query = query.parsed return APIResponse.build( status_code=200, @@ -441,8 +469,12 @@ async def handler(query: QueryParam[TestProductSearchModel]) -> APIResponse: ), ) + +@pytest.mark.asyncio +async def test_invalid_query_parameter(aiohttp_client): + handler = TestProductSearchHandler() app = web.Application() - app.router.add_get("/test", handler) + app.router.add_get("/test", handler.handle_product_search) client = await aiohttp_client(app) error_response = await client.get("/test") # request with no query parameter assert error_response.status == 400 # InvalidAPIParameters Error raised