Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial support of Apple M1/M2 (via CPU mode) #504

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
from pathlib import Path
from typing import List, Literal, Optional, Tuple, TypedDict
from .utils import default_device, model_device, distrubuted_device

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -58,21 +59,26 @@ def build(
max_batch_size: int,
model_parallel_size: Optional[int] = None,
) -> "Llama":

if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
initialize_model_parallel(
model_parallel_size,
model_parallel_backend=distrubuted_device()
)

if torch.cuda.is_available():
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if local_rank > 0:
sys.stdout = open(os.devnull, "w")
device = default_device()
# seed must be the same in all processes
torch.manual_seed(1)

if local_rank > 0:
sys.stdout = open(os.devnull, "w")

start_time = time.time()
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
Expand All @@ -81,7 +87,7 @@ def build(
checkpoints
), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
ckpt_path = checkpoints[get_model_parallel_rank()]
checkpoint = torch.load(ckpt_path, map_location="cpu")
checkpoint = torch.load(ckpt_path, map_location=model_device())
with open(Path(ckpt_dir) / "params.json", "r") as f:
params = json.loads(f.read())

Expand All @@ -92,7 +98,9 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model_args.device = device
if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
print(f"Loaded in {time.time() - start_time:.2f} seconds")
Expand Down Expand Up @@ -123,14 +131,14 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=default_device())
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=default_device())
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=default_device())

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=default_device())
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
Expand Down
6 changes: 4 additions & 2 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from .utils import default_device

import fairscale.nn.model_parallel.initialize as fs_init
import torch
Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.device = args.device if args.device is not None else default_device()

self.wq = ColumnParallelLinear(
args.dim,
Expand Down Expand Up @@ -132,15 +134,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(self.device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(self.device)

def forward(
self,
Expand Down
37 changes: 37 additions & 0 deletions llama/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import platform
import torch

# setting False since MPS not yet supported BFloat16 that is required for LLama2
enable_mps = False


def is_it_apple_arm():
if platform.system() != 'Darwin':
return False
if platform.machine() != 'arm64':
return False
return True


def distrubuted_device():
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo, I recommend to refactor it

Suggested change
def distrubuted_device():
def distributed_device():

here
and here

if torch.cuda.is_available():
return "nccl"
else:
return "gloo"


def default_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif is_it_apple_arm() and enable_mps:
return torch.device("mps")
else:
return torch.device("cpu")


def model_device():
if is_it_apple_arm() and enable_mps:
return torch.device("mps")
else:
# for CUDA we also want to us CPU for model
return torch.device("cpu")