Skip to content

Commit 0e10c4b

Browse files
Adding LoRA-Distillation SD training example (#788)
Co-authored-by: Xiaoxia (Shirley) Wu <[email protected]>
1 parent b116838 commit 0e10c4b

6 files changed

+2120
-0
lines changed

training/stable_diffusion/README.md

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Lora-enhanced distillation on Stable Diffusion model
2+
3+
This repository contains the implementation of Lora-enhanced distillation applied to the Stable Diffusion (SD) model. By combining the LoRA technique with distillation, we've achieved remarkable results, including a significant reduction in inference time and a 50% decrease in memory consumption. Importantly, this integration of LoRA-enhanced distillation maintains image quality and alignment with the provided prompt. For additional details on this work, please consult our technical report [TODO: add link].
4+
5+
In this implementation, we have adapted the dreambooth finetuning [code](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#dreambooth-training-example) as our baseline. Below, you'll find information regarding input data, training, and inference.
6+
7+
## Installation
8+
9+
You need to have huggingface [diffusers](https://github.com/huggingface/diffusers) installed on your machine. Then install the requirements:
10+
11+
<pre>
12+
pip install -r requirements.txt
13+
</pre>
14+
15+
## Training
16+
17+
### Training Data
18+
Our training data includes a significant dataset of pre-generated images by [SD](https://github.com/poloclub/diffusiondb). You are not required to download the input data. Instead, you can specify or modify it within the training code (`train_sd_distill_lora.py`) as needed.To train the model, follow these steps:
19+
20+
### Training Script
21+
22+
1. Run the `mytrainbash.sh` file.
23+
2. The finetuned model will be saved inside the output directory.
24+
25+
Here's an example command to run the training script:
26+
27+
<pre>
28+
bash mytrainbash.sh
29+
</pre>
30+
31+
Make sure to customize the training parameters in the script to suit your specific requirements.
32+
33+
## Inference
34+
35+
For inference, you can use the `inf-loop.py` Python code. Follow these steps:
36+
37+
1. Provide your desired prompts as input in the script.
38+
2. Run the `inf_txt2img_loop.py` script.
39+
40+
Here's an example command to run the inference script:
41+
42+
<pre>
43+
deepspeed inf_txt2img_loop.py
44+
</pre>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import deepspeed
2+
import torch
3+
import os
4+
from local_pipeline_stable_diffusion import StableDiffusionPipeline
5+
from diffusers import StableDiffusionPipeline as StableDiffusionPipelineBaseline
6+
import argparse
7+
8+
seed = 123450011
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("--ft_model", default="new_sd-distill-v21-10k-1e", type=str, help="Path to the fine-tuned model")
11+
parser.add_argument("--b_model", default="stabilityai/stable-diffusion-2-1-base", type=str, help="Path to the baseline model")
12+
parser.add_argument("--out_dir", default="image_out/", type=str, help="Path to the generated images")
13+
parser.add_argument('--guidance_scale', type=float, default=7.5, help='Guidance Scale')
14+
parser.add_argument("--use_local_pipe", action='store_true', help="Use local SD pipeline")
15+
parser.add_argument("--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0")), help="local rank")
16+
args = parser.parse_args()
17+
18+
19+
local_rank = int(os.getenv("LOCAL_RANK", "0"))
20+
device = torch.device(f"cuda:{local_rank}")
21+
world_size = int(os.getenv('WORLD_SIZE', '1'))
22+
23+
24+
if not os.path.exists(args.out_dir):
25+
os.makedirs(args.out_dir)
26+
print(f"Directory '{args.out_dir}' has been created to store the generated images.")
27+
else:
28+
print(f"Directory '{args.out_dir}' already exists and stores the generated images.")
29+
30+
31+
prompts = ["A boy is watching TV",
32+
"A photo of a person dancing in the rain",
33+
"A photo of a boy jumping over a fence",
34+
"A photo of a boy is kicking a ball",
35+
"A beach with a lot of waves on it",
36+
"A road that is going down a hill",
37+
"3d rendering of 5 tennis balls on top of a cake",
38+
"A person holding a drink of soda",
39+
"A person is squeezing a lemon",
40+
"A person holding a cat"]
41+
42+
43+
for prompt in prompts:
44+
#--- new image
45+
pipe_new = StableDiffusionPipeline.from_pretrained(args.ft_model, torch_dtype=torch.float16).to("cuda")
46+
generator = torch.Generator("cuda").manual_seed(seed)
47+
pipe_new = deepspeed.init_inference(pipe_new, mp_size=world_size, dtype=torch.half)
48+
image_new = pipe_new(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0]
49+
image_new.save(args.out_dir+"/NEW__seed_"+str(seed)+"_"+prompt[0:100]+".png")
50+
51+
#--- baseline image
52+
pipe_baseline = StableDiffusionPipelineBaseline.from_pretrained(args.b_model, torch_dtype=torch.float16).to("cuda")
53+
generator = torch.Generator("cuda").manual_seed(seed)
54+
pipe_baseline = deepspeed.init_inference(pipe_baseline, mp_size=world_size, dtype=torch.half)
55+
image_baseline = pipe_baseline(prompt, num_inference_steps=50, guidance_scale=args.guidance_scale, generator=generator).images[0]
56+
image_baseline.save(args.out_dir+"/BASELINE_seed_"+str(seed)+"_"+prompt[0:100]+".png")

0 commit comments

Comments
 (0)