Skip to content

Commit 49bdc41

Browse files
mariovas3pytorchmergebot
authored andcommitted
Add strict kwarg to nn.Module.set_submodule and fix bug for non dot delineated strings (pytorch#143455)
Before fixing set_submodule, it used to create leaf modules when the target was not a dot-delimited string. After the fix it will not create a new attribute if target is a non-dot-delimited string. If you want to create leaf nodes of `nn.Module` parent nodes, you can use `replace_or_create_new_leaf_module`. Fixes pytorch#143441 Pull Request resolved: pytorch#143455 Approved by: https://github.com/mikaylagawarecki
1 parent e3c4d1b commit 49bdc41

File tree

2 files changed

+77
-28
lines changed

2 files changed

+77
-28
lines changed

test/test_nn.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1468,17 +1468,44 @@ def test_add_module(self):
14681468
lambda: getattr(net, fn)(None, l))
14691469

14701470
def test_set_submodule(self):
1471+
# test the docstring example
1472+
A = nn.Module()
1473+
A.set_submodule("net_b", nn.Module())
1474+
A.set_submodule("net_b.net_c", nn.Module())
1475+
A.set_submodule("net_b.net_c.conv", nn.Conv2d(3, 3, 3))
1476+
A.set_submodule("net_b.linear", nn.Linear(3, 3))
1477+
new_linear = nn.Linear(1, 1)
1478+
A.set_submodule("net_b.net_c.conv", new_linear)
1479+
self.assertEqual(A.get_submodule("net_b.net_c.conv"), new_linear)
1480+
new_linear = nn.Linear(1, 2)
1481+
A.set_submodule("net_b.net_c.conv", new_linear, True)
1482+
self.assertEqual(A.get_submodule("net_b.net_c.conv"), new_linear)
1483+
new_conv = nn.Conv2d(1, 1, 1)
1484+
self.assertRaises(AttributeError, A.set_submodule, "net_b.conv", new_conv, True)
1485+
A.set_submodule("net_b.conv", new_conv)
1486+
self.assertEqual(A.get_submodule("net_b.conv"), new_conv)
1487+
1488+
# more tests
14711489
net = nn.Module()
14721490
net.t = nn.Module()
14731491
l = nn.Linear(1, 2)
14741492
target = "t.l"
1475-
net.set_submodule(target, l)
1493+
net.t.l = l
14761494
self.assertEqual(net.get_submodule(target), l)
14771495
l2 = nn.Linear(2, 1)
14781496
net.set_submodule(target, l2)
14791497
self.assertEqual(net.get_submodule(target), l2)
14801498
self.assertRaises(ValueError, net.set_submodule, "", l)
14811499
self.assertRaises(AttributeError, net.set_submodule, "a.l", l)
1500+
self.assertRaises(AttributeError, net.set_submodule, "0", l, True)
1501+
net.set_submodule("0", l, False)
1502+
self.assertEqual(net.get_submodule("0"), l)
1503+
l3 = nn.Linear(1, 1)
1504+
net.set_submodule("0", l3, True)
1505+
self.assertEqual(net.get_submodule("0"), l3)
1506+
net.foo = "bar"
1507+
self.assertRaises(AttributeError, net.set_submodule, "foo", l)
1508+
self.assertRaises(ValueError, net.set_submodule, "t.l", "bazz")
14821509

14831510
def test_module_to_argparse(self):
14841511
net = nn.Sequential(nn.Linear(3, 3))

torch/nn/modules/module.py

+49-27
Original file line numberDiff line numberDiff line change
@@ -693,9 +693,9 @@ def get_submodule(self, target: str) -> "Module":
693693
torch.nn.Module: The submodule referenced by ``target``
694694
695695
Raises:
696-
AttributeError: If the target string references an invalid
697-
path or resolves to something that is not an
698-
``nn.Module``
696+
AttributeError: If at any point along the path resulting from
697+
the target string the (sub)path resolves to a non-existent
698+
attribute name or an object that is not an instance of ``nn.Module``.
699699
"""
700700
if target == "":
701701
return self
@@ -716,10 +716,18 @@ def get_submodule(self, target: str) -> "Module":
716716

717717
return mod
718718

719-
def set_submodule(self, target: str, module: "Module") -> None:
719+
def set_submodule(
720+
self, target: str, module: "Module", strict: bool = False
721+
) -> None:
720722
"""
721723
Set the submodule given by ``target`` if it exists, otherwise throw an error.
722724
725+
.. note::
726+
If ``strict`` is set to ``False`` (default), the method will replace an existing submodule
727+
or create a new submodule if the parent module exists. If ``strict`` is set to ``True``,
728+
the method will only attempt to replace an existing submodule and throw an error if
729+
the submodule does not exist.
730+
723731
For example, let's say you have an ``nn.Module`` ``A`` that
724732
looks like this:
725733
@@ -728,52 +736,66 @@ def set_submodule(self, target: str, module: "Module") -> None:
728736
A(
729737
(net_b): Module(
730738
(net_c): Module(
731-
(conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))
739+
(conv): Conv2d(3, 3, 3)
732740
)
733-
(linear): Linear(in_features=100, out_features=200, bias=True)
741+
(linear): Linear(3, 3)
734742
)
735743
)
736744
737745
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
738746
submodule ``net_b``, which itself has two submodules ``net_c``
739747
and ``linear``. ``net_c`` then has a submodule ``conv``.)
740748
741-
To overide the ``Conv2d`` with a new submodule ``Linear``, you
742-
would call
743-
``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
749+
To override the ``Conv2d`` with a new submodule ``Linear``, you
750+
could call ``set_submodule("net_b.net_c.conv", nn.Linear(1, 1))``
751+
where ``strict`` could be ``True`` or ``False``
752+
753+
To add a new submodule ``Conv2d`` to the existing ``net_b`` module,
754+
you would call ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))``.
755+
756+
In the above if you set ``strict=True`` and call
757+
``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)``, an AttributeError
758+
will be raised because ``net_b`` does not have a submodule named ``conv``.
744759
745760
Args:
746761
target: The fully-qualified string name of the submodule
747762
to look for. (See above example for how to specify a
748763
fully-qualified string.)
749764
module: The module to set the submodule to.
765+
strict: If ``False``, the method will replace an existing submodule
766+
or create a new submodule if the parent module exists. If ``True``,
767+
the method will only attempt to replace an existing submodule and throw an error
768+
if the submodule doesn't already exist.
750769
751770
Raises:
752-
ValueError: If the target string is empty
753-
AttributeError: If the target string references an invalid
754-
path or resolves to something that is not an
755-
``nn.Module``
771+
ValueError: If the ``target`` string is empty or if ``module`` is not an instance of ``nn.Module``.
772+
AttributeError: If at any point along the path resulting from
773+
the ``target`` string the (sub)path resolves to a non-existent
774+
attribute name or an object that is not an instance of ``nn.Module``.
756775
"""
757776
if target == "":
758777
raise ValueError("Cannot set the submodule without a target name!")
759778

760779
atoms: list[str] = target.split(".")
761-
name = atoms.pop(-1)
762-
mod: torch.nn.Module = self
763-
764-
for item in atoms:
765-
if not hasattr(mod, item):
766-
raise AttributeError(
767-
mod._get_name() + " has no attribute `" + item + "`"
768-
)
769-
770-
mod = getattr(mod, item)
780+
if not isinstance(module, torch.nn.Module):
781+
raise ValueError(
782+
"`" + "module" + f"` is not an nn.Module, found {type(module)}"
783+
)
784+
if len(atoms) == 1:
785+
parent: torch.nn.Module = self
786+
else:
787+
parent_key = ".".join(atoms[:-1])
788+
parent = self.get_submodule(parent_key)
771789

772-
# Use isinstance instead of type here to also handle subclass of nn.Module
790+
if strict and not hasattr(parent, atoms[-1]):
791+
raise AttributeError(
792+
parent._get_name() + " has no attribute `" + atoms[-1] + "`"
793+
)
794+
if hasattr(parent, atoms[-1]):
795+
mod = getattr(parent, atoms[-1])
773796
if not isinstance(mod, torch.nn.Module):
774-
raise AttributeError("`" + item + "` is not an nn.Module")
775-
776-
setattr(mod, name, module)
797+
raise AttributeError("`" + atoms[-1] + "` is not an nn.Module")
798+
setattr(parent, atoms[-1], module)
777799

778800
def get_parameter(self, target: str) -> "Parameter":
779801
"""Return the parameter given by ``target`` if it exists, otherwise throw an error.

0 commit comments

Comments
 (0)