Skip to content

Commit 3b0531d

Browse files
authored
Merge pull request #32 from NickLucche/global-refactoring
♻️ Global refactoring
2 parents e7389fe + 97e178a commit 3b0531d

File tree

6 files changed

+247
-165
lines changed

6 files changed

+247
-165
lines changed

main.py

+37-102
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,35 @@
11
from typing import List, Union
2-
from diffusers import StableDiffusionPipeline
3-
from diffusers.pipelines.stable_diffusion import (
4-
StableDiffusionImg2ImgPipeline,
5-
StableDiffusionInpaintPipeline,
6-
)
72
import torch
83
from PIL import Image
9-
from diffusers.pipelines.stable_diffusion.safety_checker import (
10-
StableDiffusionSafetyChecker,
11-
)
124
import os
13-
from utils import ModelParts2GPUsAssigner, get_gpu_setting, dummy_checker, remove_nsfw
5+
from utils import ModelParts2GPUsAssigner, get_gpu_setting
146
from parallel import StableDiffusionModelParallel, StableDiffusionMultiProcessing
15-
from schedulers import schedulers
167
import numpy as np
8+
from sb import DiffusionModel
179

10+
# read env variables
1811
TOKEN = os.environ.get("TOKEN", None)
1912
MODEL_ID = os.environ.get("MODEL_ID", "stabilityai/stable-diffusion-2-base")
2013

14+
# If you are limited by GPU memory (e.g <10GB VRAM), please make sure to load in fp16 precision
2115
fp16 = bool(int(os.environ.get("FP16", 1)))
2216
# MP = bool(int(os.environ.get("MODEL_PARALLEL", 0)))
2317
MP = False # disabled
2418
MIN_INPAINT_MASK_PERCENT = 0.1
2519

26-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
27-
2820
# FIXME devices=0,1 causes cuda error on memory access..?
29-
# create and move model to GPU(s), defaults to GPU 0
30-
multi, devices = get_gpu_setting(os.environ.get("DEVICES", "0"))
31-
# If you are limited by GPU memory and have less than 10GB of GPU RAM available, please make sure to load the StableDiffusionPipeline in float16 precision
32-
kwargs = dict(
33-
pretrained_model_name_or_path=MODEL_ID,
34-
revision="fp16" if fp16 else None,
35-
torch_dtype=torch.float16 if fp16 else None,
36-
use_auth_token=TOKEN,
37-
requires_safety_checker=False,
38-
)
39-
40-
pipe, safety, safety_extractor = None, None, None
41-
42-
43-
def load_pipeline(model_or_path, devices: List[int]):
44-
global pipe, safety, safety_extractor
45-
if pipe is not None and pipe._pipe_name == model_or_path:
46-
# avoid re-loading same model
47-
return
48-
21+
IS_MULTI, DEVICES = get_gpu_setting(os.environ.get("DEVICES", "0"))
22+
23+
# TODO docs
24+
def init_pipeline(model_or_path=MODEL_ID, devices: List[int]=DEVICES)->Union[DiffusionModel, StableDiffusionMultiProcessing]:
25+
kwargs = dict(
26+
pretrained_model_name_or_path=model_or_path,
27+
revision="fp16" if fp16 else None,
28+
torch_dtype=torch.float16 if fp16 else None,
29+
use_auth_token=TOKEN,
30+
requires_safety_checker=False,
31+
)
4932
model_ass = None
50-
print(f"Loading {model_or_path} from disk..")
51-
kwargs["pretrained_model_name_or_path"] = model_or_path
5233
# single-gpu multiple models currently disabled
5334
if MP and len(devices) > 1:
5435
# setup for model parallel: find model parts->gpus assignment
@@ -63,31 +44,27 @@ def load_pipeline(model_or_path, devices: List[int]):
6344
)
6445
print("Assignments:", model_ass)
6546

66-
if multi and pipe is not None:
47+
# TODO move logic
48+
# if multi and pipe is not None:
6749
# avoid re-creating processes in multi-gpu mode, have them reload a different model
68-
pipe.reload_model(model_or_path)
69-
elif multi:
50+
# pipe.reload_model(model_or_path)
51+
if IS_MULTI:
7052
# DataParallel: one process *per GPU* (each has a copy of the model)
7153
# ModelParallel: one process *per model*, each model (possibly) on multiple GPUs
7254
n_procs = len(devices) if not MP else len(model_ass)
7355
pipe = StableDiffusionMultiProcessing.from_pretrained(
7456
n_procs, devices, model_parallel_assignment=model_ass, **kwargs
7557
)
7658
else:
77-
pipe = StableDiffusionPipeline.from_pretrained(**kwargs)
78-
# remove safety checker so it doesn't use up GPU memory
79-
safety, safety_extractor = remove_nsfw(pipe)
59+
pipe = DiffusionModel.from_pretrained(**kwargs)
8060
if len(devices):
8161
pipe.to(f"cuda:{devices[0]}")
8262

83-
pipe._pipe_name = model_or_path
84-
print("Model Loaded!")
85-
86-
87-
load_pipeline(MODEL_ID, devices)
63+
return pipe
8864

8965

9066
def inference(
67+
pipe: DiffusionModel,
9168
prompt,
9269
num_images=1,
9370
num_inference_steps=50,
@@ -105,29 +82,28 @@ def inference(
10582
):
10683
prompt = [prompt] * num_images
10784
input_kwargs = dict(
85+
inference_type = "text",
10886
prompt=prompt,
87+
# number of denoising steps run during inference (the higher the better)
10988
num_inference_steps=num_inference_steps,
11089
height=height,
11190
width=width,
11291
guidance_scale=guidance_scale,
113-
generator=None,
92+
# NOTE seed with multiples gpus will be different for each one but fixed!
93+
generator=seed,
11494
)
11595
# input sketch has priority over input image
11696
if input_sketch is not None:
11797
input_image = input_sketch
11898

119-
# Img2Img: to avoid re-loading the model, we ""cast"" the pipeline
12099
# TODO batch images by providing a torch tensor
121100
if input_image is not None:
122-
input_image = input_image.resize((width, height))
123101
# image guided generation
124-
if multi:
125-
pipe.change_pipeline_type("img2img")
126-
else:
127-
pipe.__class__ = StableDiffusionImg2ImgPipeline
102+
input_image = input_image.resize((width, height))
128103
# TODO negative prompt?
129104
input_kwargs["init_image"] = input_image
130105
input_kwargs["strength"] = 1.0 - inv_strenght
106+
input_kwargs["inference_type"] = "img2img"
131107
elif masked_image is not None:
132108
# resize to specified shape
133109
masked_image = {
@@ -138,61 +114,20 @@ def inference(
138114
if np.count_nonzero(masked_image["mask"].convert("1")) < (
139115
width * height * MIN_INPAINT_MASK_PERCENT
140116
):
141-
# FIXME error handling
142-
raise Exception("ERROR: mask is too small!")
143-
if multi:
144-
pipe.change_pipeline_type("inpaint")
145-
else:
146-
pipe.__class__ = StableDiffusionInpaintPipeline
117+
raise ValueError("Mask is too small. Please paint-over a larger area")
147118
input_kwargs["image"] = masked_image["image"]
148119
input_kwargs["mask_image"] = masked_image["mask"]
149-
elif multi:
150-
# default mode
151-
pipe.change_pipeline_type("text")
152-
else:
153-
pipe.__class__ = StableDiffusionPipeline
154-
155-
# for repeatable results; tensor generated on cpu for model parallel
156-
if multi:
157-
# generator cant be pickled
158-
# NOTE fixed seed with multiples gpus will be different for each one but fixed!
159-
input_kwargs["generator"] = seed
160-
elif seed is not None and seed > 0:
161-
input_kwargs["generator"] = torch.Generator(
162-
f"cuda:{devices[0]}" if not MP else "cpu"
163-
).manual_seed(seed)
164-
165-
if nsfw_filter:
166-
if multi:
167-
pipe.safety_checker = None
168-
else:
169-
pipe.safety_checker = safety.to(f"cuda:{devices[0]}")
170-
pipe.feature_extractor = safety_extractor
171-
else:
172-
if multi:
173-
pipe.safety_checker = dummy_checker
174-
else:
175-
# remove safety network from gpu
176-
remove_nsfw(pipe)
177-
178-
if low_vram:
179-
# needed on 16GB RAM 768x768 fp32
180-
pipe.enable_attention_slicing()
181-
else:
182-
pipe.disable_attention_slicing()
120+
input_kwargs["inference_type"] = "inpaint"
121+
122+
pipe.set_nsfw(nsfw_filter)
123+
124+
# needed on 16GB RAM 768x768 fp32
125+
pipe.enable_attention_slicing("auto" if low_vram else None)
183126

184127
# set noise scheduler for inference
185-
if noise_scheduler is not None and noise_scheduler in schedulers:
186-
if multi:
187-
pipe.scheduler = noise_scheduler
188-
else:
189-
# load scheduler from pre-trained config
190-
s = getattr(schedulers[noise_scheduler], "from_config")(
191-
pipe.scheduler.config
192-
)
193-
pipe.scheduler = s
128+
if noise_scheduler is not None:
129+
pipe.scheduler = noise_scheduler
194130

195-
# number of denoising steps run during inference (the higher the better)
196131
with torch.autocast("cuda"):
197132
images: List[Image.Image] = pipe(**input_kwargs)["images"]
198133
return images

0 commit comments

Comments
 (0)