Skip to content

Commit 286393f

Browse files
authored
enable tp on CPU (#36299)
* enable tp on CPU Signed-off-by: jiqing-feng <[email protected]> * get rank from cpu Signed-off-by: jiqing-feng <[email protected]> * update Signed-off-by: jiqing-feng <[email protected]> * enable TP tests Signed-off-by: jiqing-feng <[email protected]> * fix comment Signed-off-by: jiqing-feng <[email protected]> * em print Signed-off-by: jiqing-feng <[email protected]> * fix model id Signed-off-by: jiqing-feng <[email protected]> * fix conflict Signed-off-by: jiqing-feng <[email protected]> * fix index and add doc Signed-off-by: jiqing-feng <[email protected]> --------- Signed-off-by: jiqing-feng <[email protected]>
1 parent 4705b04 commit 286393f

File tree

3 files changed

+56
-113
lines changed

3 files changed

+56
-113
lines changed

docs/source/en/perf_infer_gpu_multi.md

+8-6
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ import os
4444
import torch
4545
from transformers import AutoModelForCausalLM, AutoTokenizer
4646

47-
# initialize distributed environment
48-
rank = int(os.environ["RANK"])
49-
device = torch.device(f"cuda:{rank}")
50-
torch.cuda.set_device(device)
51-
torch.distributed.init_process_group("nccl", device_id=device)
5247

5348
# enable tensor parallelism
5449
model = AutoModelForCausalLM.from_pretrained(
@@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained(
5954
# prepare input tokens
6055
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
6156
prompt = "Can I help"
62-
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
57+
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
6358

6459
# distributed run
6560
outputs = model(inputs)
@@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
7166
torchrun --nproc-per-node 4 demo.py
7267
```
7368

69+
For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
70+
```bash
71+
export OMP_NUM_THREADS=56
72+
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
73+
```
74+
The CPU benchmark data will be released soon.
75+
7476
You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.
7577

7678
For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.

src/transformers/modeling_utils.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model(
774774
"""
775775
tensor_device = "cpu"
776776
if device_map is not None and device_map.get("", None) is not None:
777-
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
777+
if device_map[""] not in ("cpu", torch.device("cpu")):
778+
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
778779
if device_map is not None:
779780
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
780781

@@ -4110,24 +4111,34 @@ def from_pretrained(
41104111
if tp_plan is not None:
41114112
if not is_torch_greater_or_equal("2.5"):
41124113
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")
4114+
4115+
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
4116+
device_type = torch._C._get_accelerator().type
4117+
41134118
if not torch.distributed.is_initialized():
41144119
try:
41154120
rank = int(os.environ["RANK"])
41164121
world_size = int(os.environ["WORLD_SIZE"])
4117-
torch.distributed.init_process_group(
4118-
"nccl", rank=rank, world_size=world_size, init_method="env://"
4119-
)
4120-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
4122+
if device_type == "cuda":
4123+
torch.distributed.init_process_group(
4124+
"nccl", rank=rank, world_size=world_size, init_method="env://"
4125+
)
4126+
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
4127+
elif device_type == "cpu":
4128+
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
4129+
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
4130+
41214131
except Exception as e:
41224132
raise EnvironmentError(
41234133
"We tried to initialize torch.distributed for you, but it failed, make"
41244134
"sure you init torch distributed in your script to use `tp_plan='auto'`"
41254135
) from e
41264136

4127-
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
4128-
device_type = torch._C._get_accelerator().type
4129-
tp_device = torch.device(device_type, torch.cuda.current_device())
4130-
if tp_device.index > 0:
4137+
# Get device with index assuming equal number of devices per host
4138+
index = None if device_type == "cpu" else torch.cuda.current_device()
4139+
tp_device = torch.device(device_type, index)
4140+
4141+
if index is not None and index > 0:
41314142
import sys
41324143

41334144
sys.stdout = open(os.devnull, "w")

tests/tensor_parallel/test_tensor_parallel.py

+28-98
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import subprocess
1716
import tempfile
1817
import textwrap
1918

20-
# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
2119
from transformers import is_torch_available
22-
from transformers.models.llama.configuration_llama import LlamaConfig
23-
from transformers.models.llama.modeling_llama import LlamaModel
2420
from transformers.testing_utils import (
2521
TestCasePlus,
26-
execute_subprocess_async,
2722
get_torch_dist_unique_port,
2823
require_torch_multi_gpu,
2924
)
@@ -33,15 +28,18 @@
3328
import torch
3429

3530

31+
# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
3632
class TestTensorParallel(TestCasePlus):
33+
nproc_per_node = 2
34+
3735
def torchrun(self, script: str):
3836
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
3937
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
4038
tmp.write(script)
4139
tmp.flush()
4240
tmp.seek(0)
4341
cmd = (
44-
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}"
42+
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
4543
).split()
4644

4745
# Note that the subprocess will be waited for here, and raise an error if not successful
@@ -50,44 +48,39 @@ def torchrun(self, script: str):
5048
except subprocess.CalledProcessError as e:
5149
raise Exception(f"The following error was captured: {e.stderr}")
5250

53-
@require_torch_multi_gpu
54-
def test_tp(self):
55-
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
56-
--master_port={get_torch_dist_unique_port()}
57-
{self.test_file_dir}/test_tp.py
58-
""".split()
59-
output_dir = self.get_auto_remove_tmp_dir()
60-
args = f"--output_dir {output_dir} --report_to none".split()
61-
cmd = ["torchrun"] + distributed_args + args
62-
print(cmd)
63-
execute_subprocess_async(cmd, env=self.get_env())
64-
# successful return here == success - any errors would have caused an error in the sub-call
65-
66-
@require_torch_multi_gpu
67-
def test_loading_memory_consumption(self):
51+
def test_model_forward(self):
6852
script_to_run = textwrap.dedent(
6953
"""
7054
import torch
7155
import os
72-
from transformers import AutoModelForCausalLM
56+
from transformers import AutoModelForCausalLM, AutoTokenizer
7357
74-
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
58+
model_id = "JackFram/llama-68m"
7559
7660
rank = int(os.environ["RANK"])
7761
world_size = int(os.environ["WORLD_SIZE"])
78-
device = torch.device(f"cuda:{rank}")
79-
torch.distributed.init_process_group("nccl", device_id=device)
8062
81-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
63+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
8264
torch.distributed.barrier()
8365
84-
# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
85-
expected_model_memory_per_device = (16 / world_size) + 1
86-
overhead_factor = 1.2
66+
has_dtensor = 0
67+
for name, parameter in model.named_parameters():
68+
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
69+
has_dtensor = 1
70+
break
71+
72+
assert has_dtensor == 1, "TP model must has DTensor"
73+
74+
tokenizer = AutoTokenizer.from_pretrained(model_id)
75+
prompt = "Can I help"
76+
77+
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
78+
outputs = model(inputs)
8779
88-
# Check that we do not use more than the expected sharded size during initialization
89-
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
90-
raise ValueError("Loading the model used more than the expected fraction of model size per device")
80+
next_token_logits = outputs[0][:, -1, :]
81+
next_token = torch.argmax(next_token_logits, dim=-1)
82+
response = tokenizer.decode(next_token)
83+
assert response == "with"
9184
9285
torch.distributed.barrier()
9386
torch.distributed.destroy_process_group()
@@ -96,69 +89,6 @@ def test_loading_memory_consumption(self):
9689
self.torchrun(script_to_run)
9790

9891

99-
if __name__ == "__main__":
100-
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
101-
# CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
102-
# or
103-
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
104-
105-
if not is_torch_available():
106-
exit(0)
107-
108-
# Test settings
109-
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
110-
bs = 1
111-
seqlen = 4096
112-
# Get distributed settings
113-
rank = int(os.environ["RANK"])
114-
world_size = int(os.environ["WORLD_SIZE"])
115-
116-
# Initialize distributed
117-
device = torch.device(f"cuda:{rank}")
118-
torch.distributed.init_process_group("nccl", device_id=device)
119-
device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))
120-
121-
# Get model config
122-
config = LlamaConfig.from_pretrained(model_id)
123-
config.hidden_size = 2048
124-
config.attention_bias = False
125-
# Instantiate model
126-
with device:
127-
model = LlamaModel(config).to(dtype=torch.float16)
128-
129-
model.eval()
130-
# Tensor Parallel
131-
if world_size > 1:
132-
model.tensor_parallel(device_mesh)
133-
# Run model
134-
135-
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
136-
137-
# Test cuda graphing explicitly
138-
with torch.cuda.device(device):
139-
print("Cuda graphing")
140-
with torch.no_grad():
141-
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
142-
# CUDA Graph setup
143-
s = torch.cuda.Stream(device=device)
144-
s.wait_stream(torch.cuda.current_stream())
145-
with torch.cuda.stream(s):
146-
for i in range(3):
147-
out = model(inputs)
148-
torch.cuda.current_stream().wait_stream(s)
149-
g = torch.cuda.CUDAGraph()
150-
with torch.cuda.graph(g):
151-
out = model(inputs)
152-
153-
for _ in range(2):
154-
g.replay()
155-
s.synchronize()
156-
157-
assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])
158-
159-
# Test compile
160-
with torch.no_grad():
161-
out = model(inputs)
162-
model.forward = torch.compile(model.forward, mode="reduce-overhead")
163-
out = model(inputs)
164-
out = model(inputs)
92+
@require_torch_multi_gpu
93+
class TestTensorParallelCuda(TestTensorParallel):
94+
nproc_per_node = torch.cuda.device_count()

0 commit comments

Comments
 (0)