@@ -693,9 +693,9 @@ def get_submodule(self, target: str) -> "Module":
693
693
torch.nn.Module: The submodule referenced by ``target``
694
694
695
695
Raises:
696
- AttributeError: If the target string references an invalid
697
- path or resolves to something that is not an
698
- ``nn.Module``
696
+ AttributeError: If at any point along the path resulting from
697
+ the target string the (sub)path resolves to a non-existent
698
+ attribute name or an object that is not an instance of ``nn.Module``.
699
699
"""
700
700
if target == "" :
701
701
return self
@@ -716,10 +716,18 @@ def get_submodule(self, target: str) -> "Module":
716
716
717
717
return mod
718
718
719
- def set_submodule (self , target : str , module : "Module" ) -> None :
719
+ def set_submodule (
720
+ self , target : str , module : "Module" , strict : bool = False
721
+ ) -> None :
720
722
"""
721
723
Set the submodule given by ``target`` if it exists, otherwise throw an error.
722
724
725
+ .. note::
726
+ If ``strict`` is set to ``False`` (default), the method will replace an existing submodule
727
+ or create a new submodule if the parent module exists. If ``strict`` is set to ``True``,
728
+ the method will only attempt to replace an existing submodule and throw an error if
729
+ the submodule does not exist.
730
+
723
731
For example, let's say you have an ``nn.Module`` ``A`` that
724
732
looks like this:
725
733
@@ -728,52 +736,66 @@ def set_submodule(self, target: str, module: "Module") -> None:
728
736
A(
729
737
(net_b): Module(
730
738
(net_c): Module(
731
- (conv): Conv2d(16, 33, kernel_size=( 3, 3), stride=(2, 2) )
739
+ (conv): Conv2d(3, 3, 3 )
732
740
)
733
- (linear): Linear(in_features=100, out_features=200, bias=True )
741
+ (linear): Linear(3, 3 )
734
742
)
735
743
)
736
744
737
745
(The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested
738
746
submodule ``net_b``, which itself has two submodules ``net_c``
739
747
and ``linear``. ``net_c`` then has a submodule ``conv``.)
740
748
741
- To overide the ``Conv2d`` with a new submodule ``Linear``, you
742
- would call
743
- ``set_submodule("net_b.net_c.conv", nn.Linear(33, 16))``.
749
+ To override the ``Conv2d`` with a new submodule ``Linear``, you
750
+ could call ``set_submodule("net_b.net_c.conv", nn.Linear(1, 1))``
751
+ where ``strict`` could be ``True`` or ``False``
752
+
753
+ To add a new submodule ``Conv2d`` to the existing ``net_b`` module,
754
+ you would call ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1))``.
755
+
756
+ In the above if you set ``strict=True`` and call
757
+ ``set_submodule("net_b.conv", nn.Conv2d(1, 1, 1), strict=True)``, an AttributeError
758
+ will be raised because ``net_b`` does not have a submodule named ``conv``.
744
759
745
760
Args:
746
761
target: The fully-qualified string name of the submodule
747
762
to look for. (See above example for how to specify a
748
763
fully-qualified string.)
749
764
module: The module to set the submodule to.
765
+ strict: If ``False``, the method will replace an existing submodule
766
+ or create a new submodule if the parent module exists. If ``True``,
767
+ the method will only attempt to replace an existing submodule and throw an error
768
+ if the submodule doesn't already exist.
750
769
751
770
Raises:
752
- ValueError: If the target string is empty
753
- AttributeError: If the target string references an invalid
754
- path or resolves to something that is not an
755
- ``nn.Module``
771
+ ValueError: If the `` target`` string is empty or if ``module`` is not an instance of ``nn.Module``.
772
+ AttributeError: If at any point along the path resulting from
773
+ the ``target`` string the (sub)path resolves to a non-existent
774
+ attribute name or an object that is not an instance of ``nn.Module``.
756
775
"""
757
776
if target == "" :
758
777
raise ValueError ("Cannot set the submodule without a target name!" )
759
778
760
779
atoms : list [str ] = target .split ("." )
761
- name = atoms .pop (- 1 )
762
- mod : torch .nn .Module = self
763
-
764
- for item in atoms :
765
- if not hasattr (mod , item ):
766
- raise AttributeError (
767
- mod ._get_name () + " has no attribute `" + item + "`"
768
- )
769
-
770
- mod = getattr (mod , item )
780
+ if not isinstance (module , torch .nn .Module ):
781
+ raise ValueError (
782
+ "`" + "module" + f"` is not an nn.Module, found { type (module )} "
783
+ )
784
+ if len (atoms ) == 1 :
785
+ parent : torch .nn .Module = self
786
+ else :
787
+ parent_key = "." .join (atoms [:- 1 ])
788
+ parent = self .get_submodule (parent_key )
771
789
772
- # Use isinstance instead of type here to also handle subclass of nn.Module
790
+ if strict and not hasattr (parent , atoms [- 1 ]):
791
+ raise AttributeError (
792
+ parent ._get_name () + " has no attribute `" + atoms [- 1 ] + "`"
793
+ )
794
+ if hasattr (parent , atoms [- 1 ]):
795
+ mod = getattr (parent , atoms [- 1 ])
773
796
if not isinstance (mod , torch .nn .Module ):
774
- raise AttributeError ("`" + item + "` is not an nn.Module" )
775
-
776
- setattr (mod , name , module )
797
+ raise AttributeError ("`" + atoms [- 1 ] + "` is not an nn.Module" )
798
+ setattr (parent , atoms [- 1 ], module )
777
799
778
800
def get_parameter (self , target : str ) -> "Parameter" :
779
801
"""Return the parameter given by ``target`` if it exists, otherwise throw an error.
0 commit comments