1
1
#!/usr/bin/env python3
2
2
import warnings
3
3
from collections import namedtuple
4
- from typing import Any , Callable , Dict , List , NamedTuple , Optional , Tuple , Union
4
+ from typing import (
5
+ Any ,
6
+ Callable ,
7
+ Dict ,
8
+ Generic ,
9
+ List ,
10
+ NamedTuple ,
11
+ Optional ,
12
+ Tuple ,
13
+ TypeVar ,
14
+ Union ,
15
+ cast ,
16
+ )
5
17
6
18
from torch import Tensor
7
19
15
27
16
28
ORIGINAL_KEY = "Original"
17
29
30
+ MetricResultType = TypeVar (
31
+ "MetricResultType" , float , Tensor , Tuple [Union [float , Tensor ], ...]
32
+ )
33
+
18
34
19
35
class AttackInfo (NamedTuple ):
20
36
attack_fn : Union [Perturbation , Callable ]
@@ -33,7 +49,7 @@ def agg_metric(inp):
33
49
return inp
34
50
35
51
36
- class AttackComparator :
52
+ class AttackComparator ( Generic [ MetricResultType ]) :
37
53
r"""
38
54
Allows measuring model robustness for a given attack or set of attacks. This class
39
55
can be used with any metric(s) as well as any set of attacks, either based on
@@ -44,7 +60,7 @@ class AttackComparator:
44
60
def __init__ (
45
61
self ,
46
62
forward_func : Callable ,
47
- metric : Callable [..., Union [ float , Tensor , Tuple [ Union [ float , Tensor ], ...]] ],
63
+ metric : Callable [..., MetricResultType ],
48
64
preproc_fn : Callable = None ,
49
65
) -> None :
50
66
r"""
@@ -74,10 +90,10 @@ def model_metric(model_out: Tensor, **kwargs: Any)
74
90
additional_forward_args provided to evaluate.
75
91
"""
76
92
self .forward_func = forward_func
77
- self .metric = metric
93
+ self .metric : Callable = metric
78
94
self .preproc_fn = preproc_fn
79
- self .attacks = {}
80
- self .summary_results = {}
95
+ self .attacks : Dict [ str , AttackInfo ] = {}
96
+ self .summary_results : Dict [ str , Summarizer ] = {}
81
97
self .metric_aggregator = agg_metric
82
98
self .batch_stats = [Mean , Min , Max ]
83
99
self .aggregate_stats = [Mean ]
@@ -148,7 +164,7 @@ def add_attack(
148
164
149
165
def _format_summary (
150
166
self , summary : Union [Dict , List [Dict ]]
151
- ) -> Dict [str , Union [ float , Tuple [ float , ...]] ]:
167
+ ) -> Dict [str , MetricResultType ]:
152
168
r"""
153
169
This method reformats a given summary; particularly for tuples,
154
170
the Summarizer's summary format is a list of dictionaries,
@@ -159,12 +175,12 @@ def _format_summary(
159
175
if isinstance (summary , dict ):
160
176
return summary
161
177
else :
162
- summary_dict = {}
178
+ summary_dict : Dict [ str , Tuple ] = {}
163
179
for key in summary [0 ]:
164
180
summary_dict [key ] = tuple (s [key ] for s in summary )
165
181
if self .out_format :
166
182
summary_dict [key ] = self .out_format (* summary_dict [key ])
167
- return summary_dict
183
+ return summary_dict # type: ignore
168
184
169
185
def _update_out_format (
170
186
self , out_metric : Union [float , Tensor , Tuple [Union [float , Tensor ], ...]]
@@ -174,7 +190,9 @@ def _update_out_format(
174
190
and isinstance (out_metric , tuple )
175
191
and hasattr (out_metric , "_fields" )
176
192
):
177
- self .out_format = namedtuple (type (out_metric ).__name__ , out_metric ._fields )
193
+ self .out_format = namedtuple ( # type: ignore
194
+ type (out_metric ).__name__ , cast (NamedTuple , out_metric )._fields
195
+ )
178
196
179
197
def _evaluate_batch (
180
198
self ,
@@ -212,13 +230,10 @@ def _evaluate_batch(
212
230
def evaluate (
213
231
self ,
214
232
inputs : Any ,
215
- additional_forward_args : Optional [ Tuple ] = None ,
233
+ additional_forward_args : Any = None ,
216
234
perturbations_per_eval : int = 1 ,
217
235
** kwargs ,
218
- ) -> Dict [
219
- str ,
220
- Union [Tensor , Tuple [Tensor , ...], Dict [str , Union [Tensor , Tuple [Tensor , ...]]]],
221
- ]:
236
+ ) -> Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]]:
222
237
r"""
223
238
Evaluate model and attack performance on provided inputs
224
239
@@ -385,45 +400,44 @@ def _check_and_evaluate(input_list, key_list):
385
400
386
401
def _parse_and_update_results (
387
402
self , batch_summarizers : Dict [str , Summarizer ]
388
- ) -> Dict [
389
- str , Union [float , Tuple [float , ...], Dict [str , Union [float , Tuple [float , ...]]]]
390
- ]:
391
- results = {
392
- ORIGINAL_KEY : self ._format_summary (batch_summarizers [ORIGINAL_KEY ].summary )[
393
- "mean"
394
- ]
403
+ ) -> Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]]:
404
+ results : Dict [str , Union [MetricResultType , Dict [str , MetricResultType ]]] = {
405
+ ORIGINAL_KEY : self ._format_summary (
406
+ cast (Union [Dict , List ], batch_summarizers [ORIGINAL_KEY ].summary )
407
+ )["mean" ]
395
408
}
396
409
self .summary_results [ORIGINAL_KEY ].update (
397
410
self .metric_aggregator (results [ORIGINAL_KEY ])
398
411
)
399
412
for attack_key in self .attacks :
400
413
attack = self .attacks [attack_key ]
401
- results [ attack . name ] = self ._format_summary (
402
- batch_summarizers [attack .name ].summary
414
+ attack_results = self ._format_summary (
415
+ cast ( Union [ Dict , List ], batch_summarizers [attack .name ].summary )
403
416
)
417
+ results [attack .name ] = attack_results
404
418
405
- if len (results [ attack . name ] ) == 1 :
406
- key = next (iter (results [ attack . name ] ))
419
+ if len (attack_results ) == 1 :
420
+ key = next (iter (attack_results ))
407
421
if attack .name not in self .summary_results :
408
422
self .summary_results [attack .name ] = Summarizer (
409
423
[stat () for stat in self .aggregate_stats ]
410
424
)
411
425
self .summary_results [attack .name ].update (
412
- self .metric_aggregator (results [ attack . name ] [key ])
426
+ self .metric_aggregator (attack_results [key ])
413
427
)
414
428
else :
415
- for key in results [ attack . name ] :
429
+ for key in attack_results :
416
430
summary_key = f"{ attack .name } { key .title ()} Attempt"
417
431
if summary_key not in self .summary_results :
418
432
self .summary_results [summary_key ] = Summarizer (
419
433
[stat () for stat in self .aggregate_stats ]
420
434
)
421
435
self .summary_results [summary_key ].update (
422
- self .metric_aggregator (results [ attack . name ] [key ])
436
+ self .metric_aggregator (attack_results [key ])
423
437
)
424
438
return results
425
439
426
- def summary (self ) -> Dict [str , Dict [str , Union [ Tensor , Tuple [ Tensor , ...]] ]]:
440
+ def summary (self ) -> Dict [str , Dict [str , MetricResultType ]]:
427
441
r"""
428
442
Returns average results over all previous batches evaluated.
429
443
@@ -440,7 +454,9 @@ def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
440
454
per batch.
441
455
"""
442
456
return {
443
- key : self ._format_summary (self .summary_results [key ].summary )
457
+ key : self ._format_summary (
458
+ cast (Union [Dict , List ], self .summary_results [key ].summary )
459
+ )
444
460
for key in self .summary_results
445
461
}
446
462
0 commit comments