1
1
import numpy as np
2
2
import wandb
3
+ from nltk .translate .bleu_score import SmoothingFunction , corpus_bleu , sentence_bleu
3
4
5
+ # TODO: Need to find a good way to measure accuracy of codes.
4
6
def compute_metrics_text (tokenizer ):
5
7
def compute_metrics (eval_pred ):
6
8
predictions , labels = eval_pred
7
- # Too many empty outputs from T5
8
- no_empty_output = sum ([1 for p in predictions if len (p ) < 5 ])
9
+
10
+ predictions = np .where (predictions [0 ] != - 100 , predictions [0 ], tokenizer .pad_token_id )
11
+ decoded_preds = tokenizer .batch_decode (
12
+ predictions ,
13
+ skip_special_tokens = True
14
+ )
15
+ labels = np .where (labels [0 ] != - 100 , labels [0 ], tokenizer .pad_token_id )
16
+ decoded_labels = tokenizer .batch_decode (labels , skip_special_tokens = True )
17
+
18
+ table_data = []
19
+ ref_bleu = []
20
+ gen_bleu = []
21
+
22
+ for l , r in zip (decoded_preds , decoded_labels ):
23
+ gen_bleu .append (l .split ())
24
+ ref_bleu .append ([r .split ()])
25
+ table_data .append ([l , r ])
26
+
27
+ cc = SmoothingFunction ()
28
+ score_bleu = corpus_bleu (ref_bleu , gen_bleu , weights = (0 , 1 , 0 , 0 ), smoothing_function = cc .method4 )
29
+ table = wandb .Table (data = table_data , columns = ["Predictions" , "Labels" ])
30
+
9
31
wandb .log ({
10
- "num_empty_t5_output " : no_empty_output ,
32
+ "table " : table ,
11
33
})
12
- total_acc = 0.0
13
- data = []
14
- n = len (predictions )
15
- for p , l in zip (predictions , labels ):
16
- p = np .where (
17
- p != - 100 ,
18
- p ,
19
- tokenizer .pad_token_id
20
- )
21
- decoded_preds = tokenizer .batch_decode (
22
- p ,
23
- max_length = 512 ,
24
- skip_special_tokens = True
25
- )
26
- l = np .where (l != - 100 , l , tokenizer .pad_token_id )
27
- decoded_labels = tokenizer .batch_decode (
28
- l ,
29
- skip_special_tokens = True
30
- )
31
- acc = np .mean (np .array (decoded_preds ) == np .array (decoded_labels ))
32
- total_acc += acc
33
- data .append ([decoded_preds , decoded_labels , acc ])
34
+
35
+ acc = np .mean (np .array (decoded_preds ) == np .array (decoded_labels ))
34
36
35
- columns = ["T5 output" , "Santa coder output" , "Accuracy" ]
36
- example_table = wandb .Table (data = data , columns = columns )
37
- wandb .log ({"Example" : example_table })
38
- # acc = np.mean(np.array(decoded_preds) == np.array(decoded_labels))
39
37
40
- return {'accuracy' : total_acc / n }
38
+ return {'accuracy' : acc , "bleu" : score_bleu }
41
39
42
40
return compute_metrics
0 commit comments