Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed support cifar #1301

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cifar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM.

At the time of writing, `mlx` doesn't have built-in learning rate schedules.
We intend to update this example once these features are added.

## Distributed training

The example also supports distributed data parallel training. You can launch a
distributed training as follows:

```shell
$ cat >hostfile.json
[
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]},
{"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}
]
$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20
```
11 changes: 10 additions & 1 deletion cifar/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mlx.core as mx
import numpy as np
from mlx.data.datasets import load_cifar10

Expand All @@ -12,8 +13,11 @@ def normalize(x):
x = x.astype("float32") / 255.0
return (x - mean) / std

group = mx.distributed.init()

tr_iter = (
tr.shuffle()
.partition_if(group.size() > 1, group.size(), group.rank())
.to_stream()
.image_random_h_flip("image", prob=0.5)
.pad("image", 0, 4, 4, 0.0)
Expand All @@ -25,6 +29,11 @@ def normalize(x):
)

test = load_cifar10(root=root, train=False)
test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size)
test_iter = (
test.to_stream()
.partition_if(group.size() > 1, group.size(), group.rank())
.key_transform("image", normalize)
.batch(batch_size)
)

return tr_iter, test_iter
94 changes: 63 additions & 31 deletions cifar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
parser.add_argument("--cpu", action="store_true", help="use cpu only")


def print_zero(group, *args, **kwargs):
if group.rank() != 0:
return
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)


def eval_fn(model, inp, tgt):
return mx.mean(mx.argmax(model(inp), axis=1) == tgt)

Expand All @@ -34,16 +41,31 @@ def train_step(model, inp, tgt):
acc = mx.mean(mx.argmax(output, axis=1) == tgt)
return loss, acc

losses = []
accs = []
samples_per_sec = []
world = mx.distributed.init()
losses = 0
accuracies = 0
samples_per_sec = 0
count = 0

def average_stats(stats, count):
if world.size() == 1:
return [s / count for s in stats]

with mx.stream(mx.cpu):
stats = mx.distributed.all_sum(mx.array(stats))
count = mx.distributed.all_sum(count)
mx.eval(stats, count)
count = count.item()

return [s / count for s in stats.tolist()]
Comment on lines +57 to +60
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mx.eval(stats, count)
count = count.item()
return [s / count for s in stats.tolist()]
return (stats / count).tolist()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha I don't know what I was thinking 🤦‍♂️


state = [model.state, optimizer.state]

@partial(mx.compile, inputs=state, outputs=state)
def step(inp, tgt):
train_step_fn = nn.value_and_grad(model, train_step)
(loss, acc), grads = train_step_fn(model, inp, tgt)
grads = nn.utils.average_gradients(grads)
optimizer.update(model, grads)
return loss, acc

Expand All @@ -52,69 +74,79 @@ def step(inp, tgt):
y = mx.array(batch["label"])
tic = time.perf_counter()
loss, acc = step(x, y)
mx.eval(state)
mx.eval(loss, acc, state)
toc = time.perf_counter()
loss = loss.item()
acc = acc.item()
losses.append(loss)
accs.append(acc)
throughput = x.shape[0] / (toc - tic)
samples_per_sec.append(throughput)
losses += loss.item()
accuracies += acc.item()
samples_per_sec += x.shape[0] / (toc - tic)
count += 1
if batch_counter % 10 == 0:
print(
l, a, s = average_stats(
[losses, accuracies, world.size() * samples_per_sec],
count,
)
print_zero(
world,
" | ".join(
(
f"Epoch {epoch:02d} [{batch_counter:03d}]",
f"Train loss {loss:.3f}",
f"Train acc {acc:.3f}",
f"Throughput: {throughput:.2f} images/second",
f"Train loss {l:.3f}",
f"Train acc {a:.3f}",
f"Throughput: {s:.2f} images/second",
)
)
),
)

mean_tr_loss = mx.mean(mx.array(losses))
mean_tr_acc = mx.mean(mx.array(accs))
samples_per_sec = mx.mean(mx.array(samples_per_sec))
return mean_tr_loss, mean_tr_acc, samples_per_sec
return average_stats([losses, accuracies, world.size() * samples_per_sec], count)


def test_epoch(model, test_iter, epoch):
accs = []
accuracies = 0
count = 0
for batch_counter, batch in enumerate(test_iter):
x = mx.array(batch["image"])
y = mx.array(batch["label"])
acc = eval_fn(model, x, y)
acc_value = acc.item()
accs.append(acc_value)
mean_acc = mx.mean(mx.array(accs))
return mean_acc
accuracies += acc.item()
count += 1

with mx.stream(mx.cpu):
accuracies = mx.distributed.all_sum(accuracies)
count = mx.distributed.all_sum(count)
return (accuracies / count).item()


def main(args):
mx.random.seed(args.seed)

# Initialize the distributed group and report the nodes that showed up
world = mx.distributed.init()
if world.size() > 1:
print(f"Starting rank {world.rank()} of {world.size()}", flush=True)

model = getattr(resnet, args.arch)()

print("Number of params: {:0.04f} M".format(model.num_params() / 1e6))
print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M")

optimizer = optim.Adam(learning_rate=args.lr)

train_data, test_data = get_cifar10(args.batch_size)
for epoch in range(args.epochs):
tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch)
print(
print_zero(
world,
" | ".join(
(
f"Epoch: {epoch}",
f"avg. Train loss {tr_loss.item():.3f}",
f"avg. Train acc {tr_acc.item():.3f}",
f"Throughput: {throughput.item():.2f} images/sec",
f"avg. Train loss {tr_loss:.3f}",
f"avg. Train acc {tr_acc:.3f}",
f"Throughput: {throughput:.2f} images/sec",
)
)
),
)

test_acc = test_epoch(model, test_data, epoch)
print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}")
print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}")

train_data.reset()
test_data.reset()
Expand Down