Skip to content

Commit

Permalink
refactor: Make api handler work only in class methods
Browse files Browse the repository at this point in the history
  • Loading branch information
seedspirit committed Feb 6, 2025
1 parent 1f99b26 commit 505c367
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 50 deletions.
14 changes: 6 additions & 8 deletions src/ai/backend/common/api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def no_content(cls, status_code: int) -> Self:
def to_json(self) -> Optional[JSONDict]:
return self._data.model_dump(mode="json") if self._data else None

@property
def status_code(self) -> int:
return self._status_code


_ParamType: TypeAlias = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam

Expand Down Expand Up @@ -252,13 +256,14 @@ async def _parse_and_execute_handler(

return web.json_response(
response.to_json,
status=response._status_code,
status=response.status_code,
)


def api_handler(handler: BaseHandler) -> ParsedRequestHandler:
"""
This decorator processes HTTP request parameters using Pydantic models.
NOTICE: API hander methods must be classmethod. It handlers are not class methods it will not work as intended
1. Request Body:
@api_handler
Expand Down Expand Up @@ -330,13 +335,6 @@ async def handler(

@functools.wraps(handler)
async def wrapped(first_arg: Any, *args, **kwargs) -> web.Response:
if isinstance(first_arg, web.Request):
return await _parse_and_execute_handler(
request=first_arg, handler=handler, signature=original_signature
)

# If handler is method defined in class
# Remove 'self' in parameters
instance = first_arg
sanitized_signature = original_signature.replace(
parameters=list(original_signature.parameters.values())[1:]
Expand Down
116 changes: 74 additions & 42 deletions tests/common/test_api_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,21 @@ class TestMessageResponse(BaseResponseModel):
message: str


@pytest.mark.asyncio
async def test_empty_parameter(aiohttp_client):
class TestMessageHandler:
@api_handler
async def handler() -> APIResponse:
async def handle_message(self) -> APIResponse:
return APIResponse.build(
status_code=200, response_model=TestMessageResponse(message="test")
)


@pytest.mark.asyncio
async def test_empty_parameter(aiohttp_client):
handler = TestMessageHandler()
app = web.Application()
app.router.add_route("GET", "/test", handler)
app.router.add_route("GET", "/test", handler.handle_message)

client = await aiohttp_client(app)

resp = await client.get("/test")

assert resp.status == 200
Expand All @@ -149,21 +151,23 @@ class TestPostUserResponse(BaseResponseModel):
age: int


@pytest.mark.asyncio
async def test_body_parameter(aiohttp_client):
class TestPostUserHandler:
@api_handler
async def handler(user: BodyParam[TestPostUserModel]) -> APIResponse:
async def handle_user(self, user: BodyParam[TestPostUserModel]) -> APIResponse:
parsed_user = user.parsed
return APIResponse.build(
status_code=200,
response_model=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age),
)


@pytest.mark.asyncio
async def test_body_parameter(aiohttp_client):
handler = TestPostUserHandler()
app = web.Application()
app.router.add_route("POST", "/test", handler)
app.router.add_route("POST", "/test", handler.handle_user)

client = await aiohttp_client(app)

test_data = {"name": "John", "age": 30}
resp = await client.post("/test", json=test_data)

Expand All @@ -183,10 +187,9 @@ class TestSearchQueryResponse(BaseResponseModel):
page: Optional[int] = Field(default=1)


@pytest.mark.asyncio
async def test_query_parameter(aiohttp_client):
class TestSearchQueryHandler:
@api_handler
async def handler(query: QueryParam[TestSearchQueryModel]) -> APIResponse:
async def handle_search(self, query: QueryParam[TestSearchQueryModel]) -> APIResponse:
parsed_query = query.parsed
return APIResponse.build(
status_code=200,
Expand All @@ -195,8 +198,12 @@ async def handler(query: QueryParam[TestSearchQueryModel]) -> APIResponse:
),
)


@pytest.mark.asyncio
async def test_query_parameter(aiohttp_client):
handler = TestSearchQueryHandler()
app = web.Application()
app.router.add_get("/test", handler)
app.router.add_get("/test", handler.handle_search)

client = await aiohttp_client(app)
resp = await client.get("/test?search=test&page=2")
Expand All @@ -215,18 +222,21 @@ class TestAuthHeaderResponse(BaseResponseModel):
authorization: str


@pytest.mark.asyncio
async def test_header_parameter(aiohttp_client):
class TestAuthHeaderHandler:
@api_handler
async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> APIResponse:
async def handle_auth(self, headers: HeaderParam[TestAuthHeaderModel]) -> APIResponse:
parsed_headers = headers.parsed
return APIResponse.build(
status_code=200,
response_model=TestAuthHeaderResponse(authorization=parsed_headers.authorization),
)


@pytest.mark.asyncio
async def test_header_parameter(aiohttp_client):
handler = TestAuthHeaderHandler()
app = web.Application()
app.router.add_get("/test", handler)
app.router.add_get("/test", handler.handle_auth)

client = await aiohttp_client(app)
headers = {"Authorization": "Bearer token123"}
Expand All @@ -245,17 +255,20 @@ class TestUserPathResponse(BaseResponseModel):
user_id: str


@pytest.mark.asyncio
async def test_path_parameter(aiohttp_client):
class TestUserPathHandler:
@api_handler
async def handler(path: PathParam[TestUserPathModel]) -> APIResponse:
async def handle_path(self, path: PathParam[TestUserPathModel]) -> APIResponse:
parsed_path = path.parsed
return APIResponse.build(
status_code=200, response_model=TestUserPathResponse(user_id=parsed_path.user_id)
)


@pytest.mark.asyncio
async def test_path_parameter(aiohttp_client):
handler = TestUserPathHandler()
app = web.Application()
app.router.add_get("/test/{user_id}", handler)
app.router.add_get("/test/{user_id}", handler.handle_path)

client = await aiohttp_client(app)
resp = await client.get("/test/123")
Expand All @@ -277,22 +290,26 @@ class TestAuthResponse(BaseResponseModel):
is_authorized: bool = Field(default=False)


@pytest.mark.asyncio
async def test_middleware_parameter(aiohttp_client):
class TestAuthHandler:
@api_handler
async def handler(auth: TestAuthInfo) -> APIResponse:
async def handle_middleware_auth(self, auth: TestAuthInfo) -> APIResponse:
return APIResponse.build(
status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized)
)


@pytest.mark.asyncio
async def test_middleware_parameter(aiohttp_client):
handler = TestAuthHandler()

@web.middleware
async def auth_middleware(request, handler):
request["is_authorized"] = True
return await handler(request)

app = web.Application()
app.middlewares.append(auth_middleware)
app.router.add_get("/test", handler)
app.router.add_get("/test", handler.handle_middleware_auth)
client = await aiohttp_client(app)

resp = await client.get("/test")
Expand All @@ -302,22 +319,26 @@ async def auth_middleware(request, handler):
assert data["is_authorized"]


@pytest.mark.asyncio
async def test_middleware_parameter_invalid_type(aiohttp_client):
class TestInvalidAuthHandler:
@api_handler
async def handler(auth: TestAuthInfo) -> APIResponse:
async def handle_invalid_auth(self, auth: TestAuthInfo) -> APIResponse:
return APIResponse.build(
status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized)
)


@pytest.mark.asyncio
async def test_middleware_parameter_invalid_type(aiohttp_client):
handler = TestInvalidAuthHandler()

@web.middleware
async def broken_auth_middleware(request, handler):
request["is_authorized"] = "not_a_boolean"
return await handler(request)

app = web.Application()
app.middlewares.append(broken_auth_middleware)
app.router.add_get("/test", handler)
app.router.add_get("/test", handler.handle_invalid_auth)
client = await aiohttp_client(app)

resp = await client.get("/test")
Expand Down Expand Up @@ -350,10 +371,10 @@ class TestCombinedResponse(BaseResponseModel):
is_authorized: bool


@pytest.mark.asyncio
async def test_multiple_parameters(aiohttp_client):
class TestMultipleParamsHandler:
@api_handler
async def handler(
async def handle_multiple(
self,
body: BodyParam[TestCreateUserModel],
auth: TestMiddlewareModel,
query: QueryParam[TestSearchParamModel],
Expand All @@ -370,14 +391,19 @@ async def handler(
),
)


@pytest.mark.asyncio
async def test_multiple_parameters(aiohttp_client):
handler = TestMultipleParamsHandler()

@web.middleware
async def auth_middleware(request, handler):
request["is_authorized"] = True
return await handler(request)

app = web.Application()
app.middlewares.append(auth_middleware)
app.router.add_post("/test", handler)
app.router.add_post("/test", handler.handle_multiple)

client = await aiohttp_client(app)
test_data = {"user_name": "John"}
Expand All @@ -400,18 +426,21 @@ class TestRegisterUserResponse(BaseResponseModel):
age: int


@pytest.mark.asyncio
async def test_invalid_body(aiohttp_client):
class TestRegisterUserHandler:
@api_handler
async def handler(user: BodyParam[TestRegisterUserModel]) -> APIResponse:
async def handle_register(self, user: BodyParam[TestRegisterUserModel]) -> APIResponse:
test_user = user.parsed
return APIResponse.build(
status_code=200,
response_model=TestRegisterUserResponse(name=test_user.name, age=test_user.age),
)


@pytest.mark.asyncio
async def test_invalid_body(aiohttp_client):
handler = TestRegisterUserHandler()
app = web.Application()
app.router.add_post("/test", handler)
app.router.add_post("/test", handler.handle_register)
client = await aiohttp_client(app)

test_data = {"name": "John"} # age field missing
Expand All @@ -429,10 +458,9 @@ class TestProductSearchResponse(BaseResponseModel):
page: Optional[int] = Field(default=1)


@pytest.mark.asyncio
async def test_invalid_query_parameter(aiohttp_client):
class TestProductSearchHandler:
@api_handler
async def handler(query: QueryParam[TestProductSearchModel]) -> APIResponse:
async def handle_product_search(self, query: QueryParam[TestProductSearchModel]) -> APIResponse:
parsed_query = query.parsed
return APIResponse.build(
status_code=200,
Expand All @@ -441,8 +469,12 @@ async def handler(query: QueryParam[TestProductSearchModel]) -> APIResponse:
),
)


@pytest.mark.asyncio
async def test_invalid_query_parameter(aiohttp_client):
handler = TestProductSearchHandler()
app = web.Application()
app.router.add_get("/test", handler)
app.router.add_get("/test", handler.handle_product_search)
client = await aiohttp_client(app)
error_response = await client.get("/test") # request with no query parameter
assert error_response.status == 400 # InvalidAPIParameters Error raised

0 comments on commit 505c367

Please sign in to comment.