@@ -785,7 +785,9 @@ def purge(self) -> None:
785
785
def named_split_embedding_weights (
786
786
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
787
787
) -> Iterator [Tuple [str , torch .Tensor ]]:
788
- assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
788
+ assert (
789
+ remove_duplicate
790
+ ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
789
791
for config , param in zip (
790
792
self ._config .embedding_tables ,
791
793
self .emb_module .split_embedding_weights (),
@@ -897,7 +899,9 @@ def named_parameters(
897
899
def named_split_embedding_weights (
898
900
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
899
901
) -> Iterator [Tuple [str , torch .Tensor ]]:
900
- assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
902
+ assert (
903
+ remove_duplicate
904
+ ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
901
905
for config , tensor in zip (
902
906
self ._config .embedding_tables ,
903
907
self .split_embedding_weights (),
@@ -1078,9 +1082,8 @@ def named_parameters(
1078
1082
combined_key = "/" .join (
1079
1083
[config .name for config in self ._config .embedding_tables ]
1080
1084
)
1081
- yield (
1082
- append_prefix (prefix , f"{ combined_key } .weight" ),
1083
- cast (nn .Parameter , self ._emb_module .weights ),
1085
+ yield append_prefix (prefix , f"{ combined_key } .weight" ), cast (
1086
+ nn .Parameter , self ._emb_module .weights
1084
1087
)
1085
1088
1086
1089
@@ -1098,8 +1101,7 @@ def __init__(
1098
1101
self ._pg = pg
1099
1102
1100
1103
self ._pooling : PoolingMode = pooling_type_to_pooling_mode (
1101
- config .pooling ,
1102
- sharding_type , # pyre-ignore[6]
1104
+ config .pooling , sharding_type # pyre-ignore[6]
1103
1105
)
1104
1106
1105
1107
self ._local_rows : List [int ] = []
@@ -1218,7 +1220,9 @@ def purge(self) -> None:
1218
1220
def named_split_embedding_weights (
1219
1221
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
1220
1222
) -> Iterator [Tuple [str , torch .Tensor ]]:
1221
- assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1223
+ assert (
1224
+ remove_duplicate
1225
+ ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1222
1226
for config , tensor in zip (
1223
1227
self ._config .embedding_tables ,
1224
1228
self .emb_module .split_embedding_weights (),
@@ -1358,7 +1362,9 @@ def named_parameters(
1358
1362
def named_split_embedding_weights (
1359
1363
self , prefix : str = "" , recurse : bool = True , remove_duplicate : bool = True
1360
1364
) -> Iterator [Tuple [str , PartiallyMaterializedTensor ]]:
1361
- assert remove_duplicate , "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1365
+ assert (
1366
+ remove_duplicate
1367
+ ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1362
1368
for config , tensor in zip (
1363
1369
self ._config .embedding_tables ,
1364
1370
self .split_embedding_weights (),
@@ -1561,7 +1567,6 @@ def named_parameters(
1561
1567
combined_key = "/" .join (
1562
1568
[config .name for config in self ._config .embedding_tables ]
1563
1569
)
1564
- yield (
1565
- append_prefix (prefix , f"{ combined_key } .weight" ),
1566
- cast (nn .Parameter , self ._emb_module .weights ),
1570
+ yield append_prefix (prefix , f"{ combined_key } .weight" ), cast (
1571
+ nn .Parameter , self ._emb_module .weights
1567
1572
)
0 commit comments