|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +# DeepSpeed Team |
| 5 | + |
| 6 | +import time |
| 7 | +import argparse |
| 8 | + |
| 9 | +import deepspeed.comm as dist |
| 10 | +from deepspeed.accelerator import get_accelerator |
| 11 | +import torch |
| 12 | + |
| 13 | +import deepspeed |
| 14 | +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum |
| 15 | + |
| 16 | + |
| 17 | +class SimpleModel(torch.nn.Module): |
| 18 | + |
| 19 | + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): |
| 20 | + super(SimpleModel, self).__init__() |
| 21 | + self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)]) |
| 22 | + if empty_grad: |
| 23 | + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim) |
| 24 | + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() |
| 25 | + |
| 26 | + def forward(self, x, y): |
| 27 | + for l in self.linears: |
| 28 | + x = l(x) |
| 29 | + return self.cross_entropy_loss(x, y) |
| 30 | + |
| 31 | + |
| 32 | +def random_dataset(total_samples, hidden_dim, device, dtype): |
| 33 | + train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype) |
| 34 | + train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim) |
| 35 | + train_dataset = torch.utils.data.TensorDataset(train_data, train_label) |
| 36 | + return train_dataset |
| 37 | + |
| 38 | + |
| 39 | +def random_dataloader(model, total_samples, hidden_dim, device, dtype): |
| 40 | + batch_size = model.train_micro_batch_size_per_gpu() |
| 41 | + train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype) |
| 42 | + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size) |
| 43 | + return train_loader |
| 44 | + |
| 45 | + |
| 46 | +def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking, iteration, warmup): |
| 47 | + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) |
| 48 | + data_loader = random_dataloader(model=model, |
| 49 | + total_samples=iteration, |
| 50 | + hidden_dim=hidden_dim, |
| 51 | + device=model.device, |
| 52 | + dtype=dtype) |
| 53 | + |
| 54 | + time_offload_list = [] |
| 55 | + time_load_list = [] |
| 56 | + |
| 57 | + dist.barrier() |
| 58 | + for i, batch in enumerate(data_loader): |
| 59 | + loss = model(batch[0], batch[1]) |
| 60 | + model.backward(loss) |
| 61 | + model.step() |
| 62 | + |
| 63 | + # Start offloading |
| 64 | + alloc_before_offload = get_accelerator().memory_allocated() |
| 65 | + dist.barrier() |
| 66 | + |
| 67 | + time_start = time.time() |
| 68 | + model.offload_states(include=include, |
| 69 | + device=OffloadDeviceEnum.cpu, |
| 70 | + pin_memory=pin_memory, |
| 71 | + non_blocking=non_blocking) |
| 72 | + dist.barrier() |
| 73 | + time_after_offload = time.time() |
| 74 | + alloc_after_offload = get_accelerator().memory_allocated() |
| 75 | + assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload" |
| 76 | + |
| 77 | + # Load offloaded states back |
| 78 | + model.reload_states() |
| 79 | + dist.barrier() |
| 80 | + time_after_load = time.time() |
| 81 | + |
| 82 | + time_offload_list.append(time_after_offload - time_start) |
| 83 | + time_load_list.append(time_after_load - time_after_offload) |
| 84 | + |
| 85 | + assert alloc_after_offload < get_accelerator().memory_allocated( |
| 86 | + ), f"Allocated memory should increase after offload back" |
| 87 | + |
| 88 | + if dist.get_rank() == 0: |
| 89 | + print( |
| 90 | + f"Memory usage ({i}): include={include}, pin_memory={pin_memory}, non_blocking={non_blocking} alloc_before_offload={alloc_before_offload} alloc_after_offload={alloc_after_offload}" |
| 91 | + ) |
| 92 | + |
| 93 | + # remove warmup |
| 94 | + time_offload_list = time_offload_list[warmup:] |
| 95 | + time_load_list = time_load_list[warmup:] |
| 96 | + |
| 97 | + if dist.get_rank() == 0: |
| 98 | + with open("offload_states.log", "a") as f: |
| 99 | + offload_time = sum(time_offload_list) / len(time_offload_list) |
| 100 | + load_time = sum(time_load_list) / len(time_load_list) |
| 101 | + msg = f"{1 if pin_memory else 0},{1 if non_blocking else 0},{offload_time},{load_time}" |
| 102 | + f.write(f"{msg}\n") |
| 103 | + print(f"Summary: pin_memory={pin_memory} non_blocking={non_blocking} offload={offload_time} load={load_time}") |
| 104 | + |
| 105 | + # Needed in ZeRO 3. Not doing so can give memory leak |
| 106 | + model.destroy() |
| 107 | + |
| 108 | + |
| 109 | +def main(): |
| 110 | + parser = argparse.ArgumentParser(description="Test Offload States") |
| 111 | + parser.add_argument("--included_state", type=str, choices=[e.name for e in OffloadStateTypeEnum] + [None], default=None, help="State to include") |
| 112 | + parser.add_argument("--pin_memory", action='store_true', help="Pin memory") |
| 113 | + parser.add_argument("--non_blocking", action='store_true', help="Non blocking") |
| 114 | + parser.add_argument("--nlayers", type=int, default=1, help="Number of layers") |
| 115 | + parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension") |
| 116 | + parser.add_argument('--dtype', choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16', help='Data type') |
| 117 | + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank") |
| 118 | + parser.add_argument("--iteration", type=int, default=10, help="Warmup") |
| 119 | + parser.add_argument("--warmup", type=int, default=5, help="Warmup") |
| 120 | + |
| 121 | + args = parser.parse_args() |
| 122 | + |
| 123 | + dtype = eval(args.dtype) |
| 124 | + hidden_dim = args.hidden_dim |
| 125 | + |
| 126 | + config_dict = { |
| 127 | + "train_micro_batch_size_per_gpu": 1, |
| 128 | + "optimizer": { |
| 129 | + "type": "Adam", |
| 130 | + "params": { |
| 131 | + "lr": 1e-6 |
| 132 | + } |
| 133 | + }, |
| 134 | + "zero_optimization": { |
| 135 | + "stage": 3, |
| 136 | + }, |
| 137 | + } |
| 138 | + |
| 139 | + if dtype == torch.float16: |
| 140 | + config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8} |
| 141 | + elif dtype == torch.bfloat16: |
| 142 | + config_dict["bf16"] = {"enabled": True} |
| 143 | + |
| 144 | + with deepspeed.zero.Init(config_dict_or_path=config_dict): |
| 145 | + model = SimpleModel(hidden_dim, nlayers=args.nlayers) |
| 146 | + |
| 147 | + included_state = None if args.included_state is None else [OffloadStateTypeEnum[args.included_state]] |
| 148 | + run_model(model, config_dict, hidden_dim, dtype, included_state, args.pin_memory, args.non_blocking, args.iteration, args.warmup) |
| 149 | + |
| 150 | + |
| 151 | +if __name__ == "__main__": |
| 152 | + main() |
0 commit comments