-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserve.py
81 lines (67 loc) · 2.31 KB
/
serve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import logging
import configparser
from fastapi import FastAPI
from pydantic import BaseModel
from rich.logging import RichHandler
from fastapi.responses import JSONResponse
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
config = configparser.ConfigParser()
config.read('config.ini')
model_path = output_dir = config.get('serve', 'model_path')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RichHandler()
logger.addHandler(handler)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f'Using Device: {device}')
logger.info(f'Model Path: {model_path}')
model = DistilBertForSequenceClassification.from_pretrained(model_path).to(device)
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
app = FastAPI()
class ClaimInfo(BaseModel):
claim_text: str
def get_veracity(text):
"""
Analyzes the veracity of the provided text using a pretrained model.
Args:
text (str): The text to be analyzed for veracity.
Returns:
int: The predicted label index, representing the model's classification
of the input text.
"""
logger.info('Tokenizing Text..')
inputs = tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
logger.info('Inferencing..')
outputs = model(**inputs)
logger.info('Softmax..')
probs = outputs[0].softmax(1)
logger.info('Argmax..')
pred_label_idx = probs.argmax()
logger.info('Returning Value..')
return pred_label_idx.item()
@app.post("/claim/v1/predict")
async def predict_veracity(claim_info: ClaimInfo):
try:
claim_text = claim_info.claim_text
logger.info('Check Veracity..')
veracity = get_veracity(claim_text)
return JSONResponse(
status_code=200,
content=
{
'message':'success',
'veracity':veracity
}
)
except Exception as e:
logger.error('Error Occured While Processing', exc_info=True)
return JSONResponse(
status_code=500,
content=
{
'message': 'Error Occured While Detecting Veracity',
'veracity':-1
}
)