Skip to content

Commit 268035f

Browse files
vivekmigfacebook-github-bot
authored andcommitted
Type Fixes in Robust Metrics (#707)
Summary: Fixes Mypy type checking issues with attack metrics to resolve CircleCI issues Pull Request resolved: #707 Reviewed By: NarineK Differential Revision: D29552315 Pulled By: vivekmig fbshipit-source-id: ba44d7e4121df30d26ac9e0bc796614ac726a9ed
1 parent 67a3ddc commit 268035f

File tree

5 files changed

+72
-41
lines changed

5 files changed

+72
-41
lines changed

captum/_utils/common.py

+13
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,19 @@ def _format_input(inputs: Union[Tensor, Tuple[Tensor, ...]]) -> Tuple[Tensor, ..
159159
return _format_tensor_into_tuples(inputs)
160160

161161

162+
def _format_float_or_tensor_into_tuples(
163+
inputs: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
164+
) -> Tuple[Union[float, Tensor], ...]:
165+
if not isinstance(inputs, tuple):
166+
assert isinstance(
167+
inputs, (torch.Tensor, float)
168+
), "`inputs` must have type float or torch.Tensor but {} found: ".format(
169+
type(inputs)
170+
)
171+
inputs = (inputs,)
172+
return inputs
173+
174+
162175
@overload
163176
def _format_additional_forward_args(additional_forward_args: None) -> None:
164177
...

captum/attr/_utils/summarizer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _copy_stats(self):
4343

4444
return copy.deepcopy(self._stats)
4545

46-
def update(self, x: Union[Tensor, Tuple[Tensor, ...]]):
46+
def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]):
4747
r"""
4848
Calls `update` on each `Stat` object within the summarizer
4949
@@ -57,9 +57,9 @@ def update(self, x: Union[Tensor, Tuple[Tensor, ...]]):
5757
# we want input to be consistently a single input or a tuple
5858
assert not (self._is_inputs_tuple ^ isinstance(x, tuple))
5959

60-
from captum._utils.common import _format_tensor_into_tuples
60+
from captum._utils.common import _format_float_or_tensor_into_tuples
6161

62-
x = _format_tensor_into_tuples(x)
62+
x = _format_float_or_tensor_into_tuples(x)
6363

6464
for i, inp in enumerate(x):
6565
if i >= len(self._summarizers):

captum/robust/_core/metrics/attack_comparator.py

+47-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,19 @@
11
#!/usr/bin/env python3
22
import warnings
33
from collections import namedtuple
4-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
4+
from typing import (
5+
Any,
6+
Callable,
7+
Dict,
8+
Generic,
9+
List,
10+
NamedTuple,
11+
Optional,
12+
Tuple,
13+
TypeVar,
14+
Union,
15+
cast,
16+
)
517

618
from torch import Tensor
719

@@ -15,6 +27,10 @@
1527

1628
ORIGINAL_KEY = "Original"
1729

30+
MetricResultType = TypeVar(
31+
"MetricResultType", float, Tensor, Tuple[Union[float, Tensor], ...]
32+
)
33+
1834

1935
class AttackInfo(NamedTuple):
2036
attack_fn: Union[Perturbation, Callable]
@@ -33,7 +49,7 @@ def agg_metric(inp):
3349
return inp
3450

3551

36-
class AttackComparator:
52+
class AttackComparator(Generic[MetricResultType]):
3753
r"""
3854
Allows measuring model robustness for a given attack or set of attacks. This class
3955
can be used with any metric(s) as well as any set of attacks, either based on
@@ -44,7 +60,7 @@ class AttackComparator:
4460
def __init__(
4561
self,
4662
forward_func: Callable,
47-
metric: Callable[..., Union[float, Tensor, Tuple[Union[float, Tensor], ...]]],
63+
metric: Callable[..., MetricResultType],
4864
preproc_fn: Callable = None,
4965
) -> None:
5066
r"""
@@ -74,10 +90,10 @@ def model_metric(model_out: Tensor, **kwargs: Any)
7490
additional_forward_args provided to evaluate.
7591
"""
7692
self.forward_func = forward_func
77-
self.metric = metric
93+
self.metric: Callable = metric
7894
self.preproc_fn = preproc_fn
79-
self.attacks = {}
80-
self.summary_results = {}
95+
self.attacks: Dict[str, AttackInfo] = {}
96+
self.summary_results: Dict[str, Summarizer] = {}
8197
self.metric_aggregator = agg_metric
8298
self.batch_stats = [Mean, Min, Max]
8399
self.aggregate_stats = [Mean]
@@ -148,7 +164,7 @@ def add_attack(
148164

149165
def _format_summary(
150166
self, summary: Union[Dict, List[Dict]]
151-
) -> Dict[str, Union[float, Tuple[float, ...]]]:
167+
) -> Dict[str, MetricResultType]:
152168
r"""
153169
This method reformats a given summary; particularly for tuples,
154170
the Summarizer's summary format is a list of dictionaries,
@@ -159,12 +175,12 @@ def _format_summary(
159175
if isinstance(summary, dict):
160176
return summary
161177
else:
162-
summary_dict = {}
178+
summary_dict: Dict[str, Tuple] = {}
163179
for key in summary[0]:
164180
summary_dict[key] = tuple(s[key] for s in summary)
165181
if self.out_format:
166182
summary_dict[key] = self.out_format(*summary_dict[key])
167-
return summary_dict
183+
return summary_dict # type: ignore
168184

169185
def _update_out_format(
170186
self, out_metric: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]
@@ -174,7 +190,9 @@ def _update_out_format(
174190
and isinstance(out_metric, tuple)
175191
and hasattr(out_metric, "_fields")
176192
):
177-
self.out_format = namedtuple(type(out_metric).__name__, out_metric._fields)
193+
self.out_format = namedtuple( # type: ignore
194+
type(out_metric).__name__, cast(NamedTuple, out_metric)._fields
195+
)
178196

179197
def _evaluate_batch(
180198
self,
@@ -212,13 +230,10 @@ def _evaluate_batch(
212230
def evaluate(
213231
self,
214232
inputs: Any,
215-
additional_forward_args: Optional[Tuple] = None,
233+
additional_forward_args: Any = None,
216234
perturbations_per_eval: int = 1,
217235
**kwargs,
218-
) -> Dict[
219-
str,
220-
Union[Tensor, Tuple[Tensor, ...], Dict[str, Union[Tensor, Tuple[Tensor, ...]]]],
221-
]:
236+
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
222237
r"""
223238
Evaluate model and attack performance on provided inputs
224239
@@ -385,45 +400,44 @@ def _check_and_evaluate(input_list, key_list):
385400

386401
def _parse_and_update_results(
387402
self, batch_summarizers: Dict[str, Summarizer]
388-
) -> Dict[
389-
str, Union[float, Tuple[float, ...], Dict[str, Union[float, Tuple[float, ...]]]]
390-
]:
391-
results = {
392-
ORIGINAL_KEY: self._format_summary(batch_summarizers[ORIGINAL_KEY].summary)[
393-
"mean"
394-
]
403+
) -> Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]]:
404+
results: Dict[str, Union[MetricResultType, Dict[str, MetricResultType]]] = {
405+
ORIGINAL_KEY: self._format_summary(
406+
cast(Union[Dict, List], batch_summarizers[ORIGINAL_KEY].summary)
407+
)["mean"]
395408
}
396409
self.summary_results[ORIGINAL_KEY].update(
397410
self.metric_aggregator(results[ORIGINAL_KEY])
398411
)
399412
for attack_key in self.attacks:
400413
attack = self.attacks[attack_key]
401-
results[attack.name] = self._format_summary(
402-
batch_summarizers[attack.name].summary
414+
attack_results = self._format_summary(
415+
cast(Union[Dict, List], batch_summarizers[attack.name].summary)
403416
)
417+
results[attack.name] = attack_results
404418

405-
if len(results[attack.name]) == 1:
406-
key = next(iter(results[attack.name]))
419+
if len(attack_results) == 1:
420+
key = next(iter(attack_results))
407421
if attack.name not in self.summary_results:
408422
self.summary_results[attack.name] = Summarizer(
409423
[stat() for stat in self.aggregate_stats]
410424
)
411425
self.summary_results[attack.name].update(
412-
self.metric_aggregator(results[attack.name][key])
426+
self.metric_aggregator(attack_results[key])
413427
)
414428
else:
415-
for key in results[attack.name]:
429+
for key in attack_results:
416430
summary_key = f"{attack.name} {key.title()} Attempt"
417431
if summary_key not in self.summary_results:
418432
self.summary_results[summary_key] = Summarizer(
419433
[stat() for stat in self.aggregate_stats]
420434
)
421435
self.summary_results[summary_key].update(
422-
self.metric_aggregator(results[attack.name][key])
436+
self.metric_aggregator(attack_results[key])
423437
)
424438
return results
425439

426-
def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
440+
def summary(self) -> Dict[str, Dict[str, MetricResultType]]:
427441
r"""
428442
Returns average results over all previous batches evaluated.
429443
@@ -440,7 +454,9 @@ def summary(self) -> Dict[str, Dict[str, Union[Tensor, Tuple[Tensor, ...]]]]:
440454
per batch.
441455
"""
442456
return {
443-
key: self._format_summary(self.summary_results[key].summary)
457+
key: self._format_summary(
458+
cast(Union[Dict, List], self.summary_results[key].summary)
459+
)
444460
for key in self.summary_results
445461
}
446462

captum/robust/_core/metrics/min_param_perturbation.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
import math
33
from enum import Enum
4-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
4+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
55

66
import torch
77
from torch import Tensor
@@ -136,7 +136,9 @@ def correct_fn(model_out: Tensor, **kwargs: Any) -> bool
136136
self.num_attempts = num_attempts
137137
self.preproc_fn = preproc_fn
138138
self.apply_before_preproc = apply_before_preproc
139-
self.correct_fn = correct_fn if correct_fn is not None else default_correct_fn
139+
self.correct_fn = cast(
140+
Callable, correct_fn if correct_fn is not None else default_correct_fn
141+
)
140142

141143
assert (
142144
mode.upper() in MinParamPerturbationMode.__members__
@@ -147,9 +149,9 @@ def _evaluate_batch(
147149
self,
148150
input_list: List,
149151
additional_forward_args: Any,
150-
correct_fn_kwargs: Dict[str, Any],
152+
correct_fn_kwargs: Optional[Dict[str, Any]],
151153
target: TargetType,
152-
) -> None:
154+
) -> Optional[int]:
153155
if additional_forward_args is None:
154156
additional_forward_args = ()
155157

tests/robust/test_min_param_perturbation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
from typing import List
2+
from typing import List, cast
33

44
import torch
55
from torch import Tensor
@@ -55,7 +55,7 @@ def test_minimal_pert_basic_linear(self) -> None:
5555
target_inp, pert = minimal_pert.evaluate(
5656
inp, target=0, attack_kwargs={"ind": 0}
5757
)
58-
self.assertAlmostEqual(pert, 2.0)
58+
self.assertAlmostEqual(cast(float, pert), 2.0)
5959
assertTensorAlmostEqual(
6060
self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]])
6161
)
@@ -79,7 +79,7 @@ def test_minimal_pert_basic_binary(self) -> None:
7979
attack_kwargs={"ind": 0},
8080
perturbations_per_eval=10,
8181
)
82-
self.assertAlmostEqual(pert, 2.0)
82+
self.assertAlmostEqual(cast(float, pert), 2.0)
8383
assertTensorAlmostEqual(
8484
self, target_inp, torch.tensor([[0.0, -9.0, 9.0, 1.0, -3.0]])
8585
)

0 commit comments

Comments
 (0)