-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathemotion_detection.py
35 lines (28 loc) · 1.24 KB
/
emotion_detection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import pipeline, set_seed
import time
import torch
# , use_auth_token="<your_token_here>"
# Set up the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Initialize the pipeline with batch_size=8
classifier = pipeline("text-classification", model='j-hartmann/emotion-english-distilroberta-base', return_all_scores=True, device=device, batch_size=1)
# Set a seed for reproducibility
set_seed(42)
# # Define a cache for storing the output of the model
# cache = {}
# Define a function to get the prediction for a single input
def predict_emotion(input_text):
prediction = classifier(input_text)
# cache[input_text] = prediction
return prediction
# Test the predict function for each input text and measure the time taken for each prediction
def detect_emotion(input_text):
# start_time = time.time()
print(f"Input: {input_text}")
prediction = predict_emotion(input_text)
# print(f"Output: {prediction}")
max_score_label_dict = max(prediction[0], key=lambda x:x['score'])
max_score_label = max_score_label_dict['label']
max_score = max_score_label_dict['score']
print(f"Label with maximum score: {max_score_label} (Score: {max_score:.4f})")
return max_score_label, max_score