Skip to content

Commit cdbac96

Browse files
Add workspaces to OpenAPI spec (#634)
* Add workspaces to OpenAPI spec * feat(server): move dashboard endpoints under api v1 router * fix: broken test * chore: tidy up dashboard api naming * fix: ruff format pass
1 parent e121a81 commit cdbac96

10 files changed

+408
-56
lines changed

api/openapi.json

+352-12
Large diffs are not rendered by default.

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ build-backend = "poetry.core.masonry.api"
5050

5151
[tool.poetry.scripts]
5252
codegate = "codegate.cli:main"
53-
generate-openapi = "src.codegate.dashboard.dashboard:generate_openapi"
53+
generate-openapi = "src.codegate.server:generate_openapi"
5454

5555
[tool.black]
5656
line-length = 100
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,49 @@
11
import asyncio
2-
import json
32
from typing import AsyncGenerator, List, Optional
43

54
import requests
65
import structlog
7-
from fastapi import APIRouter, Depends, FastAPI
6+
from fastapi import APIRouter, Depends
87
from fastapi.responses import StreamingResponse
8+
from fastapi.routing import APIRoute
99

1010
from codegate import __version__
11-
from codegate.dashboard.post_processing import (
11+
from codegate.api.dashboard.post_processing import (
1212
parse_get_alert_conversation,
1313
parse_messages_in_conversations,
1414
)
15-
from codegate.dashboard.request_models import AlertConversation, Conversation
15+
from codegate.api.dashboard.request_models import AlertConversation, Conversation
1616
from codegate.db.connection import DbReader, alert_queue
1717

1818
logger = structlog.get_logger("codegate")
1919

20-
dashboard_router = APIRouter(tags=["Dashboard"])
20+
dashboard_router = APIRouter()
2121
db_reader = None
2222

23+
24+
def uniq_name(route: APIRoute):
25+
return f"v1_{route.name}"
26+
27+
2328
def get_db_reader():
2429
global db_reader
2530
if db_reader is None:
2631
db_reader = DbReader()
2732
return db_reader
2833

34+
2935
def fetch_latest_version() -> str:
3036
url = "https://api.github.com/repos/stacklok/codegate/releases/latest"
31-
headers = {
32-
"Accept": "application/vnd.github+json",
33-
"X-GitHub-Api-Version": "2022-11-28"
34-
}
37+
headers = {"Accept": "application/vnd.github+json", "X-GitHub-Api-Version": "2022-11-28"}
3538
response = requests.get(url, headers=headers, timeout=5)
3639
response.raise_for_status()
3740
data = response.json()
3841
return data.get("tag_name", "unknown")
3942

40-
@dashboard_router.get("/dashboard/messages")
43+
44+
@dashboard_router.get(
45+
"/dashboard/messages", tags=["Dashboard"], generate_unique_id_function=uniq_name
46+
)
4147
def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversation]:
4248
"""
4349
Get all the messages from the database and return them as a list of conversations.
@@ -47,7 +53,9 @@ def get_messages(db_reader: DbReader = Depends(get_db_reader)) -> List[Conversat
4753
return asyncio.run(parse_messages_in_conversations(prompts_outputs))
4854

4955

50-
@dashboard_router.get("/dashboard/alerts")
56+
@dashboard_router.get(
57+
"/dashboard/alerts", tags=["Dashboard"], generate_unique_id_function=uniq_name
58+
)
5159
def get_alerts(db_reader: DbReader = Depends(get_db_reader)) -> List[Optional[AlertConversation]]:
5260
"""
5361
Get all the messages from the database and return them as a list of conversations.
@@ -65,21 +73,26 @@ async def generate_sse_events() -> AsyncGenerator[str, None]:
6573
yield f"data: {message}\n\n"
6674

6775

68-
@dashboard_router.get("/dashboard/alerts_notification")
76+
@dashboard_router.get(
77+
"/dashboard/alerts_notification", tags=["Dashboard"], generate_unique_id_function=uniq_name
78+
)
6979
async def stream_sse():
7080
"""
7181
Send alerts event
7282
"""
7383
return StreamingResponse(generate_sse_events(), media_type="text/event-stream")
7484

75-
@dashboard_router.get("/dashboard/version")
85+
86+
@dashboard_router.get(
87+
"/dashboard/version", tags=["Dashboard"], generate_unique_id_function=uniq_name
88+
)
7689
def version_check():
7790
try:
7891
latest_version = fetch_latest_version()
7992

8093
# normalize the versions as github will return them with a 'v' prefix
81-
current_version = __version__.lstrip('v')
82-
latest_version_stripped = latest_version.lstrip('v')
94+
current_version = __version__.lstrip("v")
95+
latest_version_stripped = latest_version.lstrip("v")
8396

8497
is_latest: bool = latest_version_stripped == current_version
8598

@@ -95,28 +108,13 @@ def version_check():
95108
"current_version": __version__,
96109
"latest_version": "unknown",
97110
"is_latest": None,
98-
"error": "An error occurred while fetching the latest version"
111+
"error": "An error occurred while fetching the latest version",
99112
}
100113
except Exception as e:
101114
logger.error(f"Unexpected error: {str(e)}")
102115
return {
103116
"current_version": __version__,
104117
"latest_version": "unknown",
105118
"is_latest": None,
106-
"error": "An unexpected error occurred"
119+
"error": "An unexpected error occurred",
107120
}
108-
109-
110-
def generate_openapi():
111-
# Create a temporary FastAPI app instance
112-
app = FastAPI()
113-
114-
# Include your defined router
115-
app.include_router(dashboard_router)
116-
117-
# Generate OpenAPI JSON
118-
openapi_schema = app.openapi()
119-
120-
# Convert the schema to JSON string for easier handling or storage
121-
openapi_json = json.dumps(openapi_schema, indent=2)
122-
print(openapi_json)

src/codegate/dashboard/post_processing.py src/codegate/api/dashboard/post_processing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import structlog
88

9-
from codegate.dashboard.request_models import (
9+
from codegate.api.dashboard.request_models import (
1010
AlertConversation,
1111
ChatMessage,
1212
Conversation,

src/codegate/api/v1.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from codegate.api import v1_models
77
from codegate.db.connection import AlreadyExistsError
88
from codegate.workspaces.crud import WorkspaceCrud
9+
from codegate.api.dashboard.dashboard import dashboard_router
910

1011
v1 = APIRouter()
12+
v1.include_router(dashboard_router)
13+
1114
wscrud = WorkspaceCrud()
1215

1316

src/codegate/db/connection.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, sqlite_path: Optional[str] = None):
5454
)
5555
self._db_path = Path(sqlite_path).absolute()
5656
self._db_path.parent.mkdir(parents=True, exist_ok=True)
57-
logger.debug(f"Connecting to DB from path: {self._db_path}")
57+
# logger.debug(f"Connecting to DB from path: {self._db_path}")
5858
engine_dict = {
5959
"url": f"sqlite+aiosqlite:///{self._db_path}",
6060
"echo": False, # Set to False in production

src/codegate/server.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
import traceback
3+
from unittest.mock import Mock
24

35
import structlog
46
from fastapi import APIRouter, FastAPI, Request
@@ -8,7 +10,6 @@
810

911
from codegate import __description__, __version__
1012
from codegate.api.v1 import v1
11-
from codegate.dashboard.dashboard import dashboard_router
1213
from codegate.pipeline.factory import PipelineFactory
1314
from codegate.providers.anthropic.provider import AnthropicProvider
1415
from codegate.providers.llamacpp.provider import LlamaCppProvider
@@ -96,9 +97,19 @@ async def health_check():
9697
return {"status": "healthy"}
9798

9899
app.include_router(system_router)
99-
app.include_router(dashboard_router)
100100

101101
# CodeGate API
102102
app.include_router(v1, prefix="/api/v1", tags=["CodeGate API"])
103103

104104
return app
105+
106+
107+
def generate_openapi():
108+
app = init_app(Mock(spec=PipelineFactory))
109+
110+
# Generate OpenAPI JSON
111+
openapi_schema = app.openapi()
112+
113+
# Convert the schema to JSON string for easier handling or storage
114+
openapi_json = json.dumps(openapi_schema, indent=2)
115+
print(openapi_json)

tests/dashboard/test_post_processing.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import pytest
66

7-
from codegate.dashboard.post_processing import (
7+
from codegate.api.dashboard.post_processing import (
88
_get_question_answer,
99
_group_partial_messages,
1010
_is_system_prompt,
1111
parse_output,
1212
parse_request,
1313
)
14-
from codegate.dashboard.request_models import (
14+
from codegate.api.dashboard.request_models import (
1515
PartialQuestions,
1616
)
1717
from codegate.db.models import GetPromptWithOutputsRow
@@ -162,10 +162,10 @@ async def test_parse_output(output_dict, expected_str):
162162
)
163163
async def test_get_question_answer(request_msg_list, output_msg_str, row):
164164
with patch(
165-
"codegate.dashboard.post_processing.parse_request", new_callable=AsyncMock
165+
"codegate.api.dashboard.post_processing.parse_request", new_callable=AsyncMock
166166
) as mock_parse_request:
167167
with patch(
168-
"codegate.dashboard.post_processing.parse_output", new_callable=AsyncMock
168+
"codegate.api.dashboard.post_processing.parse_output", new_callable=AsyncMock
169169
) as mock_parse_output:
170170
# Set return values for the mocks
171171
mock_parse_request.return_value = request_msg_list

tests/test_server.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ def test_health_check(test_client: TestClient) -> None:
8181
assert response.status_code == 200
8282
assert response.json() == {"status": "healthy"}
8383

84-
@patch("codegate.dashboard.dashboard.fetch_latest_version", return_value="foo")
84+
@patch("codegate.api.dashboard.dashboard.fetch_latest_version", return_value="foo")
8585
def test_version_endpoint(mock_fetch_latest_version, test_client: TestClient) -> None:
8686
"""Test the version endpoint."""
87-
response = test_client.get("/dashboard/version")
87+
response = test_client.get("/api/v1/dashboard/version")
8888
assert response.status_code == 200
8989

9090
response_data = response.json()
@@ -139,7 +139,7 @@ def test_dashboard_routes(mock_pipeline_factory) -> None:
139139
routes = [route.path for route in app.routes]
140140

141141
# Verify dashboard endpoints are included
142-
dashboard_routes = [route for route in routes if route.startswith("/dashboard")]
142+
dashboard_routes = [route for route in routes if route.startswith("/api/v1/dashboard")]
143143
assert len(dashboard_routes) > 0
144144

145145

0 commit comments

Comments
 (0)