Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable tp on CPU #36299

Merged
merged 21 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions docs/source/en/perf_infer_gpu_multi.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# initialize distributed environment
rank = int(os.environ["RANK"])
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
torch.distributed.init_process_group("nccl", device_id=device)

# enable tensor parallelism
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -59,7 +54,7 @@ model = AutoModelForCausalLM.from_pretrained(
# prepare input tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

# distributed run
outputs = model(inputs)
Expand All @@ -71,6 +66,13 @@ Launch the inference script above on [torchrun](https://pytorch.org/docs/stable/
torchrun --nproc-per-node 4 demo.py
```

For CPU, please binding different socket on each rank. For example, if you are using Intel 4th Gen Xeon:
```bash
export OMP_NUM_THREADS=56
numactl -C 0-55 -m 0 torchrun --nnodes=2 --node_rank=0 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & numactl -C 56-111 -m 1 torchrun --nnodes=2 --node_rank=1 --master_addr="127.0.0.1" --master_port=29500 --nproc-per-node 1 demo.py & wait
```
The CPU benchmark data will be released soon.

You can benefit from considerable speed ups for inference, especially for inputs with large batch size or long sequences.

For a single forward pass on [Llama](./model_doc/llama) with a sequence length of 512 and various batch sizes, you can expect the following speed ups.
Expand Down
29 changes: 20 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,8 @@ def _load_state_dict_into_meta_model(
"""
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map[""] not in ("cpu", torch.device("cpu")):
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])

Expand Down Expand Up @@ -4110,24 +4111,34 @@ def from_pretrained(
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type

if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)

except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
tp_device = torch.device(device_type, torch.cuda.current_device())
if tp_device.index > 0:
# Get device with index assuming equal number of devices per host
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)

if index is not None and index > 0:
import sys

sys.stdout = open(os.devnull, "w")
Expand Down
126 changes: 28 additions & 98 deletions tests/tensor_parallel/test_tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import subprocess
import tempfile
import textwrap

# TORCH_LOGS=+dtensor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py
from transformers import is_torch_available
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from transformers.testing_utils import (
TestCasePlus,
execute_subprocess_async,
get_torch_dist_unique_port,
require_torch_multi_gpu,
)
Expand All @@ -33,15 +28,18 @@
import torch


# RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py
class TestTensorParallel(TestCasePlus):
nproc_per_node = 2

def torchrun(self, script: str):
"""Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary."""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp:
tmp.write(script)
tmp.flush()
tmp.seek(0)
cmd = (
f"torchrun --nproc_per_node {torch.cuda.device_count()} --master_port {get_torch_dist_unique_port()} {tmp.name}"
f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}"
).split()

# Note that the subprocess will be waited for here, and raise an error if not successful
Expand All @@ -50,44 +48,39 @@ def torchrun(self, script: str):
except subprocess.CalledProcessError as e:
raise Exception(f"The following error was captured: {e.stderr}")

@require_torch_multi_gpu
def test_tp(self):
distributed_args = f"""--nproc_per_node={torch.cuda.device_count()}
--master_port={get_torch_dist_unique_port()}
{self.test_file_dir}/test_tp.py
""".split()
output_dir = self.get_auto_remove_tmp_dir()
args = f"--output_dir {output_dir} --report_to none".split()
cmd = ["torchrun"] + distributed_args + args
print(cmd)
execute_subprocess_async(cmd, env=self.get_env())
# successful return here == success - any errors would have caused an error in the sub-call

@require_torch_multi_gpu
def test_loading_memory_consumption(self):
def test_model_forward(self):
script_to_run = textwrap.dedent(
"""
import torch
import os
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_id = "JackFram/llama-68m"

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, tp_plan="auto")
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", tp_plan="auto")
torch.distributed.barrier()

# The expected model memory footprint. We add 1 as not all the modules are split (e.g. the embeddings)
expected_model_memory_per_device = (16 / world_size) + 1
overhead_factor = 1.2
has_dtensor = 0
for name, parameter in model.named_parameters():
if isinstance(parameter.data, torch.distributed.tensor.DTensor):
has_dtensor = 1
break

assert has_dtensor == 1, "TP model must has DTensor"

tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Can I help"

inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
outputs = model(inputs)

# Check that we do not use more than the expected sharded size during initialization
if torch.cuda.max_memory_allocated(device) / 1024**3 > expected_model_memory_per_device * overhead_factor:
raise ValueError("Loading the model used more than the expected fraction of model size per device")
next_token_logits = outputs[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
response = tokenizer.decode(next_token)
assert response == "with"

torch.distributed.barrier()
torch.distributed.destroy_process_group()
Expand All @@ -96,69 +89,6 @@ def test_loading_memory_consumption(self):
self.torchrun(script_to_run)


if __name__ == "__main__":
# The script below is meant to be run under torch.distributed, on a machine with multiple GPUs:
# CUDA_VISIBLE_DEVICES=0,1 RUN_SLOW=1 pytest -sv tests/tp/test_tp.py
# or
# PYTHONPATH="src" python -m torch.distributed.run --nproc_per_node 2 ./tests/tp/test_tp.py

if not is_torch_available():
exit(0)

# Test settings
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
bs = 1
seqlen = 4096
# Get distributed settings
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

# Initialize distributed
device = torch.device(f"cuda:{rank}")
torch.distributed.init_process_group("nccl", device_id=device)
device_mesh = torch.distributed.init_device_mesh("cuda", (world_size,))

# Get model config
config = LlamaConfig.from_pretrained(model_id)
config.hidden_size = 2048
config.attention_bias = False
# Instantiate model
with device:
model = LlamaModel(config).to(dtype=torch.float16)

model.eval()
# Tensor Parallel
if world_size > 1:
model.tensor_parallel(device_mesh)
# Run model

inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)

# Test cuda graphing explicitly
with torch.cuda.device(device):
print("Cuda graphing")
with torch.no_grad():
inputs = torch.randint(config.vocab_size, (bs, seqlen), device=device)
# CUDA Graph setup
s = torch.cuda.Stream(device=device)
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
out = model(inputs)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(inputs)

for _ in range(2):
g.replay()
s.synchronize()

assert out.last_hidden_state.shape == torch.Size([bs, seqlen, config.hidden_size])

# Test compile
with torch.no_grad():
out = model(inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead")
out = model(inputs)
out = model(inputs)
@require_torch_multi_gpu
class TestTensorParallelCuda(TestTensorParallel):
nproc_per_node = torch.cuda.device_count()