-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
379 lines (315 loc) · 12.8 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Union
from uuid import uuid4
from fastapi import FastAPI, File, HTTPException, Query, Request, Response, UploadFile, responses
from pydantic import BaseModel, Field, field_validator
from starlette.background import BackgroundTask
from typing_extensions import Annotated
import utils.config as config
from core.enums import ClassifierType
from core.extract_metrics import extract_notebook_metrics_from_ipynb_file
from core.model_store import ModelStore
from utils.logger import init_logger
init_logger()
logger = logging.getLogger(__name__)
model_store = ModelStore()
description = """
API to check your coding quality.
## Models
You will be able to select the model which decides if your code is comprehensive enough.
<br/>
Information about the training parameters of the models is provided in the /models/ endpoint.
## Notebooks
You will be able to upload your notebooks to extract their metrics and check if they are comprehensive enough.
* **Upload Your Notebook**: Upload your notebook and get its identifier (file name saved on the server).
* **Get Notebook Metrics**: See the metrics of your notebook.
* **Get Notebook Prediction**: See if your code is comprehensive or not based on the selected model (chosen from the /models/ API).
"""
app = FastAPI(
title="Code Comprehension Service",
description=description,
summary="Check if your code is a good one.",
version="0.0.1",
contact={
"name": "Masih Beigi Rizi",
"email": "[email protected]",
},
)
@app.middleware("http")
async def logger_middleware(request: Request, call_next):
start_time = time.time()
req_body = await request.body()
response = await call_next(request)
res_body = b""
async for chunk in response.body_iterator:
res_body += chunk
req_info = (
f"\npath: {request.url.path}\npath_params: {request.path_params}\nbody: {req_body}\ndict: {request.__dict__}"
)
res_info = f"\nbody: {res_body}\ndict: {response.__dict__}"
task = BackgroundTask(
logger.info, f"request_info: {req_info}\nresponse_info: {res_info} in {time.time() - start_time} seconds."
)
return Response(
content=res_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
background=task,
)
class ModelInfo(BaseModel):
id: str
file_name: str
classifier: str
notebook_metrics_df_file_name: str
notebook_scores_df_file_path: str
notebook_metrics_filters: List[str]
notebook_scores_filters: List[str]
sort_by: str
split_factor: float
selection_ratio: float
include_pt: bool
metrics: dict
@app.get(
"/models/",
summary="Get a list of available models.",
response_model=List[ModelInfo],
)
async def get_models(
classifier: Annotated[
Union[ClassifierType, None],
Query(
max_length=50,
description=f"Filter by type of model class can be one of {[model.value for model in ClassifierType]}.",
example=ClassifierType.cat_boost,
),
] = None,
) -> List[ModelInfo]:
"""
This function retrieves information about the currently active models.
Returns:
A list containing information about active models.
"""
active_models = model_store.active_models
if classifier is not None:
active_models = {
key: {inner_key: inner_val for inner_key, inner_val in inner_dict.items()}
for key, inner_dict in active_models.items()
if inner_dict["classifier"] == classifier
}
return [ModelInfo(id=id_, **model_detail) for id_, model_detail in active_models.items()]
class UploadNotebookResponse(BaseModel):
result: str = Field(examples=["c744d10c6a2d4ec49cd30f69b8301da3-14091946.ipynb"])
@app.post(
"/notebook/upload/",
summary="Upload a Jupyter notebook file.",
response_model=UploadNotebookResponse,
)
def upload_notebook(
file: Annotated[
UploadFile,
File(
description="Selected jupyter notebook.",
examples=["filename.ipynb"],
),
]
) -> UploadNotebookResponse:
"""
This function uploads a Jupyter notebook file to the server.
Args:
file: The Jupyter notebook file to be uploaded.
Returns:
A dictionary containing a message with the generated filename.
"""
filename = f"{uuid4().hex}-{file.filename}"
notebooks_folder_path = Path(config.NOTEBOOKS_FOLDER_PATH)
filepath = notebooks_folder_path / filename
with open(filepath, "wb") as buffer:
buffer.write(file.file.read())
file.file.close()
return UploadNotebookResponse(result=filename)
class MetricsExtractionInfo(BaseModel):
notebook_filename: str = Field(examples=["uuid-filename.ipynb"])
base_code_df_filename: Optional[str] = Field(
default=Path(config.CODE_DF_FILE_PATH).name, examples=[Path(config.CODE_DF_FILE_PATH).name]
)
chunk_size: Optional[int] = Field(default=config.CHUNK_SIZE, examples=[config.CHUNK_SIZE])
@field_validator("notebook_filename", mode="before")
def validate_filename(cls, v: str):
if "/" in v or "\\" in v or ".." in v:
raise ValueError("Filename cannot contain path components or '..'. Rename your file and try again.")
if not v.endswith(".ipynb"):
raise ValueError("Filename must end with the '.ipynb' extension")
expected_parent = Path(config.NOTEBOOKS_FOLDER_PATH)
file_path = expected_parent / v
if file_path.parent != expected_parent:
raise ValueError("Filename must not change the parent folder path.")
return v
class MetricsExtractionResponse(BaseModel):
metrics: Dict[str, Union[int, float]] = Field(
examples=[
{
"LOC": 78,
"BLC": 0,
"UDF": 3,
"I": 6,
"EH": 0,
"NVD": 3,
"NEC": 1,
"S": 17,
"P": 82,
"OPRND": 193,
"OPRATOR": 93,
"UOPRND": 185,
"UOPRATOR": 29,
"ID": 428,
"LOCom": 16,
"EAP": 2.3896513096877,
"CW": 114,
"ALID": 6.9,
"MeanLCC": 7.8,
"ALLC": 38.947288006111535,
"KLCID": 4.522916666666666,
"CyC": 1.2,
"MLID": 25,
"NBD": 10.363636363636363,
"AID": 118,
"CC": 10,
"MC": 11,
"MeanWMC": 19.90909090909091,
"MeanLMC": 1.3636363636363635,
"H1": 0,
"H2": 3,
"H3": 1,
"MW": 219,
"LMC": 15,
}
]
)
@app.post(
"/notebook/metrics/",
summary="Extract metrics from a Jupyter notebook file.",
response_model=MetricsExtractionResponse,
)
def extract_metrics(info: MetricsExtractionInfo) -> MetricsExtractionResponse:
"""
This function extracts metrics from a specified Jupyter notebook file.
Takes:
information about the notebook and extraction parameters.
Returns:
A dictionary containing the extracted metrics.
"""
file_path = Path(config.NOTEBOOKS_FOLDER_PATH) / info.notebook_filename
if not file_path.is_file():
raise HTTPException(
status_code=404,
detail="Specified notebook does not exist. Please upload the file again (/notebook/upload/ API) or make sure you have the correct file name.",
)
base_code_df_file_path = Path(config.DATAFRAMES_FOLDER_PATH) / info.base_code_df_filename
if not base_code_df_file_path.is_file():
raise HTTPException(
status_code=404,
detail="Specified code df does not exist. You can exclude this field from your request body to use the default code df.",
)
try:
extracted_notebook_metrics_df = extract_notebook_metrics_from_ipynb_file(
file_path=str(file_path.resolve()),
base_code_df_file_path=str(base_code_df_file_path.resolve()),
chunk_size=info.chunk_size,
)
except Exception as exc:
logger.error(f"Exception on extracting notebook metrics: {exc}")
raise HTTPException(
status_code=400,
detail="Could not extract the metrics from the provided notebook file. Please upload a different notebook file and try again or contact the admin (info on top of the page)..",
)
if len(extracted_notebook_metrics_df) != 1:
raise HTTPException(
status_code=404,
detail="Notebook file is empty (No metrics were extracted). Please upload a different notebook file and try again or contact the admin (info on top of the page).",
)
extracted_notebook_metrics_df.drop(["kernel_id"], axis=1, inplace=True)
return MetricsExtractionResponse(metrics=extracted_notebook_metrics_df.iloc[[0]].to_dict(orient="index")[0])
class PredictionInfo(MetricsExtractionInfo):
model_id: str = Field(examples=[list(model_store.active_models.keys())[0]], description="Use ids in the response of models/ api.")
pt_score: Optional[int] = Field(
default=None, examples=[10], description="Only set when model has value True for its include_pt field."
)
class PredictionResponse(MetricsExtractionResponse):
prediction: int = Field(examples=[0, 1])
@app.post(
"/notebook/predict/",
summary="Make predictions using a model on a notebook's metrics.",
response_model=PredictionResponse,
)
def predict(info: PredictionInfo) -> PredictionResponse:
"""
This function performs predictions on the metrics extracted from a notebook using a specified model.
Takes:
information about the notebook, model, and prediction parameters.
Returns:
a dictionary containing the metrics and the prediction result.
"""
file_path = Path(config.NOTEBOOKS_FOLDER_PATH) / info.notebook_filename
if not file_path.is_file():
raise HTTPException(
status_code=404,
detail="Specified notebook does not exist. Please upload the file again (/notebook/upload/ API) or make sure you have the correct file name.",
)
base_code_df_file_path = Path(config.DATAFRAMES_FOLDER_PATH) / info.base_code_df_filename
if not base_code_df_file_path.is_file():
raise HTTPException(
status_code=404,
detail="Specified code df does not exist. You can exclude this field from your request body to use the default code df.",
)
try:
extracted_notebook_metrics_df = extract_notebook_metrics_from_ipynb_file(
file_path=str(file_path.resolve()),
base_code_df_file_path=str(base_code_df_file_path.resolve()),
chunk_size=info.chunk_size,
)
except Exception as exc:
logger.error(f"Exception on extracting notebook metrics: {exc}")
raise HTTPException(
status_code=400,
detail="Could not extract the metrics from the provided notebook file. Please upload a different notebook file and try again or contact the admin (info on top of the page)..",
)
if len(extracted_notebook_metrics_df) != 1:
raise HTTPException(
status_code=404,
detail="Notebook file is empty (No metrics were extracted). Please upload a different notebook file and try again or contact the admin (info on top of the page)..",
)
extracted_notebook_metrics_df.drop(["kernel_id"], axis=1, inplace=True)
classifier = model_store.get_model(info.model_id)
if classifier is None:
raise HTTPException(
status_code=404,
detail="Specified model id does not exist. Checkout the /models/ endpoint documentation to get the right id for your need.",
)
# TODO: standardize column names
extracted_notebook_metrics_df.rename(
columns={
"ALLC": "ALLCL",
},
inplace=True,
)
if info.pt_score is not None:
if not model_store.get_model_info(info.model_id).get("include_pt"):
raise Exception(
status_code=400,
detail="Specified model does not take PT score. Please remove the pt_score field from your request body or use a different model.",
)
extracted_notebook_metrics_df["PT"] = info.pt_score
result = classifier.predict(x=extracted_notebook_metrics_df)
return PredictionResponse(
metrics=extracted_notebook_metrics_df.iloc[[0]].to_dict(orient="index")[0],
prediction=int(result[0]),
)
@app.get("/", include_in_schema=False)
async def docs_redirect():
return responses.RedirectResponse(url="/docs")
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)