Skip to content

Commit fd79b31

Browse files
authored
Example and benchmark of APIs to offload states (#942)
* add benchmarking for offloading states * fix api names
1 parent be0a0e1 commit fd79b31

File tree

4 files changed

+233
-0
lines changed

4 files changed

+233
-0
lines changed

training/offload_states/README.md

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Offloading States Example
2+
3+
The script `offload_states.py` demonstrates how to offload the state of a model. Here is the example usage.
4+
5+
```bash
6+
$ deepspeed --num_gpus=4 offload_states.py --hidden_dim 32768 --nlayers 4 --pin_memory --non_blocking
7+
...
8+
Memory usage (0): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198419456 alloc_after_offload=17763840
9+
Memory usage (1): include=None, pin_memory=True, non_blocking=True alloc_before_offload=18198760960 alloc_after_offload=17763840
10+
...
11+
Summary: pin_memory=True non_blocking=True offload=5.643414640426636 load=2.4087101459503173
12+
```
13+
14+
`run_benchmark.sh` shows how to run the script with different configurations. The script outputs the time for offloading and loading the states.
15+
16+
```bash
17+
$ ./run_benchmark.sh
18+
...
19+
| |pin_memory=0_non_blocking=0|pin_memory=0_non_blocking=1|pin_memory=1_non_blocking=0|pin_memory=1_non_blocking=1|
20+
|--:|---------------------------|---------------------------|---------------------------|---------------------------|
21+
| 1|4.34 / 3.42 |4.99 / 2.37 |6.5 / 2.42 |6.0 / 2.39 |
22+
| 2|9.9 / 3.28 |5.1 / 2.34 |6.21 / 2.42 |6.25 / 2.45 |
23+
| 3|9.92 / 3.19 |6.71 / 2.35 |6.33 / 2.38 |5.93 / 2.42 |
24+
| 4|9.55 / 2.82 |7.11 / 2.39 |6.9 / 2.38 |6.5 / 2.43 |...
25+
```
+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
import time
7+
import argparse
8+
9+
import deepspeed.comm as dist
10+
from deepspeed.accelerator import get_accelerator
11+
import torch
12+
13+
import deepspeed
14+
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
15+
16+
17+
class SimpleModel(torch.nn.Module):
18+
19+
def __init__(self, hidden_dim, empty_grad=False, nlayers=1):
20+
super(SimpleModel, self).__init__()
21+
self.linears = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim) for _ in range(nlayers)])
22+
if empty_grad:
23+
self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
24+
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
25+
26+
def forward(self, x, y):
27+
for l in self.linears:
28+
x = l(x)
29+
return self.cross_entropy_loss(x, y)
30+
31+
32+
def random_dataset(total_samples, hidden_dim, device, dtype):
33+
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
34+
train_label = torch.empty(total_samples, dtype=torch.long, device=device).random_(hidden_dim)
35+
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
36+
return train_dataset
37+
38+
39+
def random_dataloader(model, total_samples, hidden_dim, device, dtype):
40+
batch_size = model.train_micro_batch_size_per_gpu()
41+
train_dataset = random_dataset(total_samples, hidden_dim, device, dtype=dtype)
42+
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
43+
return train_loader
44+
45+
46+
def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking, iteration, warmup):
47+
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
48+
data_loader = random_dataloader(model=model,
49+
total_samples=iteration,
50+
hidden_dim=hidden_dim,
51+
device=model.device,
52+
dtype=dtype)
53+
54+
time_offload_list = []
55+
time_load_list = []
56+
57+
dist.barrier()
58+
for i, batch in enumerate(data_loader):
59+
loss = model(batch[0], batch[1])
60+
model.backward(loss)
61+
model.step()
62+
63+
# Start offloading
64+
alloc_before_offload = get_accelerator().memory_allocated()
65+
dist.barrier()
66+
67+
time_start = time.time()
68+
model.offload_states(include=include,
69+
device=OffloadDeviceEnum.cpu,
70+
pin_memory=pin_memory,
71+
non_blocking=non_blocking)
72+
dist.barrier()
73+
time_after_offload = time.time()
74+
alloc_after_offload = get_accelerator().memory_allocated()
75+
assert alloc_after_offload < alloc_before_offload, f"Allocated memory should decrease after offload"
76+
77+
# Load offloaded states back
78+
model.reload_states()
79+
dist.barrier()
80+
time_after_load = time.time()
81+
82+
time_offload_list.append(time_after_offload - time_start)
83+
time_load_list.append(time_after_load - time_after_offload)
84+
85+
assert alloc_after_offload < get_accelerator().memory_allocated(
86+
), f"Allocated memory should increase after offload back"
87+
88+
if dist.get_rank() == 0:
89+
print(
90+
f"Memory usage ({i}): include={include}, pin_memory={pin_memory}, non_blocking={non_blocking} alloc_before_offload={alloc_before_offload} alloc_after_offload={alloc_after_offload}"
91+
)
92+
93+
# remove warmup
94+
time_offload_list = time_offload_list[warmup:]
95+
time_load_list = time_load_list[warmup:]
96+
97+
if dist.get_rank() == 0:
98+
with open("offload_states.log", "a") as f:
99+
offload_time = sum(time_offload_list) / len(time_offload_list)
100+
load_time = sum(time_load_list) / len(time_load_list)
101+
msg = f"{1 if pin_memory else 0},{1 if non_blocking else 0},{offload_time},{load_time}"
102+
f.write(f"{msg}\n")
103+
print(f"Summary: pin_memory={pin_memory} non_blocking={non_blocking} offload={offload_time} load={load_time}")
104+
105+
# Needed in ZeRO 3. Not doing so can give memory leak
106+
model.destroy()
107+
108+
109+
def main():
110+
parser = argparse.ArgumentParser(description="Test Offload States")
111+
parser.add_argument("--included_state", type=str, choices=[e.name for e in OffloadStateTypeEnum] + [None], default=None, help="State to include")
112+
parser.add_argument("--pin_memory", action='store_true', help="Pin memory")
113+
parser.add_argument("--non_blocking", action='store_true', help="Non blocking")
114+
parser.add_argument("--nlayers", type=int, default=1, help="Number of layers")
115+
parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden dimension")
116+
parser.add_argument('--dtype', choices=['torch.bfloat16', 'torch.float16', 'torch.float32'], default='torch.bfloat16', help='Data type')
117+
parser.add_argument("--local_rank", type=int, default=-1, help="Local rank")
118+
parser.add_argument("--iteration", type=int, default=10, help="Warmup")
119+
parser.add_argument("--warmup", type=int, default=5, help="Warmup")
120+
121+
args = parser.parse_args()
122+
123+
dtype = eval(args.dtype)
124+
hidden_dim = args.hidden_dim
125+
126+
config_dict = {
127+
"train_micro_batch_size_per_gpu": 1,
128+
"optimizer": {
129+
"type": "Adam",
130+
"params": {
131+
"lr": 1e-6
132+
}
133+
},
134+
"zero_optimization": {
135+
"stage": 3,
136+
},
137+
}
138+
139+
if dtype == torch.float16:
140+
config_dict["fp16"] = {"enabled": True, "initial_scale_power": 8}
141+
elif dtype == torch.bfloat16:
142+
config_dict["bf16"] = {"enabled": True}
143+
144+
with deepspeed.zero.Init(config_dict_or_path=config_dict):
145+
model = SimpleModel(hidden_dim, nlayers=args.nlayers)
146+
147+
included_state = None if args.included_state is None else [OffloadStateTypeEnum[args.included_state]]
148+
run_model(model, config_dict, hidden_dim, dtype, included_state, args.pin_memory, args.non_blocking, args.iteration, args.warmup)
149+
150+
151+
if __name__ == "__main__":
152+
main()
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pandas as pd
2+
from pytablewriter import MarkdownTableWriter
3+
4+
5+
def read_csv(file_path):
6+
return pd.read_csv(file_path)
7+
8+
df = read_csv('offload_states.log')
9+
df.columns = ['pin_memory', 'non_blocking', 'offload_time', 'load_time']
10+
11+
df['ratio_string'] = df['offload_time'].round(2).astype(str) + " / " + df['load_time'].round(2).astype(str)
12+
13+
result_df = pd.DataFrame({
14+
'pin_memory=0_non_blocking=0': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True),
15+
'pin_memory=0_non_blocking=1': df[(df['pin_memory'] == 0) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True),
16+
'pin_memory=1_non_blocking=0': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 0)]['ratio_string'].reset_index(drop=True),
17+
'pin_memory=1_non_blocking=1': df[(df['pin_memory'] == 1) & (df['non_blocking'] == 1)]['ratio_string'].reset_index(drop=True)
18+
})
19+
result_df = result_df.dropna()
20+
result_df.index = range(1, len(result_df) + 1)
21+
result_df.index.name = 'trial'
22+
# print(result_df)
23+
24+
writer = MarkdownTableWriter()
25+
writer.from_dataframe(result_df,
26+
add_index_column=True,
27+
)
28+
writer.write_table()
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
NGPUS=4
2+
HIDDEN_SIZE=32768
3+
NUM_LAYERS=4
4+
5+
TRIALS=10
6+
7+
PIN_MEMORY_OPTS=(0 1)
8+
NON_BLOCKING_OPTS=(0 1)
9+
10+
for i in $(seq 1 $TRIALS); do
11+
for PIN_MEMORY in "${PIN_MEMORY_OPTS[@]}"; do
12+
PIN_MEMORY_ARG=""
13+
if [ $PIN_MEMORY -eq 1 ]; then
14+
PIN_MEMORY_ARG="--pin_memory"
15+
fi
16+
17+
for NON_BLOCKING in "${NON_BLOCKING_OPTS[@]}"; do
18+
NON_BLOCKING_ARG=""
19+
if [ $NON_BLOCKING -eq 1 ]; then
20+
NON_BLOCKING_ARG="--non_blocking"
21+
fi
22+
23+
echo "Running iteration $i"
24+
deepspeed --num_gpus=$NGPUS offload_states.py --hidden_dim $HIDDEN_SIZE --nlayers $NUM_LAYERS $PIN_MEMORY_ARG $NON_BLOCKING_ARG
25+
done
26+
done
27+
done
28+
python output_table.py

0 commit comments

Comments
 (0)