Skip to content

Commit c2f7d61

Browse files
Thomas Polasekfacebook-github-bot
Thomas Polasek
authored andcommitted
Back out "Convert directory fbcode/torchrec to use the Ruff Formatter"
Summary: Original commit changeset: ee300de21222 Original Phabricator Diff: D66013071 bypass-github-export-checks Reviewed By: aporialiao Differential Revision: D66198773 fbshipit-source-id: 4a8e5a124937a8329d7ed39444d3fdc5e4f2d10c
1 parent 34cdb1d commit c2f7d61

File tree

107 files changed

+472
-369
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+472
-369
lines changed

benchmarks/ebc_benchmarks.py

+2
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def get_fused_ebc_uvm_time(
163163
location: EmbeddingLocation,
164164
epochs: int = 100,
165165
) -> Tuple[float, float]:
166+
166167
fused_ebc = FusedEmbeddingBagCollection(
167168
tables=embedding_bag_configs,
168169
optimizer_type=torch.optim.SGD,
@@ -194,6 +195,7 @@ def get_ebc_comparison(
194195
device: torch.device,
195196
epochs: int = 100,
196197
) -> Tuple[float, float, float, float, float]:
198+
197199
# Simple EBC module wrapping a list of nn.EmbeddingBag
198200
ebc = EmbeddingBagCollection(
199201
tables=embedding_bag_configs,

benchmarks/ebc_benchmarks_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def get_random_dataset(
2626
embedding_bag_configs: List[EmbeddingBagConfig],
2727
pooling_factors: Optional[Dict[str, int]] = None,
2828
) -> IterableDataset[Batch]:
29+
2930
if pooling_factors is None:
3031
pooling_factors = {}
3132

@@ -56,6 +57,7 @@ def train_one_epoch(
5657
dataset: IterableDataset[Batch],
5758
device: torch.device,
5859
) -> float:
60+
5961
start_time = time.perf_counter()
6062

6163
for data in dataset:
@@ -80,6 +82,7 @@ def train_one_epoch_fused_optimizer(
8082
dataset: IterableDataset[Batch],
8183
device: torch.device,
8284
) -> float:
85+
8386
start_time = time.perf_counter()
8487

8588
for data in dataset:
@@ -103,6 +106,7 @@ def train(
103106
device: torch.device,
104107
epochs: int = 100,
105108
) -> Tuple[float, float]:
109+
106110
training_time = []
107111
for _ in range(epochs):
108112
if optimizer:

examples/bert4rec/bert4rec_main.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
# OSS import
3737
try:
38+
3839
# pyre-ignore[21]
3940
# @manual=//torchrec/github/examples/bert4rec:bert4rec_metrics
4041
from bert4rec_metrics import recalls_and_ndcgs_for_ks

examples/golden_training/train_dlrm_data_parallel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def train(
160160
)
161161

162162
def dense_filter(
163-
named_parameters: Iterator[Tuple[str, nn.Parameter]],
163+
named_parameters: Iterator[Tuple[str, nn.Parameter]]
164164
) -> Iterator[Tuple[str, nn.Parameter]]:
165165
for fqn, param in named_parameters:
166166
if "sparse" not in fqn:

examples/retrieval/two_tower_retrieval.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
# OSS import
2929
try:
30+
3031
# pyre-ignore[21]
3132
# @manual=//torchrec/github/examples/retrieval:knn_index
3233
from knn_index import get_index

tools/lint/black_linter.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ def main() -> None:
179179
level=(
180180
logging.NOTSET
181181
if args.verbose
182-
else logging.DEBUG
183-
if len(args.filenames) < 1000
184-
else logging.INFO
182+
else logging.DEBUG if len(args.filenames) < 1000 else logging.INFO
185183
),
186184
stream=sys.stderr,
187185
)

torchrec/datasets/criteo.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -351,9 +351,8 @@ def get_file_row_ranges_and_remainder(
351351

352352
# If the ranges overlap.
353353
if rank_left_g <= file_right_g and rank_right_g >= file_left_g:
354-
overlap_left_g, overlap_right_g = (
355-
max(rank_left_g, file_left_g),
356-
min(rank_right_g, file_right_g),
354+
overlap_left_g, overlap_right_g = max(rank_left_g, file_left_g), min(
355+
rank_right_g, file_right_g
357356
)
358357

359358
# Convert overlap in global numbers to (local) numbers specific to the

torchrec/datasets/random.py

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(
3333
*,
3434
min_ids_per_features: Optional[List[int]] = None,
3535
) -> None:
36+
3637
self.keys = keys
3738
self.keys_length: int = len(keys)
3839
self.batch_size = batch_size
@@ -75,6 +76,7 @@ def __next__(self) -> Batch:
7576
return batch
7677

7778
def _generate_batch(self) -> Batch:
79+
7880
values = []
7981
lengths = []
8082
for key_idx, _ in enumerate(self.keys):

torchrec/datasets/test_utils/criteo_test_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def _create_dataset_npys(
103103
labels: Optional[np.ndarray] = None,
104104
) -> Generator[Tuple[str, ...], None, None]:
105105
with tempfile.TemporaryDirectory() as tmpdir:
106+
106107
if filenames is None:
107108
filenames = [filename]
108109

torchrec/distributed/batched_embedding_kernel.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,9 @@ def purge(self) -> None:
785785
def named_split_embedding_weights(
786786
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
787787
) -> Iterator[Tuple[str, torch.Tensor]]:
788-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
788+
assert (
789+
remove_duplicate
790+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
789791
for config, param in zip(
790792
self._config.embedding_tables,
791793
self.emb_module.split_embedding_weights(),
@@ -897,7 +899,9 @@ def named_parameters(
897899
def named_split_embedding_weights(
898900
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
899901
) -> Iterator[Tuple[str, torch.Tensor]]:
900-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
902+
assert (
903+
remove_duplicate
904+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
901905
for config, tensor in zip(
902906
self._config.embedding_tables,
903907
self.split_embedding_weights(),
@@ -1078,9 +1082,8 @@ def named_parameters(
10781082
combined_key = "/".join(
10791083
[config.name for config in self._config.embedding_tables]
10801084
)
1081-
yield (
1082-
append_prefix(prefix, f"{combined_key}.weight"),
1083-
cast(nn.Parameter, self._emb_module.weights),
1085+
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1086+
nn.Parameter, self._emb_module.weights
10841087
)
10851088

10861089

@@ -1098,8 +1101,7 @@ def __init__(
10981101
self._pg = pg
10991102

11001103
self._pooling: PoolingMode = pooling_type_to_pooling_mode(
1101-
config.pooling,
1102-
sharding_type, # pyre-ignore[6]
1104+
config.pooling, sharding_type # pyre-ignore[6]
11031105
)
11041106

11051107
self._local_rows: List[int] = []
@@ -1218,7 +1220,9 @@ def purge(self) -> None:
12181220
def named_split_embedding_weights(
12191221
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
12201222
) -> Iterator[Tuple[str, torch.Tensor]]:
1221-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1223+
assert (
1224+
remove_duplicate
1225+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
12221226
for config, tensor in zip(
12231227
self._config.embedding_tables,
12241228
self.emb_module.split_embedding_weights(),
@@ -1358,7 +1362,9 @@ def named_parameters(
13581362
def named_split_embedding_weights(
13591363
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
13601364
) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
1361-
assert remove_duplicate, "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
1365+
assert (
1366+
remove_duplicate
1367+
), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
13621368
for config, tensor in zip(
13631369
self._config.embedding_tables,
13641370
self.split_embedding_weights(),
@@ -1561,7 +1567,6 @@ def named_parameters(
15611567
combined_key = "/".join(
15621568
[config.name for config in self._config.embedding_tables]
15631569
)
1564-
yield (
1565-
append_prefix(prefix, f"{combined_key}.weight"),
1566-
cast(nn.Parameter, self._emb_module.weights),
1570+
yield append_prefix(prefix, f"{combined_key}.weight"), cast(
1571+
nn.Parameter, self._emb_module.weights
15671572
)

torchrec/distributed/benchmark/benchmark_inference.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,9 @@ def main() -> None:
250250
mb = int(float(num * dim) / 1024 / 1024)
251251
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] u8: {mb:6}Mb"
252252

253-
report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
253+
report: str = (
254+
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
255+
)
254256
report += f"Module: {module_name}\n"
255257
report += tables_info
256258
report += "\n"

torchrec/distributed/benchmark/benchmark_train.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def main() -> None:
157157
tables_info += f"\nTABLE[{i}][{num:9}, {dim:4}] {mb:6}Mb"
158158

159159
### Benchmark no VBE
160-
report: str = f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
160+
report: str = (
161+
f"REPORT BENCHMARK {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
162+
)
161163
report += f"Module: {module_name}\n"
162164
report += tables_info
163165
report += "\n"
@@ -179,7 +181,9 @@ def main() -> None:
179181
)
180182

181183
### Benchmark with VBE
182-
report: str = f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
184+
report: str = (
185+
f"REPORT BENCHMARK (VBE) {datetime_sfx} world_size:{args.world_size} batch_size:{args.batch_size}\n"
186+
)
183187
report += f"Module: {module_name} (VBE)\n"
184188
report += tables_info
185189
report += "\n"

torchrec/distributed/benchmark/benchmark_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ def __str__(self) -> str:
128128
@dataclass
129129
class BenchmarkResult:
130130
"Class for holding results of benchmark runs"
131-
132131
short_name: str
133132
elapsed_time: torch.Tensor # milliseconds
134133
mem_stats: List[MemoryStats] # memory stats per rank
@@ -555,7 +554,9 @@ def fx_script_module(eager_module: torch.nn.Module) -> torch.nn.Module:
555554
return fx_script_module(
556555
# pyre-fixme[6]: For 1st argument expected `Module` but got
557556
# `Optional[Module]`.
558-
sharded_module if not benchmark_unsharded_module else module
557+
sharded_module
558+
if not benchmark_unsharded_module
559+
else module
559560
)
560561
else:
561562
# pyre-fixme[7]: Expected `Module` but got `Optional[Module]`.
@@ -966,6 +967,7 @@ def multi_process_benchmark(
966967
# pyre-ignore
967968
**kwargs,
968969
) -> BenchmarkResult:
970+
969971
def setUp() -> None:
970972
if "MASTER_ADDR" not in os.environ:
971973
os.environ["MASTER_ADDR"] = str("localhost")

torchrec/distributed/comm_ops.py

+1
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def variable_batch_alltoall_pooled(
477477
group: Optional[dist.ProcessGroup] = None,
478478
codecs: Optional[QuantizedCommCodecs] = None,
479479
) -> Awaitable[Tensor]:
480+
480481
if group is None:
481482
group = dist.distributed_c10d._get_default_group()
482483

torchrec/distributed/composable/tests/test_embedding.py

+1
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def test_sharding_ebc(
210210
use_apply_optimizer_in_backward: bool,
211211
use_index_dedup: bool,
212212
) -> None:
213+
213214
WORLD_SIZE = 2
214215

215216
embedding_config = [

torchrec/distributed/composable/tests/test_embeddingbag.py

+1
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ def test_sharding_ebc(
292292
sharding_type: str,
293293
use_apply_optimizer_in_backward: bool,
294294
) -> None:
295+
295296
# TODO DistributedDataParallel needs full support of registering fused optims before we can enable this.
296297
assume(
297298
not (

torchrec/distributed/embedding.py

+2
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def create_sharding_infos_by_sharding(
167167
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
168168
fused_params: Optional[Dict[str, Any]],
169169
) -> Dict[str, List[EmbeddingShardingInfo]]:
170+
170171
if fused_params is None:
171172
fused_params = {}
172173

@@ -248,6 +249,7 @@ def create_sharding_infos_by_sharding_device_group(
248249
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
249250
fused_params: Optional[Dict[str, Any]],
250251
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
252+
251253
if fused_params is None:
252254
fused_params = {}
253255

torchrec/distributed/embedding_kernel.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,7 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
105105

106106
assert embedding_table.local_rows == param.size( # pyre-ignore[16]
107107
0
108-
), (
109-
# pyre-fixme[16]: Item `Tuple` of `PartiallyMaterializedTensor | Tensor
110-
# | Module | Tuple[Tensor, Optional[Tensor], Optional[Tensor]]` has no
111-
# attribute `size`.
112-
# pyre-fixme[16]: Item `Tuple` of `PartiallyMaterializedTensor | Tensor
113-
# | Module | Tuple[Tensor, Optional[Tensor], Optional[Tensor]]` has no
114-
# attribute `shape`.
115-
f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}"
116-
)
108+
), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}" # pyre-ignore[16]
117109

118110
if qscale is not None:
119111
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]

torchrec/distributed/embedding_tower_sharding.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def input_dist(
237237
features: KeyedJaggedTensor,
238238
optional_features: Optional[KeyedJaggedTensor] = None,
239239
) -> Awaitable[Awaitable[KJTList]]:
240+
240241
# optional_features are populated only if both kjt and weighted kjt present in tower
241242
if self._wkjt_feature_names and self._kjt_feature_names:
242243
kjt_features = features
@@ -505,7 +506,9 @@ def __init__(
505506
if lt_tables.issubset(pt_tables):
506507
found_physical_tower = True
507508
break
508-
assert found_physical_tower, f"tables in a logical tower must be in the same physical tower, logical tower tables: {lt_tables}, tables_per_pt: {tables_per_pt}"
509+
assert (
510+
found_physical_tower
511+
), f"tables in a logical tower must be in the same physical tower, logical tower tables: {lt_tables}, tables_per_pt: {tables_per_pt}"
509512

510513
logical_to_physical_order: List[List[int]] = [
511514
[] for _ in range(self._cross_pg_world_size)
@@ -604,6 +607,7 @@ def _create_input_dist(
604607
kjt_feature_names: List[str],
605608
wkjt_feature_names: List[str],
606609
) -> None:
610+
607611
if self._kjt_feature_names != kjt_feature_names:
608612
self._has_kjt_features_permute = True
609613
for f in self._kjt_feature_names:
@@ -940,6 +944,7 @@ def __init__(
940944
fused_params: Optional[Dict[str, Any]] = None,
941945
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
942946
) -> None:
947+
943948
super().__init__(
944949
fused_params=fused_params, qcomm_codecs_registry=qcomm_codecs_registry
945950
)
@@ -955,6 +960,7 @@ def shard(
955960
device: Optional[torch.device] = None,
956961
module_fqn: Optional[str] = None,
957962
) -> ShardedEmbeddingTowerCollection:
963+
958964
return ShardedEmbeddingTowerCollection(
959965
module=module,
960966
table_name_to_parameter_sharding=params,

torchrec/distributed/embedding_types.py

+1
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
514514
return types
515515

516516
def shardable_parameters(self, module: M) -> Dict[str, nn.Parameter]:
517+
517518
shardable_params: Dict[str, nn.Parameter] = {}
518519
for name, param in module.state_dict().items():
519520
if name.endswith(".weight"):

torchrec/distributed/embeddingbag.py

+2
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def create_sharding_infos_by_sharding(
208208
fused_params: Optional[Dict[str, Any]],
209209
suffix: Optional[str] = "weight",
210210
) -> Dict[str, List[EmbeddingShardingInfo]]:
211+
211212
if fused_params is None:
212213
fused_params = {}
213214

@@ -312,6 +313,7 @@ def create_sharding_infos_by_sharding_device_group(
312313
fused_params: Optional[Dict[str, Any]],
313314
suffix: Optional[str] = "weight",
314315
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
316+
315317
if fused_params is None:
316318
fused_params = {}
317319

torchrec/distributed/fp_embeddingbag.py

+2
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def compute(
115115
ctx: EmbeddingBagCollectionContext,
116116
dist_input: KJTList,
117117
) -> List[torch.Tensor]:
118+
118119
fp_features = self.apply_feature_processors_to_kjt_list(dist_input)
119120
return self._embedding_bag_collection.compute(ctx, fp_features)
120121

@@ -185,6 +186,7 @@ def shard(
185186
device: Optional[torch.device] = None,
186187
module_fqn: Optional[str] = None,
187188
) -> ShardedFeatureProcessedEmbeddingBagCollection:
189+
188190
if device is None:
189191
device = torch.device("cuda")
190192

0 commit comments

Comments
 (0)