9
9
10
10
#!/usr/bin/env python3
11
11
12
- import os
13
12
import tempfile
14
13
import unittest
15
- import uuid
16
14
17
15
import torch
18
- from torch import distributed as dist
19
16
from torch .distributed ._composable import replicate
20
17
from torch .distributed ._shard .api import ShardedTensor
21
18
from torch .distributed .checkpoint import (
24
21
load_state_dict ,
25
22
save_state_dict ,
26
23
)
27
- from torch .distributed .launcher .api import elastic_launch , LaunchConfig
28
24
from torchrec .distributed .shard import shard as trec_shard , shard_modules
29
25
from torchrec .distributed .sharding_plan import column_wise
26
+ from torchrec .distributed .test_utils .multi_process import (
27
+ MultiProcessContext ,
28
+ MultiProcessTestBase ,
29
+ )
30
30
from torchrec .distributed .test_utils .test_model import ModelInput , TestSparseNN
31
31
from torchrec .modules .embedding_configs import EmbeddingBagConfig
32
32
from torchrec .test_utils import skip_if_asan
33
33
34
34
35
- class DDPTest (unittest . TestCase ):
35
+ class DDPTest (MultiProcessTestBase ):
36
36
@classmethod
37
- def _run_init_parameters (cls , path : str ) -> None :
38
- rank = int (os .environ ["LOCAL_RANK" ])
39
- world_size = int (os .environ ["WORLD_SIZE" ])
40
- if torch .cuda .is_available ():
41
- device : torch .device = torch .device (f"cuda:{ rank } " )
42
- backend = "nccl"
43
- torch .cuda .set_device (device )
44
- else :
45
- device : torch .device = torch .device ("cpu" )
46
- backend = "gloo"
47
- dist .init_process_group (
48
- backend = backend ,
49
- rank = rank ,
50
- world_size = world_size ,
51
- init_method = f"file://{ os .path .join (path , 'dist_rdvz' )} " ,
52
- )
53
- num_float_features = 32
54
-
55
- tables = [
56
- EmbeddingBagConfig (
57
- num_embeddings = (i + 1 ) * 10 ,
58
- embedding_dim = (i + 1 ) * 4 * world_size ,
59
- name = "table_" + str (i ),
60
- feature_names = ["feature_" + str (i )],
61
- )
62
- for i in range (3 )
63
- ]
64
- weighted_tables = [
65
- EmbeddingBagConfig (
66
- num_embeddings = (i + 1 ) * 10 ,
67
- embedding_dim = (i + 1 ) * 4 * world_size ,
68
- name = "weighted_table_" + str (i ),
69
- feature_names = ["weighted_feature_" + str (i )],
70
- )
71
- for i in range (2 )
72
- ]
73
- m = TestSparseNN (
74
- tables = tables ,
75
- num_float_features = num_float_features ,
76
- weighted_tables = weighted_tables ,
77
- dense_device = device ,
78
- )
79
- # Put all tensors on meta device, then init_params should
80
- # materialize them.
81
- for name , param in m ._parameters .items ():
82
- if isinstance (param , torch .Tensor ):
83
- m ._parameters [name ] = torch .nn .Parameter (
84
- torch .empty_like (param , device = "meta" ),
85
- requires_grad = param .requires_grad ,
37
+ def _run_init (cls , rank : int , world_size : int ) -> None :
38
+ with MultiProcessContext (rank , world_size , "nccl" ) as ctx :
39
+ num_float_features = 32
40
+
41
+ tables = [
42
+ EmbeddingBagConfig (
43
+ num_embeddings = (i + 1 ) * 10 ,
44
+ embedding_dim = (i + 1 ) * 4 * world_size ,
45
+ name = "table_" + str (i ),
46
+ feature_names = ["feature_" + str (i )],
86
47
)
87
-
88
- shard_modules (m , device = device , init_params = True )
89
- # init_params should move m to `device`
90
- for p in m .parameters ():
91
- assert p .device == device
48
+ for i in range (3 )
49
+ ]
50
+ weighted_tables = [
51
+ EmbeddingBagConfig (
52
+ num_embeddings = (i + 1 ) * 10 ,
53
+ embedding_dim = (i + 1 ) * 4 * world_size ,
54
+ name = "weighted_table_" + str (i ),
55
+ feature_names = ["weighted_feature_" + str (i )],
56
+ )
57
+ for i in range (2 )
58
+ ]
59
+ m = TestSparseNN (
60
+ tables = tables ,
61
+ num_float_features = num_float_features ,
62
+ weighted_tables = weighted_tables ,
63
+ dense_device = ctx .device ,
64
+ )
65
+ # Put all tensors on meta device, then init_params should
66
+ # materialize them.
67
+ for name , param in m ._parameters .items ():
68
+ if isinstance (param , torch .Tensor ):
69
+ m ._parameters [name ] = torch .nn .Parameter (
70
+ torch .empty_like (param , device = "meta" ),
71
+ requires_grad = param .requires_grad ,
72
+ )
73
+
74
+ shard_modules (m , device = ctx .device , init_params = True )
75
+ # init_params should move m to `device`
76
+ for p in m .parameters ():
77
+ assert p .device == ctx .device
92
78
93
79
@classmethod
94
- def _run (cls , path : str ) -> None :
95
- rank = int (os .environ ["LOCAL_RANK" ])
96
- world_size = int (os .environ ["WORLD_SIZE" ])
97
- if torch .cuda .is_available ():
98
- device : torch .device = torch .device (f"cuda:{ rank } " )
99
- backend = "nccl"
100
- torch .cuda .set_device (device )
101
- else :
102
- device : torch .device = torch .device ("cpu" )
103
- backend = "gloo"
104
- dist .init_process_group (
105
- backend = backend ,
106
- rank = rank ,
107
- world_size = world_size ,
108
- init_method = f"file://{ os .path .join (path , 'dist_rdvz' )} " ,
109
- )
110
- num_float_features = 32
111
-
112
- tables = [
113
- EmbeddingBagConfig (
114
- num_embeddings = (i + 1 ) * 10 ,
115
- embedding_dim = (i + 1 ) * 4 * world_size ,
116
- name = "table_" + str (i ),
117
- feature_names = ["feature_" + str (i )],
80
+ def _run (cls , rank : int , world_size : int , path : str ) -> None :
81
+ with MultiProcessContext (rank , world_size , "nccl" ) as ctx :
82
+ num_float_features = 32
83
+
84
+ tables = [
85
+ EmbeddingBagConfig (
86
+ num_embeddings = (i + 1 ) * 10 ,
87
+ embedding_dim = (i + 1 ) * 4 * world_size ,
88
+ name = "table_" + str (i ),
89
+ feature_names = ["feature_" + str (i )],
90
+ )
91
+ for i in range (3 )
92
+ ]
93
+ weighted_tables = [
94
+ EmbeddingBagConfig (
95
+ num_embeddings = (i + 1 ) * 10 ,
96
+ embedding_dim = (i + 1 ) * 4 * world_size ,
97
+ name = "weighted_table_" + str (i ),
98
+ feature_names = ["weighted_feature_" + str (i )],
99
+ )
100
+ for i in range (2 )
101
+ ]
102
+ m = TestSparseNN (
103
+ tables = tables ,
104
+ num_float_features = num_float_features ,
105
+ weighted_tables = weighted_tables ,
106
+ dense_device = ctx .device ,
118
107
)
119
- for i in range (3 )
120
- ]
121
- weighted_tables = [
122
- EmbeddingBagConfig (
123
- num_embeddings = (i + 1 ) * 10 ,
124
- embedding_dim = (i + 1 ) * 4 * world_size ,
125
- name = "weighted_table_" + str (i ),
126
- feature_names = ["weighted_feature_" + str (i )],
108
+ m .sparse .ebc = trec_shard (
109
+ module = m .sparse .ebc ,
110
+ device = ctx .device ,
111
+ plan = column_wise (ranks = list (range (world_size ))),
127
112
)
128
- for i in range (2 )
129
- ]
130
- m = TestSparseNN (
131
- tables = tables ,
132
- num_float_features = num_float_features ,
133
- weighted_tables = weighted_tables ,
134
- dense_device = device ,
135
- )
136
- m .sparse .ebc = trec_shard (
137
- module = m .sparse .ebc ,
138
- device = device ,
139
- plan = column_wise (ranks = list (range (world_size ))),
140
- )
141
- m .sparse .weighted_ebc = trec_shard (
142
- module = m .sparse .weighted_ebc ,
143
- device = device ,
144
- plan = column_wise (ranks = list (range (world_size ))),
145
- )
146
- m .over = replicate (m .over )
147
- m .dense = replicate (m .dense )
148
-
149
- ######## run one iteration ########
150
- _ , local_batch = ModelInput .generate (
151
- batch_size = 8 ,
152
- world_size = world_size ,
153
- num_float_features = num_float_features ,
154
- tables = tables ,
155
- weighted_tables = weighted_tables ,
156
- )
157
- batch = local_batch [0 ].to (device )
158
- m (batch )[1 ].sum ().backward ()
159
-
160
- state_dict = m .state_dict ()
161
- writer = FileSystemWriter (path = path )
162
- reader = FileSystemReader (path = path )
163
- save_state_dict (state_dict , writer )
164
-
165
- p_sum = torch .zeros (1 , device = device )
166
- for p in m .parameters ():
167
- with torch .no_grad ():
168
- if isinstance (p , ShardedTensor ):
169
- if not p .local_shards ():
170
- continue
171
- p = p .local_tensor ()
172
- p_sum += p .sum ()
173
- p .zero_ ()
174
- assert p .sum () == 0
175
- load_state_dict (state_dict , reader )
176
- m .load_state_dict (state_dict )
177
-
178
- p_sum_loaded = torch .zeros (1 , device = device )
179
- for p in m .parameters ():
180
- with torch .no_grad ():
181
- if isinstance (p , ShardedTensor ):
182
- if not p .local_shards ():
183
- continue
184
- p = p .local_tensor ()
185
- p_sum_loaded += p .sum ()
186
- # TODO: debug why failing on OSS
187
- # assert p_sum.allclose(p_sum_loaded)
113
+ m .sparse .weighted_ebc = trec_shard (
114
+ module = m .sparse .weighted_ebc ,
115
+ device = ctx .device ,
116
+ plan = column_wise (ranks = list (range (world_size ))),
117
+ )
118
+ m .over = replicate (m .over )
119
+ m .dense = replicate (m .dense )
120
+
121
+ ######## run one iteration ########
122
+ _ , local_batch = ModelInput .generate (
123
+ batch_size = 8 ,
124
+ world_size = world_size ,
125
+ num_float_features = num_float_features ,
126
+ tables = tables ,
127
+ weighted_tables = weighted_tables ,
128
+ )
129
+ batch = local_batch [0 ].to (ctx .device )
130
+ m (batch )[1 ].sum ().backward ()
131
+
132
+ state_dict = m .state_dict ()
133
+ writer = FileSystemWriter (path = path )
134
+ reader = FileSystemReader (path = path )
135
+ save_state_dict (state_dict , writer )
136
+
137
+ p_sum = torch .zeros (1 , device = ctx .device )
138
+ for p in m .parameters ():
139
+ with torch .no_grad ():
140
+ if isinstance (p , ShardedTensor ):
141
+ if not p .local_shards ():
142
+ continue
143
+ p = p .local_tensor ()
144
+ p_sum += p .sum ()
145
+ p .zero_ ()
146
+ assert p .sum () == 0
147
+ load_state_dict (state_dict , reader )
148
+ m .load_state_dict (state_dict )
149
+
150
+ p_sum_loaded = torch .zeros (1 , device = ctx .device )
151
+ for p in m .parameters ():
152
+ with torch .no_grad ():
153
+ if isinstance (p , ShardedTensor ):
154
+ if not p .local_shards ():
155
+ continue
156
+ p = p .local_tensor ()
157
+ p_sum_loaded += p .sum ()
158
+ # TODO: debug why failing on OSS
159
+ # assert p_sum.allclose(p_sum_loaded)
188
160
189
161
@skip_if_asan
190
162
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -195,18 +167,10 @@ def _run(cls, path: str) -> None:
195
167
)
196
168
def test_checkpoint (self ) -> None :
197
169
with tempfile .TemporaryDirectory () as path :
198
- lc = LaunchConfig (
199
- min_nodes = 1 ,
200
- max_nodes = 1 ,
201
- nproc_per_node = 2 ,
202
- run_id = str (uuid .uuid4 ()),
203
- rdzv_backend = "c10d" ,
204
- rdzv_endpoint = "localhost:0" ,
205
- start_method = "spawn" ,
206
- monitor_interval = 1 ,
207
- max_restarts = 0 ,
170
+ self ._run_multi_process_test (
171
+ callable = self ._run ,
172
+ path = path ,
208
173
)
209
- elastic_launch (config = lc , entrypoint = self ._run )(path )
210
174
211
175
@skip_if_asan
212
176
# pyre-fixme[56]: Pyre was not able to infer the type of argument
@@ -216,15 +180,6 @@ def test_checkpoint(self) -> None:
216
180
"Not enough GPUs, this test requires at least two GPUs" ,
217
181
)
218
182
def test_init_params (self ) -> None :
219
- with tempfile .TemporaryDirectory () as path :
220
- lc = LaunchConfig (
221
- min_nodes = 1 ,
222
- max_nodes = 1 ,
223
- nproc_per_node = 2 ,
224
- run_id = str (uuid .uuid4 ()),
225
- rdzv_backend = "c10d" ,
226
- start_method = "spawn" ,
227
- monitor_interval = 1 ,
228
- max_restarts = 0 ,
229
- )
230
- elastic_launch (config = lc , entrypoint = self ._run_init_parameters )(path )
183
+ self ._run_multi_process_test (
184
+ callable = self ._run_init ,
185
+ )
0 commit comments