Skip to content

Commit 11d9e63

Browse files
committed
add bleu for compute_metrics
1 parent 009ca76 commit 11d9e63

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

metrics.py

+28-30
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,40 @@
11
import numpy as np
22
import wandb
3+
from nltk.translate.bleu_score import SmoothingFunction, corpus_bleu, sentence_bleu
34

5+
# TODO: Need to find a good way to measure accuracy of codes.
46
def compute_metrics_text(tokenizer):
57
def compute_metrics(eval_pred):
68
predictions, labels = eval_pred
7-
# Too many empty outputs from T5
8-
no_empty_output = sum([1 for p in predictions if len(p) < 5 ])
9+
10+
predictions = np.where(predictions[0] != -100, predictions[0], tokenizer.pad_token_id)
11+
decoded_preds = tokenizer.batch_decode(
12+
predictions,
13+
skip_special_tokens=True
14+
)
15+
labels = np.where(labels[0] != -100, labels[0], tokenizer.pad_token_id)
16+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
17+
18+
table_data = []
19+
ref_bleu = []
20+
gen_bleu = []
21+
22+
for l, r in zip(decoded_preds, decoded_labels):
23+
gen_bleu.append(l.split())
24+
ref_bleu.append([r.split()])
25+
table_data.append([l, r])
26+
27+
cc = SmoothingFunction()
28+
score_bleu = corpus_bleu(ref_bleu, gen_bleu, weights=(0, 1, 0, 0), smoothing_function=cc.method4)
29+
table = wandb.Table(data=table_data, columns=["Predictions", "Labels"])
30+
931
wandb.log({
10-
"num_empty_t5_output": no_empty_output,
32+
"table": table,
1133
})
12-
total_acc = 0.0
13-
data = []
14-
n = len(predictions)
15-
for p, l in zip(predictions, labels):
16-
p = np.where(
17-
p != -100,
18-
p,
19-
tokenizer.pad_token_id
20-
)
21-
decoded_preds = tokenizer.batch_decode(
22-
p,
23-
max_length=512,
24-
skip_special_tokens=True
25-
)
26-
l = np.where(l != -100, l, tokenizer.pad_token_id)
27-
decoded_labels = tokenizer.batch_decode(
28-
l,
29-
skip_special_tokens=True
30-
)
31-
acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))
32-
total_acc += acc
33-
data.append([decoded_preds, decoded_labels, acc])
34+
35+
acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))
3436

35-
columns=["T5 output", "Santa coder output", "Accuracy"]
36-
example_table = wandb.Table(data=data, columns=columns)
37-
wandb.log({"Example": example_table})
38-
# acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))
3937

40-
return {'accuracy': total_acc / n}
38+
return {'accuracy': acc, "bleu": score_bleu}
4139

4240
return compute_metrics

0 commit comments

Comments
 (0)