Skip to content

Commit ab27f16

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Modernize DDP tests (#2658)
Summary: Pull Request resolved: #2658 These tests are timing out on internal testing infra. modernizing to resolve. Reviewed By: sarckk Differential Revision: D67720829 fbshipit-source-id: b4508a18b25b81365ea2a3af727cbd36c861bf2a
1 parent f059a49 commit ab27f16

File tree

1 file changed

+129
-174
lines changed

1 file changed

+129
-174
lines changed

Diff for: torchrec/distributed/composable/tests/test_ddp.py

+129-174
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,10 @@
99

1010
#!/usr/bin/env python3
1111

12-
import os
1312
import tempfile
1413
import unittest
15-
import uuid
1614

1715
import torch
18-
from torch import distributed as dist
1916
from torch.distributed._composable import replicate
2017
from torch.distributed._shard.api import ShardedTensor
2118
from torch.distributed.checkpoint import (
@@ -24,167 +21,142 @@
2421
load_state_dict,
2522
save_state_dict,
2623
)
27-
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
2824
from torchrec.distributed.shard import shard as trec_shard, shard_modules
2925
from torchrec.distributed.sharding_plan import column_wise
26+
from torchrec.distributed.test_utils.multi_process import (
27+
MultiProcessContext,
28+
MultiProcessTestBase,
29+
)
3030
from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN
3131
from torchrec.modules.embedding_configs import EmbeddingBagConfig
3232
from torchrec.test_utils import skip_if_asan
3333

3434

35-
class DDPTest(unittest.TestCase):
35+
class DDPTest(MultiProcessTestBase):
3636
@classmethod
37-
def _run_init_parameters(cls, path: str) -> None:
38-
rank = int(os.environ["LOCAL_RANK"])
39-
world_size = int(os.environ["WORLD_SIZE"])
40-
if torch.cuda.is_available():
41-
device: torch.device = torch.device(f"cuda:{rank}")
42-
backend = "nccl"
43-
torch.cuda.set_device(device)
44-
else:
45-
device: torch.device = torch.device("cpu")
46-
backend = "gloo"
47-
dist.init_process_group(
48-
backend=backend,
49-
rank=rank,
50-
world_size=world_size,
51-
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
52-
)
53-
num_float_features = 32
54-
55-
tables = [
56-
EmbeddingBagConfig(
57-
num_embeddings=(i + 1) * 10,
58-
embedding_dim=(i + 1) * 4 * world_size,
59-
name="table_" + str(i),
60-
feature_names=["feature_" + str(i)],
61-
)
62-
for i in range(3)
63-
]
64-
weighted_tables = [
65-
EmbeddingBagConfig(
66-
num_embeddings=(i + 1) * 10,
67-
embedding_dim=(i + 1) * 4 * world_size,
68-
name="weighted_table_" + str(i),
69-
feature_names=["weighted_feature_" + str(i)],
70-
)
71-
for i in range(2)
72-
]
73-
m = TestSparseNN(
74-
tables=tables,
75-
num_float_features=num_float_features,
76-
weighted_tables=weighted_tables,
77-
dense_device=device,
78-
)
79-
# Put all tensors on meta device, then init_params should
80-
# materialize them.
81-
for name, param in m._parameters.items():
82-
if isinstance(param, torch.Tensor):
83-
m._parameters[name] = torch.nn.Parameter(
84-
torch.empty_like(param, device="meta"),
85-
requires_grad=param.requires_grad,
37+
def _run_init(cls, rank: int, world_size: int) -> None:
38+
with MultiProcessContext(rank, world_size, "nccl") as ctx:
39+
num_float_features = 32
40+
41+
tables = [
42+
EmbeddingBagConfig(
43+
num_embeddings=(i + 1) * 10,
44+
embedding_dim=(i + 1) * 4 * world_size,
45+
name="table_" + str(i),
46+
feature_names=["feature_" + str(i)],
8647
)
87-
88-
shard_modules(m, device=device, init_params=True)
89-
# init_params should move m to `device`
90-
for p in m.parameters():
91-
assert p.device == device
48+
for i in range(3)
49+
]
50+
weighted_tables = [
51+
EmbeddingBagConfig(
52+
num_embeddings=(i + 1) * 10,
53+
embedding_dim=(i + 1) * 4 * world_size,
54+
name="weighted_table_" + str(i),
55+
feature_names=["weighted_feature_" + str(i)],
56+
)
57+
for i in range(2)
58+
]
59+
m = TestSparseNN(
60+
tables=tables,
61+
num_float_features=num_float_features,
62+
weighted_tables=weighted_tables,
63+
dense_device=ctx.device,
64+
)
65+
# Put all tensors on meta device, then init_params should
66+
# materialize them.
67+
for name, param in m._parameters.items():
68+
if isinstance(param, torch.Tensor):
69+
m._parameters[name] = torch.nn.Parameter(
70+
torch.empty_like(param, device="meta"),
71+
requires_grad=param.requires_grad,
72+
)
73+
74+
shard_modules(m, device=ctx.device, init_params=True)
75+
# init_params should move m to `device`
76+
for p in m.parameters():
77+
assert p.device == ctx.device
9278

9379
@classmethod
94-
def _run(cls, path: str) -> None:
95-
rank = int(os.environ["LOCAL_RANK"])
96-
world_size = int(os.environ["WORLD_SIZE"])
97-
if torch.cuda.is_available():
98-
device: torch.device = torch.device(f"cuda:{rank}")
99-
backend = "nccl"
100-
torch.cuda.set_device(device)
101-
else:
102-
device: torch.device = torch.device("cpu")
103-
backend = "gloo"
104-
dist.init_process_group(
105-
backend=backend,
106-
rank=rank,
107-
world_size=world_size,
108-
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
109-
)
110-
num_float_features = 32
111-
112-
tables = [
113-
EmbeddingBagConfig(
114-
num_embeddings=(i + 1) * 10,
115-
embedding_dim=(i + 1) * 4 * world_size,
116-
name="table_" + str(i),
117-
feature_names=["feature_" + str(i)],
80+
def _run(cls, rank: int, world_size: int, path: str) -> None:
81+
with MultiProcessContext(rank, world_size, "nccl") as ctx:
82+
num_float_features = 32
83+
84+
tables = [
85+
EmbeddingBagConfig(
86+
num_embeddings=(i + 1) * 10,
87+
embedding_dim=(i + 1) * 4 * world_size,
88+
name="table_" + str(i),
89+
feature_names=["feature_" + str(i)],
90+
)
91+
for i in range(3)
92+
]
93+
weighted_tables = [
94+
EmbeddingBagConfig(
95+
num_embeddings=(i + 1) * 10,
96+
embedding_dim=(i + 1) * 4 * world_size,
97+
name="weighted_table_" + str(i),
98+
feature_names=["weighted_feature_" + str(i)],
99+
)
100+
for i in range(2)
101+
]
102+
m = TestSparseNN(
103+
tables=tables,
104+
num_float_features=num_float_features,
105+
weighted_tables=weighted_tables,
106+
dense_device=ctx.device,
118107
)
119-
for i in range(3)
120-
]
121-
weighted_tables = [
122-
EmbeddingBagConfig(
123-
num_embeddings=(i + 1) * 10,
124-
embedding_dim=(i + 1) * 4 * world_size,
125-
name="weighted_table_" + str(i),
126-
feature_names=["weighted_feature_" + str(i)],
108+
m.sparse.ebc = trec_shard(
109+
module=m.sparse.ebc,
110+
device=ctx.device,
111+
plan=column_wise(ranks=list(range(world_size))),
127112
)
128-
for i in range(2)
129-
]
130-
m = TestSparseNN(
131-
tables=tables,
132-
num_float_features=num_float_features,
133-
weighted_tables=weighted_tables,
134-
dense_device=device,
135-
)
136-
m.sparse.ebc = trec_shard(
137-
module=m.sparse.ebc,
138-
device=device,
139-
plan=column_wise(ranks=list(range(world_size))),
140-
)
141-
m.sparse.weighted_ebc = trec_shard(
142-
module=m.sparse.weighted_ebc,
143-
device=device,
144-
plan=column_wise(ranks=list(range(world_size))),
145-
)
146-
m.over = replicate(m.over)
147-
m.dense = replicate(m.dense)
148-
149-
######## run one iteration ########
150-
_, local_batch = ModelInput.generate(
151-
batch_size=8,
152-
world_size=world_size,
153-
num_float_features=num_float_features,
154-
tables=tables,
155-
weighted_tables=weighted_tables,
156-
)
157-
batch = local_batch[0].to(device)
158-
m(batch)[1].sum().backward()
159-
160-
state_dict = m.state_dict()
161-
writer = FileSystemWriter(path=path)
162-
reader = FileSystemReader(path=path)
163-
save_state_dict(state_dict, writer)
164-
165-
p_sum = torch.zeros(1, device=device)
166-
for p in m.parameters():
167-
with torch.no_grad():
168-
if isinstance(p, ShardedTensor):
169-
if not p.local_shards():
170-
continue
171-
p = p.local_tensor()
172-
p_sum += p.sum()
173-
p.zero_()
174-
assert p.sum() == 0
175-
load_state_dict(state_dict, reader)
176-
m.load_state_dict(state_dict)
177-
178-
p_sum_loaded = torch.zeros(1, device=device)
179-
for p in m.parameters():
180-
with torch.no_grad():
181-
if isinstance(p, ShardedTensor):
182-
if not p.local_shards():
183-
continue
184-
p = p.local_tensor()
185-
p_sum_loaded += p.sum()
186-
# TODO: debug why failing on OSS
187-
# assert p_sum.allclose(p_sum_loaded)
113+
m.sparse.weighted_ebc = trec_shard(
114+
module=m.sparse.weighted_ebc,
115+
device=ctx.device,
116+
plan=column_wise(ranks=list(range(world_size))),
117+
)
118+
m.over = replicate(m.over)
119+
m.dense = replicate(m.dense)
120+
121+
######## run one iteration ########
122+
_, local_batch = ModelInput.generate(
123+
batch_size=8,
124+
world_size=world_size,
125+
num_float_features=num_float_features,
126+
tables=tables,
127+
weighted_tables=weighted_tables,
128+
)
129+
batch = local_batch[0].to(ctx.device)
130+
m(batch)[1].sum().backward()
131+
132+
state_dict = m.state_dict()
133+
writer = FileSystemWriter(path=path)
134+
reader = FileSystemReader(path=path)
135+
save_state_dict(state_dict, writer)
136+
137+
p_sum = torch.zeros(1, device=ctx.device)
138+
for p in m.parameters():
139+
with torch.no_grad():
140+
if isinstance(p, ShardedTensor):
141+
if not p.local_shards():
142+
continue
143+
p = p.local_tensor()
144+
p_sum += p.sum()
145+
p.zero_()
146+
assert p.sum() == 0
147+
load_state_dict(state_dict, reader)
148+
m.load_state_dict(state_dict)
149+
150+
p_sum_loaded = torch.zeros(1, device=ctx.device)
151+
for p in m.parameters():
152+
with torch.no_grad():
153+
if isinstance(p, ShardedTensor):
154+
if not p.local_shards():
155+
continue
156+
p = p.local_tensor()
157+
p_sum_loaded += p.sum()
158+
# TODO: debug why failing on OSS
159+
# assert p_sum.allclose(p_sum_loaded)
188160

189161
@skip_if_asan
190162
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -195,18 +167,10 @@ def _run(cls, path: str) -> None:
195167
)
196168
def test_checkpoint(self) -> None:
197169
with tempfile.TemporaryDirectory() as path:
198-
lc = LaunchConfig(
199-
min_nodes=1,
200-
max_nodes=1,
201-
nproc_per_node=2,
202-
run_id=str(uuid.uuid4()),
203-
rdzv_backend="c10d",
204-
rdzv_endpoint="localhost:0",
205-
start_method="spawn",
206-
monitor_interval=1,
207-
max_restarts=0,
170+
self._run_multi_process_test(
171+
callable=self._run,
172+
path=path,
208173
)
209-
elastic_launch(config=lc, entrypoint=self._run)(path)
210174

211175
@skip_if_asan
212176
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -216,15 +180,6 @@ def test_checkpoint(self) -> None:
216180
"Not enough GPUs, this test requires at least two GPUs",
217181
)
218182
def test_init_params(self) -> None:
219-
with tempfile.TemporaryDirectory() as path:
220-
lc = LaunchConfig(
221-
min_nodes=1,
222-
max_nodes=1,
223-
nproc_per_node=2,
224-
run_id=str(uuid.uuid4()),
225-
rdzv_backend="c10d",
226-
start_method="spawn",
227-
monitor_interval=1,
228-
max_restarts=0,
229-
)
230-
elastic_launch(config=lc, entrypoint=self._run_init_parameters)(path)
183+
self._run_multi_process_test(
184+
callable=self._run_init,
185+
)

0 commit comments

Comments
 (0)