Skip to content

Commit 18cbe26

Browse files
author
yupeng.zhou
committed
twoperson
1 parent ae1398a commit 18cbe26

3 files changed

+898
-115
lines changed

gradio_app_sdxl_specific_id_low_vram.py

+116-115
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
AttnProcessor2_0 as AttnProcessor
2222
else:
2323
from utils.gradio_utils import AttnProcessor
24-
24+
import datetime
2525
import diffusers
2626
from diffusers import StableDiffusionXLPipeline
2727
from utils import PhotoMakerStableDiffusionXLPipeline
@@ -181,83 +181,6 @@ def __call__(
181181
cur_step += 1
182182
indices1024,indices4096 = cal_attn_indice_xl_effcient_memory(self.total_length,self.id_length,sa32,sa64,height,width, device=self.device, dtype= self.dtype)
183183

184-
return hidden_states
185-
def __call1__(
186-
self,
187-
attn,
188-
hidden_states,
189-
encoder_hidden_states=None,
190-
attention_mask=None,
191-
temb=None,
192-
attn_indices = None,
193-
):
194-
# print("hidden state shape",hidden_states.shape,self.id_length)
195-
residual = hidden_states
196-
# if encoder_hidden_states is not None:
197-
# raise Exception("not implement")
198-
if attn.spatial_norm is not None:
199-
hidden_states = attn.spatial_norm(hidden_states, temb)
200-
input_ndim = hidden_states.ndim
201-
202-
if input_ndim == 4:
203-
total_batch_size, channel, height, width = hidden_states.shape
204-
hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2)
205-
total_batch_size,nums_token,channel = hidden_states.shape
206-
img_nums = total_batch_size//2
207-
hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel)
208-
batch_size, sequence_length, _ = hidden_states.shape
209-
210-
if attn.group_norm is not None:
211-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
212-
213-
query = attn.to_q(hidden_states)
214-
215-
if encoder_hidden_states is None:
216-
encoder_hidden_states = hidden_states # B, N, C
217-
else:
218-
encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel)
219-
220-
key = attn.to_k(encoder_hidden_states)
221-
value = attn.to_v(encoder_hidden_states)
222-
223-
224-
inner_dim = key.shape[-1]
225-
head_dim = inner_dim // attn.heads
226-
227-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
228-
229-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
230-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
231-
# print(key.shape,value.shape,query.shape,attention_mask.shape)
232-
# the output of sdp = (batch, num_heads, seq_len, head_dim)
233-
# TODO: add support for attn.scale when we move to Torch 2.1
234-
#print(query.shape,key.shape,value.shape,attention_mask.shape)
235-
hidden_states = F.scaled_dot_product_attention(
236-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
237-
)
238-
239-
hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim)
240-
hidden_states = hidden_states.to(query.dtype)
241-
242-
243-
244-
# linear proj
245-
hidden_states = attn.to_out[0](hidden_states)
246-
# dropout
247-
hidden_states = attn.to_out[1](hidden_states)
248-
249-
# if input_ndim == 4:
250-
# tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
251-
252-
# if attn.residual_connection:
253-
# tile_hidden_states = tile_hidden_states + residual
254-
255-
if input_ndim == 4:
256-
hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width)
257-
if attn.residual_connection:
258-
hidden_states = hidden_states + residual
259-
hidden_states = hidden_states / attn.rescale_output_factor
260-
# print(hidden_states.shape)
261184
return hidden_states
262185
def __call2__(
263186
self,
@@ -393,6 +316,60 @@ def set_attention_processor(unet,id_length,is_ipadapter = False):
393316
<style>
394317
'''
395318

319+
def save_single_character_weights(unet,character,description, filepath):
320+
"""
321+
保存 attention_processor 类中的 id_bank GPU Tensor 列表到指定文件中。
322+
参数:
323+
- model: 包含 attention_processor 类实例的模型。
324+
- filepath: 权重要保存到的文件路径。
325+
"""
326+
weights_to_save = {}
327+
weights_to_save["description"] = description
328+
weights_to_save["character"] = character
329+
for attn_name, attn_processor in unet.attn_processors.items():
330+
if isinstance(attn_processor, SpatialAttnProcessor2_0):
331+
# 将每个 Tensor 转到 CPU 并转为列表,以确保它可以被序列化
332+
weights_to_save[attn_name] = {}
333+
for step_key in attn_processor.id_bank[character].keys():
334+
weights_to_save[attn_name][step_key] = [tensor.cpu() for tensor in attn_processor.id_bank[character][step_key]]
335+
# 使用torch.save保存权重
336+
torch.save(weights_to_save, filepath)
337+
338+
def load_single_character_weights(unet, filepath):
339+
"""
340+
从指定文件中加载权重到 attention_processor 类的 id_bank 中。
341+
参数:
342+
- model: 包含 attention_processor 类实例的模型。
343+
- filepath: 权重文件的路径。
344+
"""
345+
# 使用torch.load来读取权重
346+
weights_to_load = torch.load(filepath, map_location=torch.device('cpu'))
347+
character = weights_to_load['character']
348+
description = weights_to_load['description']
349+
for attn_name, attn_processor in unet.attn_processors.items():
350+
if isinstance(attn_processor, SpatialAttnProcessor2_0):
351+
# 转移权重到GPU(如果GPU可用的话)并赋值给id_bank
352+
attn_processor.id_bank[character] = {}
353+
for step_key in weights_to_load[attn_name].keys():
354+
attn_processor.id_bank[character][step_key] = [tensor.to(unet.device) for tensor in weights_to_load[attn_name][step_key]]
355+
356+
def save_results(unet,img_list):
357+
358+
timestamp = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')
359+
folder_name = f'results/{timestamp}'
360+
weight_folder_name = f'{folder_name}/weights'
361+
# 创建文件夹
362+
if not os.path.exists(folder_name):
363+
os.makedirs(folder_name)
364+
os.makedirs(weight_folder_name)
365+
366+
for idx, img in enumerate(img_list):
367+
file_path = os.path.join(folder_name, f'image_{idx}.png') # 图片文件名
368+
img.save(file_path)
369+
global character_dict
370+
# for char in character_dict:
371+
# description = character_dict[char]
372+
# save_single_character_weights(unet,char,description,os.path.join(weight_folder_name, f'{char}.pt'))
396373

397374
#################################################
398375
title = r"""
@@ -426,14 +403,14 @@ def set_attention_processor(unet,id_length,is_ipadapter = False):
426403
```
427404
📋 **License**
428405
<br>
429-
The Contents you create are under Apache-2.0 LICENSE. The Code are under Attribution-NonCommercial 4.0 International.
406+
Apache-2.0 LICENSE.
430407
431408
📧 **Contact**
432409
<br>
433410
If you have any questions, please feel free to reach me out at <b>[email protected]</b>.
434411
"""
435412
version = r"""
436-
<h3 align="center">StoryDiffusion Version 0.01 (test version)</h3>
413+
<h3 align="center">StoryDiffusion Version 0.02 (test version)</h3>
437414
438415
<h5 >1. Support image ref image. (Cartoon Ref image is not support now)</h5>
439416
<h5 >2. Support Typesetting Style and Captioning.(By default, the prompt is used as the caption for each image. If you need to change the caption, add a # at the end of each line. Only the part after the # will be added as a caption to the image.)</h5>
@@ -528,9 +505,29 @@ def change_visiale_by_model_type(_model_type):
528505
else:
529506
raise ValueError("Invalid model type",_model_type)
530507

508+
def load_character_files(character_files:str):
509+
if character_files == "":
510+
raise gr.Error("Please set a character file!")
511+
character_files_arr = character_files.splitlines()
512+
primarytext = []
513+
for character_file_name in character_files_arr:
514+
character_file = torch.load(character_file_name, map_location=torch.device('cpu'))
515+
primarytext.append(character_file["character"] + character_file["description"])
516+
return array2string(primarytext)
517+
518+
519+
def load_character_files_on_running(unet,character_files:str):
520+
if character_files == "":
521+
return False
522+
character_files_arr = character_files.splitlines()
523+
for character_file in character_files_arr:
524+
load_single_character_weights(unet, character_file)
525+
return True
531526

532527
######### Image Generation ##############
533-
def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_name, _Ip_Adapter_Strength ,_style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt,prompt_array,G_height,G_width,_comic_type, font_choice): # Corrected font_choice usage
528+
def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_name, _Ip_Adapter_Strength ,_style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt,prompt_array,G_height,G_width,_comic_type, font_choice,_char_files): # Corrected font_choice usage
529+
if len(general_prompt.splitlines()) >= 3:
530+
raise gr.Error("Support for more than three characters is temporarily unavailable due to VRAM limitations, but this issue will be resolved soon.")
534531
_model_type = "Photomaker" if _model_type == "Using Ref Images" else "original"
535532
if _model_type == "Photomaker" and "img" not in general_prompt:
536533
raise gr.Error("Please add the triger word \" img \" behind the class word you want to customize, such as: man img or woman img")
@@ -574,17 +571,13 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
574571
unet = pipe.unet
575572
# unet.set_attn_processor(copy.deepcopy(attn_procs))
576573

577-
574+
load_chars = load_character_files_on_running(unet,character_files=_char_files)
578575

579576
prompts = prompt_array.splitlines()
580577
global character_dict,character_index_dict,invert_character_index_dict,ref_indexs_dict,ref_totals
581578
character_dict,character_list = character_to_dict(general_prompt)
582579

583580

584-
585-
586-
587-
588581
start_merge_step = int(float(_style_strength_ratio) / 100 * _num_steps)
589582
if start_merge_step > 30:
590583
start_merge_step = 30
@@ -627,33 +620,37 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
627620
id_images = []
628621
results_dict = {}
629622
global cur_character
630-
for character_key in character_dict.keys():
631-
cur_character = [character_key]
632-
ref_indexs = ref_indexs_dict[character_key]
633-
print(character_key,ref_indexs)
634-
current_prompts = [replace_prompts[ref_ind] for ref_ind in ref_indexs]
635-
print(current_prompts)
636-
setup_seed(seed_)
637-
generator = torch.Generator(device="cuda").manual_seed(seed_)
638-
cur_step = 0
639-
cur_positive_prompts, negative_prompt = apply_style(style_name, current_prompts, negative_prompt)
640-
if _model_type == "original":
641-
id_images = pipe(cur_positive_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
642-
elif _model_type == "Photomaker":
643-
id_images = pipe(cur_positive_prompts,input_id_images=input_id_images_dict[character_key], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
644-
else:
645-
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
646-
647-
# total_results = id_images + total_results
648-
# yield total_results
649-
print(id_images)
650-
for ind,img in enumerate(id_images):
651-
print(ref_indexs[ind])
652-
results_dict[ref_indexs[ind]] = img
653-
# real_images = []
623+
if not load_chars:
624+
for character_key in character_dict.keys():
625+
cur_character = [character_key]
626+
ref_indexs = ref_indexs_dict[character_key]
627+
print(character_key,ref_indexs)
628+
current_prompts = [replace_prompts[ref_ind] for ref_ind in ref_indexs]
629+
print(current_prompts)
630+
setup_seed(seed_)
631+
generator = torch.Generator(device="cuda").manual_seed(seed_)
632+
cur_step = 0
633+
cur_positive_prompts, negative_prompt = apply_style(style_name, current_prompts, negative_prompt)
634+
if _model_type == "original":
635+
id_images = pipe(cur_positive_prompts, num_inference_steps=_num_steps, guidance_scale=guidance_scale, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
636+
elif _model_type == "Photomaker":
637+
id_images = pipe(cur_positive_prompts,input_id_images=input_id_images_dict[character_key], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator).images
638+
else:
639+
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
640+
641+
# total_results = id_images + total_results
642+
# yield total_results
643+
print(id_images)
644+
for ind,img in enumerate(id_images):
645+
print(ref_indexs[ind])
646+
results_dict[ref_indexs[ind]] = img
647+
# real_images = []
648+
yield [results_dict[ind] for ind in results_dict.keys()]
654649
write = False
655-
656-
real_prompts_inds = [ind for ind in range(len(prompts)) if ind not in ref_totals]
650+
if not load_chars:
651+
real_prompts_inds = [ind for ind in range(len(prompts)) if ind not in ref_totals]
652+
else:
653+
real_prompts_inds = [ind for ind in range(len(prompts))]
657654
print(real_prompts_inds)
658655

659656
for real_prompts_ind in real_prompts_inds:
@@ -672,7 +669,7 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
672669
results_dict[real_prompts_ind] = (pipe(real_prompt, input_id_images=input_id_images_dict[cur_character[0]] if real_prompts_ind not in nc_indexs else input_id_images_dict[character_list[0]], num_inference_steps=_num_steps, guidance_scale=guidance_scale, start_merge_step = start_merge_step, height = height, width = width,negative_prompt = negative_prompt,generator = generator,nc_flag = True if real_prompts_ind in nc_indexs else False).images[0])
673670
else:
674671
raise NotImplementedError("You should choice between original and Photomaker!",f"But you choice {_model_type}")
675-
672+
yield [results_dict[ind] for ind in results_dict.keys()]
676673
total_results = [results_dict[ind] for ind in range(len(prompts))]
677674
if _comic_type != "No typesetting (default)":
678675
captions= prompt_array.splitlines()
@@ -684,6 +681,8 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
684681
print(f"Attempting to load font from path: {font_path}")
685682
font = ImageFont.truetype(font_path, int(45))
686683
total_results = get_comic(total_results, _comic_type, captions=captions, font=font) + total_results
684+
save_results(pipe.unet,total_results)
685+
687686
yield total_results
688687

689688

@@ -731,6 +730,8 @@ def array2string(arr):
731730
negative_prompt = gr.Textbox(value='', label="(2) Negative_prompt", interactive=True)
732731
style = gr.Dropdown(label="Style template", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)
733732
prompt_array = gr.Textbox(lines = 3,value='', label="(3) Comic Description (each line corresponds to a frame).", interactive=True)
733+
char_path = gr.Textbox(lines = 2,value='', visible = False,label="(Optional) Character files", interactive=True)
734+
char_btn = gr.Button("Load Character files",visible = False)
734735
with gr.Accordion("(4) Tune the hyperparameters", open=True):
735736
font_choice = gr.Dropdown(label="Select Font", choices=[f for f in os.listdir("./fonts") if f.endswith('.ttf')], value="Inkfree.ttf", info="Select font for the final slide.", interactive=True)
736737
sa32_ = gr.Slider(label=" (The degree of Paired Attention at 32 x 32 self-attention layers) ", minimum=0, maximum=1., value=0.5, step=0.1)
@@ -792,9 +793,9 @@ def array2string(arr):
792793
model_type.change(fn = change_visiale_by_model_type , inputs = model_type, outputs=[control_image_input,style_strength_ratio,Ip_Adapter_Strength])
793794
files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
794795
remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
795-
796+
char_btn.click(fn=load_character_files,inputs=char_path,outputs=[general_prompt])
796797
final_run_btn.click(fn=set_text_unfinished, outputs=generated_information
797-
).then(process_generation, inputs=[sd_type,model_type,files, num_steps,style, Ip_Adapter_Strength,style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array,G_height,G_width,comic_type, font_choice], outputs=out_image
798+
).then(process_generation, inputs=[sd_type,model_type,files, num_steps,style, Ip_Adapter_Strength,style_strength_ratio, guidance_scale, seed_, sa32_, sa64_, id_length_, general_prompt, negative_prompt, prompt_array,G_height,G_width,comic_type, font_choice,char_path], outputs=out_image
798799
).then(fn=set_text_finished,outputs=generated_information)
799800

800801

0 commit comments

Comments
 (0)