Skip to content

Commit ea1461b

Browse files
authored
Merge pull request #591 from vivekmig/v0.3.1
Captum v0.3.1
2 parents 26cad25 + d0920f6 commit ea1461b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+2949
-950
lines changed

Diff for: .circleci/config.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ commands:
102102
steps:
103103
- run:
104104
name: "Simple PIP install"
105-
command: python -m pip install -e .[dev]
105+
command: |
106+
python -m pip install --upgrade pip
107+
python -m pip install -e .[dev]
106108
107109
py_3_7_setup:
108110
description: "Set python version to 3.7 and install pip and pytest"

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,4 @@ website/static/js/*
114114
!website/static/js/code_block_buttons.js
115115
website/static/_sphinx-sources/
116116
node_modules
117+
captum/insights/attr_vis/widget/static

Diff for: README.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
[![CircleCI](https://circleci.com/gh/pytorch/captum.svg?style=shield)](https://circleci.com/gh/pytorch/captum)
88

99
Captum is a model interpretability and understanding library for PyTorch.
10-
Captum means comprehension in latin and contains general purpose implementations
10+
Captum means comprehension in Latin and contains general purpose implementations
1111
of integrated gradients, saliency maps, smoothgrad, vargrad and others for
1212
PyTorch models. It has quick integration for models built with domain-specific
1313
libraries such as torchvision, torchtext, and others.
@@ -175,12 +175,12 @@ Convergence Delta: tensor([2.3842e-07, -4.7684e-07])
175175
The algorithm outputs an attribution score for each input element and a
176176
convergence delta. The lower the absolute value of the convergence delta the better
177177
is the approximation. If we choose not to return delta,
178-
we can simply not provide `return_convergence_delta` input
178+
we can simply not provide the `return_convergence_delta` input
179179
argument. The absolute value of the returned deltas can be interpreted as an
180180
approximation error for each input sample.
181181
It can also serve as a proxy of how accurate the integral approximation for given
182182
inputs and baselines is.
183-
If the approximation error is large, we can try larger number of integral
183+
If the approximation error is large, we can try a larger number of integral
184184
approximation steps by setting `n_steps` to a larger value. Not all algorithms
185185
return approximation error. Those which do, though, compute it based on the
186186
completeness property of the algorithms.
@@ -224,7 +224,7 @@ in order to get per example average delta.
224224

225225

226226
Below is an example of how we can apply `DeepLift` and `DeepLiftShap` on the
227-
`ToyModel` described above. Current implementation of DeepLift supports only
227+
`ToyModel` described above. The current implementation of DeepLift supports only the
228228
`Rescale` rule.
229229
For more details on alternative implementations, please see the [DeepLift paper](https://arxiv.org/abs/1704.02685).
230230

@@ -286,7 +286,7 @@ In order to smooth and improve the quality of the attributions we can run
286286
to smoothen the attributions by aggregating them for multiple noisy
287287
samples that were generated by adding gaussian noise.
288288

289-
Here is an example how we can use `NoiseTunnel` with `IntegratedGradients`.
289+
Here is an example of how we can use `NoiseTunnel` with `IntegratedGradients`.
290290

291291
```python
292292
ig = IntegratedGradients(model)
@@ -338,7 +338,7 @@ It is an extension of path integrated gradients for hidden layers and holds the
338338
completeness property as well.
339339

340340
It doesn't attribute the contribution scores to the input features
341-
but shows the importance of each neuron in selected layer.
341+
but shows the importance of each neuron in the selected layer.
342342
```python
343343
lc = LayerConductance(model, model.lin1)
344344
attributions, delta = lc.attribute(input, baselines=baseline, target=0, return_convergence_delta=True)
@@ -412,6 +412,8 @@ See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out.
412412
## Talks and Papers
413413
The slides of our presentation from NeurIPS 2019 can be found [here](docs/presentations/Captum_NeurIPS_2019_final.key)
414414

415+
The slides of our presentation from KDD 2020 tutorial can be found [here](https://pytorch-tutorial-assets.s3.amazonaws.com/Captum_KDD_2020.pdf)
416+
415417
## References of Algorithms
416418

417419
* `IntegratedGradients`, `LayerIntegratedGradients`: [Axiomatic Attribution for Deep Networks, Mukund Sundararajan et al. 2017](https://arxiv.org/abs/1703.01365) and [Did the Model Understand the Question?, Pramod K. Mudrakarta, et al. 2018](https://arxiv.org/abs/1805.05492)

Diff for: captum/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env python3
22

3-
__version__ = "0.3.0"
3+
__version__ = "0.3.1"

Diff for: captum/_utils/common.py

+63
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,20 @@ def _expand_target(
249249
return target
250250

251251

252+
def _expand_feature_mask(
253+
feature_mask: Union[Tensor, Tuple[Tensor, ...]], n_samples: int
254+
):
255+
is_feature_mask_tuple = _is_tuple(feature_mask)
256+
feature_mask = _format_tensor_into_tuples(feature_mask)
257+
feature_mask_new = tuple(
258+
feature_mask_elem.repeat_interleave(n_samples, dim=0)
259+
if feature_mask_elem.size(0) > 1
260+
else feature_mask_elem
261+
for feature_mask_elem in feature_mask
262+
)
263+
return _format_output(is_feature_mask_tuple, feature_mask_new)
264+
265+
252266
def _expand_and_update_baselines(
253267
inputs: Tuple[Tensor, ...],
254268
n_samples: int,
@@ -317,6 +331,18 @@ def _expand_and_update_target(n_samples: int, kwargs: dict):
317331
kwargs["target"] = target
318332

319333

334+
def _expand_and_update_feature_mask(n_samples: int, kwargs: dict):
335+
if "feature_mask" not in kwargs:
336+
return
337+
338+
feature_mask = kwargs["feature_mask"]
339+
if feature_mask is None:
340+
return
341+
342+
feature_mask = _expand_feature_mask(feature_mask, n_samples)
343+
kwargs["feature_mask"] = feature_mask
344+
345+
320346
@typing.overload
321347
def _format_output(
322348
is_inputs_tuple: Literal[True], output: Tuple[Tensor, ...]
@@ -354,6 +380,43 @@ def _format_output(
354380
return output if is_inputs_tuple else output[0]
355381

356382

383+
@typing.overload
384+
def _format_outputs(
385+
is_multiple_inputs: Literal[False], outputs: List[Tuple[Tensor, ...]]
386+
) -> Union[Tensor, Tuple[Tensor, ...]]:
387+
...
388+
389+
390+
@typing.overload
391+
def _format_outputs(
392+
is_multiple_inputs: Literal[True], outputs: List[Tuple[Tensor, ...]]
393+
) -> List[Union[Tensor, Tuple[Tensor, ...]]]:
394+
...
395+
396+
397+
@typing.overload
398+
def _format_outputs(
399+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
400+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
401+
...
402+
403+
404+
def _format_outputs(
405+
is_multiple_inputs: bool, outputs: List[Tuple[Tensor, ...]]
406+
) -> Union[Tensor, Tuple[Tensor, ...], List[Union[Tensor, Tuple[Tensor, ...]]]]:
407+
assert isinstance(outputs, list), "Outputs must be a list"
408+
assert is_multiple_inputs or len(outputs) == 1, (
409+
"outputs should contain multiple inputs or have a single output"
410+
f"however the number of outputs is: {len(outputs)}"
411+
)
412+
413+
return (
414+
[_format_output(len(output) > 1, output) for output in outputs]
415+
if is_multiple_inputs
416+
else _format_output(len(outputs[0]) > 1, outputs[0])
417+
)
418+
419+
357420
def _run_forward(
358421
forward_func: Callable,
359422
inputs: Union[Tensor, Tuple[Tensor, ...]],

Diff for: captum/_utils/models/linear_model/train.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
import warnings
23
from typing import Any, Callable, Dict, List, Optional
34

45
import torch
@@ -286,10 +287,11 @@ def sklearn_train_linear_model(
286287
except ImportError:
287288
raise ValueError("sklearn is not available. Please install sklearn >= 0.23")
288289

289-
assert (
290-
sklearn.__version__ >= "0.23.0"
291-
), "Must have sklearn version 0.23.0 or higher to use "
292-
"sample_weight in Lasso regression."
290+
if not sklearn.__version__ >= "0.23.0":
291+
warnings.warn(
292+
"Must have sklearn version 0.23.0 or higher to use "
293+
"sample_weight in Lasso regression."
294+
)
293295

294296
num_batches = 0
295297
xs, ys, ws = [], [], []
@@ -323,7 +325,16 @@ def sklearn_train_linear_model(
323325
sklearn_model = reduce(
324326
lambda val, el: getattr(val, el), [sklearn] + sklearn_trainer.split(".")
325327
)(**construct_kwargs)
326-
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
328+
try:
329+
sklearn_model.fit(x, y, sample_weight=w, **fit_kwargs)
330+
except TypeError:
331+
sklearn_model.fit(x, y, **fit_kwargs)
332+
warnings.warn(
333+
"Sample weight is not supported for the provided linear model!"
334+
" Trained model without weighting inputs. For Lasso, please"
335+
" upgrade sklearn to a version >= 0.23.0."
336+
)
337+
327338
t2 = time.time()
328339

329340
# Convert weights to pytorch

Diff for: captum/attr/_core/feature_ablation.py

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def attribute(
346346
eval_diff = (
347347
initial_eval - modified_eval.reshape((-1, num_outputs))
348348
).reshape((-1, num_outputs) + (len(inputs[i].shape) - 1) * (1,))
349+
eval_diff = eval_diff.to(total_attrib[i].device)
349350
if self.use_weights:
350351
weights[i] += current_mask.float().sum(dim=0)
351352
total_attrib[i] += (eval_diff * current_mask.to(attrib_type)).sum(

Diff for: captum/attr/_core/gradient_shap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def attribute(
274274
nt, # self
275275
inputs,
276276
nt_type="smoothgrad",
277-
n_samples=n_samples,
277+
nt_samples=n_samples,
278278
stdevs=stdevs,
279279
draw_baseline_from_distrib=True,
280280
baselines=baselines,

Diff for: captum/attr/_core/kernel_shap.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from captum._utils.models.linear_model import SkLearnLinearRegression
1010
from captum._utils.typing import BaselineType, TargetType, TensorOrTupleOfTensorsGeneric
1111
from captum.attr._core.lime import Lime
12+
from captum.attr._utils.common import lime_n_perturb_samples_deprecation_decorator
1213
from captum.log import log_usage
1314

1415

@@ -72,14 +73,15 @@ def __init__(self, forward_func: Callable) -> None:
7273
)
7374

7475
@log_usage()
76+
@lime_n_perturb_samples_deprecation_decorator
7577
def attribute( # type: ignore
7678
self,
7779
inputs: TensorOrTupleOfTensorsGeneric,
7880
baselines: BaselineType = None,
7981
target: TargetType = None,
8082
additional_forward_args: Any = None,
8183
feature_mask: Union[None, Tensor, Tuple[Tensor, ...]] = None,
82-
n_perturb_samples: int = 25,
84+
n_samples: int = 25,
8385
perturbations_per_eval: int = 1,
8486
return_input_shape: bool = True,
8587
) -> TensorOrTupleOfTensorsGeneric:
@@ -213,9 +215,9 @@ def attribute( # type: ignore
213215
If None, then a feature mask is constructed which assigns
214216
each scalar within a tensor as a separate feature.
215217
Default: None
216-
n_perturb_samples (int, optional): The number of samples of the original
218+
n_samples (int, optional): The number of samples of the original
217219
model used to train the surrogate interpretable model.
218-
Default: `50` if `n_perturb_samples` is not provided.
220+
Default: `50` if `n_samples` is not provided.
219221
perturbations_per_eval (int, optional): Allows multiple samples
220222
to be processed simultaneously in one call to forward_fn.
221223
Each forward pass will contain a maximum of
@@ -266,7 +268,7 @@ def attribute( # type: ignore
266268
>>> ks = KernelShap(net)
267269
>>> # Computes attribution, with each of the 4 x 4 = 16
268270
>>> # features as a separate interpretable feature
269-
>>> attr = ks.attribute(input, target=1, n_perturb_samples=200)
271+
>>> attr = ks.attribute(input, target=1, n_samples=200)
270272
271273
>>> # Alternatively, we can group each 2x2 square of the inputs
272274
>>> # as one 'interpretable' feature and perturb them together.
@@ -299,7 +301,7 @@ def attribute( # type: ignore
299301
target=target,
300302
additional_forward_args=additional_forward_args,
301303
feature_mask=feature_mask,
302-
n_perturb_samples=n_perturb_samples,
304+
n_samples=n_samples,
303305
perturbations_per_eval=perturbations_per_eval,
304306
return_input_shape=return_input_shape,
305307
)

Diff for: captum/attr/_core/layer/layer_gradient_shap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def attribute(
305305
nt, # self
306306
inputs,
307307
nt_type="smoothgrad",
308-
n_samples=n_samples,
308+
nt_samples=n_samples,
309309
stdevs=stdevs,
310310
draw_baseline_from_distrib=True,
311311
baselines=baselines,

0 commit comments

Comments
 (0)