From 8a76b421a0f04adcfd8fc153a3feded231570b02 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 25 Feb 2025 17:07:30 -0800 Subject: [PATCH 1/3] Add distributed support in the CIFAR example --- cifar/dataset.py | 11 +++++- cifar/main.py | 91 +++++++++++++++++++++++++++++++----------------- 2 files changed, 70 insertions(+), 32 deletions(-) diff --git a/cifar/dataset.py b/cifar/dataset.py index 22b229f88..8967591e9 100644 --- a/cifar/dataset.py +++ b/cifar/dataset.py @@ -1,3 +1,4 @@ +import mlx.core as mx import numpy as np from mlx.data.datasets import load_cifar10 @@ -12,8 +13,11 @@ def normalize(x): x = x.astype("float32") / 255.0 return (x - mean) / std + group = mx.distributed.init() + tr_iter = ( tr.shuffle() + .partition_if(group.size() > 1, group.size(), group.rank()) .to_stream() .image_random_h_flip("image", prob=0.5) .pad("image", 0, 4, 4, 0.0) @@ -25,6 +29,11 @@ def normalize(x): ) test = load_cifar10(root=root, train=False) - test_iter = test.to_stream().key_transform("image", normalize).batch(batch_size) + test_iter = ( + test.to_stream() + .partition_if(group.size() > 1, group.size(), group.rank()) + .key_transform("image", normalize) + .batch(batch_size) + ) return tr_iter, test_iter diff --git a/cifar/main.py b/cifar/main.py index 378bc4241..270741337 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -23,6 +23,13 @@ parser.add_argument("--cpu", action="store_true", help="use cpu only") +def print_zero(group, *args, **kwargs): + if group.rank() != 0: + return + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + def eval_fn(model, inp, tgt): return mx.mean(mx.argmax(model(inp), axis=1) == tgt) @@ -34,9 +41,23 @@ def train_step(model, inp, tgt): acc = mx.mean(mx.argmax(output, axis=1) == tgt) return loss, acc - losses = [] - accs = [] - samples_per_sec = [] + world = mx.distributed.init() + losses = 0 + accuracies = 0 + samples_per_sec = 0 + count = 0 + + def average_stats(stats, count): + if world.size() == 1: + return [s / count for s in stats] + + with mx.stream(mx.cpu): + stats = mx.distributed.all_sum(mx.array(stats)) + count = mx.distributed.all_sum(count) + mx.eval(stats, count) + count = count.item() + + return [s / count for s in stats.tolist()] state = [model.state, optimizer.state] @@ -44,6 +65,7 @@ def train_step(model, inp, tgt): def step(inp, tgt): train_step_fn = nn.value_and_grad(model, train_step) (loss, acc), grads = train_step_fn(model, inp, tgt) + grads = nn.utils.average_gradients(grads) optimizer.update(model, grads) return loss, acc @@ -52,69 +74,76 @@ def step(inp, tgt): y = mx.array(batch["label"]) tic = time.perf_counter() loss, acc = step(x, y) - mx.eval(state) + mx.eval(loss, acc, state) toc = time.perf_counter() - loss = loss.item() - acc = acc.item() - losses.append(loss) - accs.append(acc) - throughput = x.shape[0] / (toc - tic) - samples_per_sec.append(throughput) + losses += loss.item() + accuracies += acc.item() + samples_per_sec += x.shape[0] / (toc - tic) + count += 1 if batch_counter % 10 == 0: - print( + l, a, s = average_stats([losses, accuracies, samples_per_sec], count) + print_zero( + world, " | ".join( ( f"Epoch {epoch:02d} [{batch_counter:03d}]", - f"Train loss {loss:.3f}", - f"Train acc {acc:.3f}", - f"Throughput: {throughput:.2f} images/second", + f"Train loss {l:.3f}", + f"Train acc {a:.3f}", + f"Throughput: {s:.2f} images/second", ) - ) + ), ) - mean_tr_loss = mx.mean(mx.array(losses)) - mean_tr_acc = mx.mean(mx.array(accs)) - samples_per_sec = mx.mean(mx.array(samples_per_sec)) - return mean_tr_loss, mean_tr_acc, samples_per_sec + return average_stats([losses, accuracies, samples_per_sec], count) def test_epoch(model, test_iter, epoch): - accs = [] + accuracies = 0 + count = 0 for batch_counter, batch in enumerate(test_iter): x = mx.array(batch["image"]) y = mx.array(batch["label"]) acc = eval_fn(model, x, y) - acc_value = acc.item() - accs.append(acc_value) - mean_acc = mx.mean(mx.array(accs)) - return mean_acc + accuracies += acc.item() + count += 1 + + with mx.stream(mx.cpu): + accuracies = mx.distributed.all_sum(accuracies) + count = mx.distributed.all_sum(count) + return (accuracies / count).item() def main(args): mx.random.seed(args.seed) + # Initialize the distributed group and report the nodes that showed up + world = mx.distributed.init() + if world.size() > 1: + print(f"{world.rank()} of {world.size()}", flush=True) + model = getattr(resnet, args.arch)() - print("Number of params: {:0.04f} M".format(model.num_params() / 1e6)) + print_zero(world, f"Number of params: {model.num_params() / 1e6:0.04f} M") optimizer = optim.Adam(learning_rate=args.lr) train_data, test_data = get_cifar10(args.batch_size) for epoch in range(args.epochs): tr_loss, tr_acc, throughput = train_epoch(model, train_data, optimizer, epoch) - print( + print_zero( + world, " | ".join( ( f"Epoch: {epoch}", - f"avg. Train loss {tr_loss.item():.3f}", - f"avg. Train acc {tr_acc.item():.3f}", - f"Throughput: {throughput.item():.2f} images/sec", + f"avg. Train loss {tr_loss:.3f}", + f"avg. Train acc {tr_acc:.3f}", + f"Throughput: {throughput:.2f} images/sec", ) - ) + ), ) test_acc = test_epoch(model, test_data, epoch) - print(f"Epoch: {epoch} | Test acc {test_acc.item():.3f}") + print_zero(world, f"Epoch: {epoch} | Test acc {test_acc:.3f}") train_data.reset() test_data.reset() From 14faec4ca2794d56793a7080b8f3c5fe073c087c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 25 Feb 2025 17:20:15 -0800 Subject: [PATCH 2/3] Fix the throughput calculation --- cifar/main.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cifar/main.py b/cifar/main.py index 270741337..7eb6efdf5 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -81,7 +81,10 @@ def step(inp, tgt): samples_per_sec += x.shape[0] / (toc - tic) count += 1 if batch_counter % 10 == 0: - l, a, s = average_stats([losses, accuracies, samples_per_sec], count) + l, a, s = average_stats( + [losses, accuracies, world.size() * samples_per_sec], + count, + ) print_zero( world, " | ".join( @@ -94,7 +97,7 @@ def step(inp, tgt): ), ) - return average_stats([losses, accuracies, samples_per_sec], count) + return average_stats([losses, accuracies, world.size() * samples_per_sec], count) def test_epoch(model, test_iter, epoch): From d20413a54d0017291ca93390df67d829fdf1694b Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 25 Feb 2025 17:40:24 -0800 Subject: [PATCH 3/3] Add it to the readme and fix the rank printing in main --- cifar/README.md | 14 ++++++++++++++ cifar/main.py | 2 +- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/cifar/README.md b/cifar/README.md index 763e641dd..2016200df 100644 --- a/cifar/README.md +++ b/cifar/README.md @@ -48,3 +48,17 @@ Note this was run on an M1 Macbook Pro with 16GB RAM. At the time of writing, `mlx` doesn't have built-in learning rate schedules. We intend to update this example once these features are added. + +## Distributed training + +The example also supports distributed data parallel training. You can launch a +distributed training as follows: + +```shell +$ cat >hostfile.json +[ + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]}, + {"ssh": "host-to-ssh-to", "ips": ["ip-to-bind-to"]} +] +$ mlx.launch --verbose --hostfile hostfile.json main.py --batch 256 --epochs 5 --arch resnet20 +``` diff --git a/cifar/main.py b/cifar/main.py index 7eb6efdf5..3fe5d2e09 100644 --- a/cifar/main.py +++ b/cifar/main.py @@ -122,7 +122,7 @@ def main(args): # Initialize the distributed group and report the nodes that showed up world = mx.distributed.init() if world.size() > 1: - print(f"{world.rank()} of {world.size()}", flush=True) + print(f"Starting rank {world.rank()} of {world.size()}", flush=True) model = getattr(resnet, args.arch)()