Skip to content

Commit 8ff25e0

Browse files
lukehindspre-commit-ci[bot]ericwb
authored
Pytorch fix (#1231)
* Fix pytorch weights check * B614: Fix PyTorch plugin to handle weights_only parameter correctly The PyTorch plugin (B614) has been updated to properly handle the weights_only parameter in torch.load calls. When weights_only=True is specified, PyTorch will only deserialize known safe types, making the operation more secure. I also removed torch.save as there is no certain insecure element as such, saving any file or artifact requires consideration of what it is you are saving. Changes: - Update plugin to only check torch.load calls (not torch.save) - Fix weights_only check to handle both string and boolean True values - Remove map_location check as it doesn't affect security - Update example file to demonstrate both safe and unsafe cases - Update plugin documentation to mention weights_only as a safe alternative The plugin now correctly identifies unsafe torch.load calls while allowing safe usage with weights_only=True to pass without warning. Fixes: #1224 * Fix E501 line too long * Rename files to new test scope * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update doc/source/plugins/b614_pytorch_load.rst Co-authored-by: Eric Brown <[email protected]> * Update pytorch_load.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Brown <[email protected]>
1 parent def123a commit 8ff25e0

File tree

7 files changed

+63
-50
lines changed

7 files changed

+63
-50
lines changed

bandit/plugins/pytorch_load_save.py bandit/plugins/pytorch_load.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,26 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
r"""
5-
==========================================
6-
B614: Test for unsafe PyTorch load or save
7-
==========================================
5+
==================================
6+
B614: Test for unsafe PyTorch load
7+
==================================
88
9-
This plugin checks for the use of `torch.load` and `torch.save`. Using
10-
`torch.load` with untrusted data can lead to arbitrary code execution, and
11-
improper use of `torch.save` might expose sensitive data or lead to data
12-
corruption. A safe alternative is to use `torch.load` with the `safetensors`
13-
library from hugingface, which provides a safe deserialization mechanism.
9+
This plugin checks for unsafe use of `torch.load`. Using `torch.load` with
10+
untrusted data can lead to arbitrary code execution. There are two safe
11+
alternatives:
12+
1. Use `torch.load` with `weights_only=True` where only tensor data is
13+
extracted, and no arbitrary Python objects are deserialized
14+
2. Use the `safetensors` library from huggingface, which provides a safe
15+
deserialization mechanism
16+
17+
With `weights_only=True`, PyTorch enforces a strict type check, ensuring
18+
that only torch.Tensor objects are loaded.
1419
1520
:Example:
1621
1722
.. code-block:: none
1823
19-
>> Issue: Use of unsafe PyTorch load or save
24+
>> Issue: Use of unsafe PyTorch load
2025
Severity: Medium Confidence: High
2126
CWE: CWE-94 (https://cwe.mitre.org/data/definitions/94.html)
2227
Location: examples/pytorch_load_save.py:8
@@ -42,12 +47,11 @@
4247

4348
@test.checks("Call")
4449
@test.test_id("B614")
45-
def pytorch_load_save(context):
50+
def pytorch_load(context):
4651
"""
47-
This plugin checks for the use of `torch.load` and `torch.save`. Using
48-
`torch.load` with untrusted data can lead to arbitrary code execution,
49-
and improper use of `torch.save` might expose sensitive data or lead
50-
to data corruption.
52+
This plugin checks for unsafe use of `torch.load`. Using `torch.load`
53+
with untrusted data can lead to arbitrary code execution. The safe
54+
alternative is to use `weights_only=True` or the safetensors library.
5155
"""
5256
imported = context.is_module_imported_exact("torch")
5357
qualname = context.call_function_name_qual
@@ -59,14 +63,18 @@ def pytorch_load_save(context):
5963
if all(
6064
[
6165
"torch" in qualname_list,
62-
func in ["load", "save"],
63-
not context.check_call_arg_value("map_location", "cpu"),
66+
func == "load",
6467
]
6568
):
69+
# For torch.load, check if weights_only=True is specified
70+
weights_only = context.get_call_arg_value("weights_only")
71+
if weights_only == "True" or weights_only is True:
72+
return
73+
6674
return bandit.Issue(
6775
severity=bandit.MEDIUM,
6876
confidence=bandit.HIGH,
69-
text="Use of unsafe PyTorch load or save",
77+
text="Use of unsafe PyTorch load",
7078
cwe=issue.Cwe.DESERIALIZATION_OF_UNTRUSTED_DATA,
7179
lineno=context.get_lineno_for_call_arg("load"),
7280
)
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
------------------
2+
B614: pytorch_load
3+
------------------
4+
5+
.. automodule:: bandit.plugins.pytorch_load

doc/source/plugins/b614_pytorch_load_save.rst

-5
This file was deleted.

examples/pytorch_load.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import torchvision.models as models
3+
4+
# Example of saving a model
5+
model = models.resnet18(pretrained=True)
6+
torch.save(model.state_dict(), 'model_weights.pth')
7+
8+
# Example of loading the model weights in an insecure way (should trigger B614)
9+
loaded_model = models.resnet18()
10+
loaded_model.load_state_dict(torch.load('model_weights.pth'))
11+
12+
# Example of loading with weights_only=True (should NOT trigger B614)
13+
safe_model = models.resnet18()
14+
safe_model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
15+
16+
# Example of loading with weights_only=False (should trigger B614)
17+
unsafe_model = models.resnet18()
18+
unsafe_model.load_state_dict(torch.load('model_weights.pth', weights_only=False))
19+
20+
# Example of loading with map_location but no weights_only (should trigger B614)
21+
cpu_model = models.resnet18()
22+
cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu'))
23+
24+
# Example of loading with both map_location and weights_only=True (should NOT trigger B614)
25+
safe_cpu_model = models.resnet18()
26+
safe_cpu_model.load_state_dict(torch.load('model_weights.pth', map_location='cpu', weights_only=True))

examples/pytorch_load_save.py

-21
This file was deleted.

setup.cfg

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ bandit.plugins =
155155
#bandit/plugins/tarfile_unsafe_members.py
156156
tarfile_unsafe_members = bandit.plugins.tarfile_unsafe_members:tarfile_unsafe_members
157157

158-
#bandit/plugins/pytorch_load_save.py
159-
pytorch_load_save = bandit.plugins.pytorch_load_save:pytorch_load_save
158+
#bandit/plugins/pytorch_load.py
159+
pytorch_load = bandit.plugins.pytorch_load:pytorch_load
160160

161161
# bandit/plugins/trojansource.py
162162
trojansource = bandit.plugins.trojansource:trojansource

tests/functional/test_functional.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -872,13 +872,13 @@ def test_tarfile_unsafe_members(self):
872872
}
873873
self.check_example("tarfile_extractall.py", expect)
874874

875-
def test_pytorch_load_save(self):
876-
"""Test insecure usage of torch.load and torch.save."""
875+
def test_pytorch_load(self):
876+
"""Test insecure usage of torch.load."""
877877
expect = {
878-
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 4, "HIGH": 0},
879-
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 4},
878+
"SEVERITY": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 3, "HIGH": 0},
879+
"CONFIDENCE": {"UNDEFINED": 0, "LOW": 0, "MEDIUM": 0, "HIGH": 3},
880880
}
881-
self.check_example("pytorch_load_save.py", expect)
881+
self.check_example("pytorch_load.py", expect)
882882

883883
def test_trojansource(self):
884884
expect = {

0 commit comments

Comments
 (0)