21
21
AttnProcessor2_0 as AttnProcessor
22
22
else :
23
23
from utils .gradio_utils import AttnProcessor
24
-
24
+ import datetime
25
25
import diffusers
26
26
from diffusers import StableDiffusionXLPipeline
27
27
from utils import PhotoMakerStableDiffusionXLPipeline
@@ -181,83 +181,6 @@ def __call__(
181
181
cur_step += 1
182
182
indices1024 ,indices4096 = cal_attn_indice_xl_effcient_memory (self .total_length ,self .id_length ,sa32 ,sa64 ,height ,width , device = self .device , dtype = self .dtype )
183
183
184
- return hidden_states
185
- def __call1__ (
186
- self ,
187
- attn ,
188
- hidden_states ,
189
- encoder_hidden_states = None ,
190
- attention_mask = None ,
191
- temb = None ,
192
- attn_indices = None ,
193
- ):
194
- # print("hidden state shape",hidden_states.shape,self.id_length)
195
- residual = hidden_states
196
- # if encoder_hidden_states is not None:
197
- # raise Exception("not implement")
198
- if attn .spatial_norm is not None :
199
- hidden_states = attn .spatial_norm (hidden_states , temb )
200
- input_ndim = hidden_states .ndim
201
-
202
- if input_ndim == 4 :
203
- total_batch_size , channel , height , width = hidden_states .shape
204
- hidden_states = hidden_states .view (total_batch_size , channel , height * width ).transpose (1 , 2 )
205
- total_batch_size ,nums_token ,channel = hidden_states .shape
206
- img_nums = total_batch_size // 2
207
- hidden_states = hidden_states .view (- 1 ,img_nums ,nums_token ,channel ).reshape (- 1 ,img_nums * nums_token ,channel )
208
- batch_size , sequence_length , _ = hidden_states .shape
209
-
210
- if attn .group_norm is not None :
211
- hidden_states = attn .group_norm (hidden_states .transpose (1 , 2 )).transpose (1 , 2 )
212
-
213
- query = attn .to_q (hidden_states )
214
-
215
- if encoder_hidden_states is None :
216
- encoder_hidden_states = hidden_states # B, N, C
217
- else :
218
- encoder_hidden_states = encoder_hidden_states .view (- 1 ,self .id_length + 1 ,nums_token ,channel ).reshape (- 1 ,(self .id_length + 1 ) * nums_token ,channel )
219
-
220
- key = attn .to_k (encoder_hidden_states )
221
- value = attn .to_v (encoder_hidden_states )
222
-
223
-
224
- inner_dim = key .shape [- 1 ]
225
- head_dim = inner_dim // attn .heads
226
-
227
- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
228
-
229
- key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
230
- value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
231
- # print(key.shape,value.shape,query.shape,attention_mask.shape)
232
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
233
- # TODO: add support for attn.scale when we move to Torch 2.1
234
- #print(query.shape,key.shape,value.shape,attention_mask.shape)
235
- hidden_states = F .scaled_dot_product_attention (
236
- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
237
- )
238
-
239
- hidden_states = hidden_states .transpose (1 , 2 ).reshape (total_batch_size , - 1 , attn .heads * head_dim )
240
- hidden_states = hidden_states .to (query .dtype )
241
-
242
-
243
-
244
- # linear proj
245
- hidden_states = attn .to_out [0 ](hidden_states )
246
- # dropout
247
- hidden_states = attn .to_out [1 ](hidden_states )
248
-
249
- # if input_ndim == 4:
250
- # tile_hidden_states = tile_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
251
-
252
- # if attn.residual_connection:
253
- # tile_hidden_states = tile_hidden_states + residual
254
-
255
- if input_ndim == 4 :
256
- hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (total_batch_size , channel , height , width )
257
- if attn .residual_connection :
258
- hidden_states = hidden_states + residual
259
- hidden_states = hidden_states / attn .rescale_output_factor
260
- # print(hidden_states.shape)
261
184
return hidden_states
262
185
def __call2__ (
263
186
self ,
@@ -393,6 +316,60 @@ def set_attention_processor(unet,id_length,is_ipadapter = False):
393
316
<style>
394
317
'''
395
318
319
+ def save_single_character_weights (unet ,character ,description , filepath ):
320
+ """
321
+ 保存 attention_processor 类中的 id_bank GPU Tensor 列表到指定文件中。
322
+ 参数:
323
+ - model: 包含 attention_processor 类实例的模型。
324
+ - filepath: 权重要保存到的文件路径。
325
+ """
326
+ weights_to_save = {}
327
+ weights_to_save ["description" ] = description
328
+ weights_to_save ["character" ] = character
329
+ for attn_name , attn_processor in unet .attn_processors .items ():
330
+ if isinstance (attn_processor , SpatialAttnProcessor2_0 ):
331
+ # 将每个 Tensor 转到 CPU 并转为列表,以确保它可以被序列化
332
+ weights_to_save [attn_name ] = {}
333
+ for step_key in attn_processor .id_bank [character ].keys ():
334
+ weights_to_save [attn_name ][step_key ] = [tensor .cpu () for tensor in attn_processor .id_bank [character ][step_key ]]
335
+ # 使用torch.save保存权重
336
+ torch .save (weights_to_save , filepath )
337
+
338
+ def load_single_character_weights (unet , filepath ):
339
+ """
340
+ 从指定文件中加载权重到 attention_processor 类的 id_bank 中。
341
+ 参数:
342
+ - model: 包含 attention_processor 类实例的模型。
343
+ - filepath: 权重文件的路径。
344
+ """
345
+ # 使用torch.load来读取权重
346
+ weights_to_load = torch .load (filepath , map_location = torch .device ('cpu' ))
347
+ character = weights_to_load ['character' ]
348
+ description = weights_to_load ['description' ]
349
+ for attn_name , attn_processor in unet .attn_processors .items ():
350
+ if isinstance (attn_processor , SpatialAttnProcessor2_0 ):
351
+ # 转移权重到GPU(如果GPU可用的话)并赋值给id_bank
352
+ attn_processor .id_bank [character ] = {}
353
+ for step_key in weights_to_load [attn_name ].keys ():
354
+ attn_processor .id_bank [character ][step_key ] = [tensor .to (unet .device ) for tensor in weights_to_load [attn_name ][step_key ]]
355
+
356
+ def save_results (unet ,img_list ):
357
+
358
+ timestamp = datetime .datetime .now ().strftime ('%Y%m%d-%H%M%S' )
359
+ folder_name = f'results/{ timestamp } '
360
+ weight_folder_name = f'{ folder_name } /weights'
361
+ # 创建文件夹
362
+ if not os .path .exists (folder_name ):
363
+ os .makedirs (folder_name )
364
+ os .makedirs (weight_folder_name )
365
+
366
+ for idx , img in enumerate (img_list ):
367
+ file_path = os .path .join (folder_name , f'image_{ idx } .png' ) # 图片文件名
368
+ img .save (file_path )
369
+ global character_dict
370
+ # for char in character_dict:
371
+ # description = character_dict[char]
372
+ # save_single_character_weights(unet,char,description,os.path.join(weight_folder_name, f'{char}.pt'))
396
373
397
374
#################################################
398
375
title = r"""
@@ -426,14 +403,14 @@ def set_attention_processor(unet,id_length,is_ipadapter = False):
426
403
```
427
404
📋 **License**
428
405
<br>
429
- The Contents you create are under Apache-2.0 LICENSE. The Code are under Attribution-NonCommercial 4.0 International .
406
+ Apache-2.0 LICENSE.
430
407
431
408
📧 **Contact**
432
409
<br>
433
410
If you have any questions, please feel free to reach me out at <b>[email protected] </b>.
434
411
"""
435
412
version = r"""
436
- <h3 align="center">StoryDiffusion Version 0.01 (test version)</h3>
413
+ <h3 align="center">StoryDiffusion Version 0.02 (test version)</h3>
437
414
438
415
<h5 >1. Support image ref image. (Cartoon Ref image is not support now)</h5>
439
416
<h5 >2. Support Typesetting Style and Captioning.(By default, the prompt is used as the caption for each image. If you need to change the caption, add a # at the end of each line. Only the part after the # will be added as a caption to the image.)</h5>
@@ -528,9 +505,29 @@ def change_visiale_by_model_type(_model_type):
528
505
else :
529
506
raise ValueError ("Invalid model type" ,_model_type )
530
507
508
+ def load_character_files (character_files :str ):
509
+ if character_files == "" :
510
+ raise gr .Error ("Please set a character file!" )
511
+ character_files_arr = character_files .splitlines ()
512
+ primarytext = []
513
+ for character_file_name in character_files_arr :
514
+ character_file = torch .load (character_file_name , map_location = torch .device ('cpu' ))
515
+ primarytext .append (character_file ["character" ] + character_file ["description" ])
516
+ return array2string (primarytext )
517
+
518
+
519
+ def load_character_files_on_running (unet ,character_files :str ):
520
+ if character_files == "" :
521
+ return False
522
+ character_files_arr = character_files .splitlines ()
523
+ for character_file in character_files_arr :
524
+ load_single_character_weights (unet , character_file )
525
+ return True
531
526
532
527
######### Image Generation ##############
533
- def process_generation (_sd_type ,_model_type ,_upload_images , _num_steps ,style_name , _Ip_Adapter_Strength ,_style_strength_ratio , guidance_scale , seed_ , sa32_ , sa64_ , id_length_ , general_prompt , negative_prompt ,prompt_array ,G_height ,G_width ,_comic_type , font_choice ): # Corrected font_choice usage
528
+ def process_generation (_sd_type ,_model_type ,_upload_images , _num_steps ,style_name , _Ip_Adapter_Strength ,_style_strength_ratio , guidance_scale , seed_ , sa32_ , sa64_ , id_length_ , general_prompt , negative_prompt ,prompt_array ,G_height ,G_width ,_comic_type , font_choice ,_char_files ): # Corrected font_choice usage
529
+ if len (general_prompt .splitlines ()) >= 3 :
530
+ raise gr .Error ("Support for more than three characters is temporarily unavailable due to VRAM limitations, but this issue will be resolved soon." )
534
531
_model_type = "Photomaker" if _model_type == "Using Ref Images" else "original"
535
532
if _model_type == "Photomaker" and "img" not in general_prompt :
536
533
raise gr .Error ("Please add the triger word \" img \" behind the class word you want to customize, such as: man img or woman img" )
@@ -574,17 +571,13 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
574
571
unet = pipe .unet
575
572
# unet.set_attn_processor(copy.deepcopy(attn_procs))
576
573
577
-
574
+ load_chars = load_character_files_on_running ( unet , character_files = _char_files )
578
575
579
576
prompts = prompt_array .splitlines ()
580
577
global character_dict ,character_index_dict ,invert_character_index_dict ,ref_indexs_dict ,ref_totals
581
578
character_dict ,character_list = character_to_dict (general_prompt )
582
579
583
580
584
-
585
-
586
-
587
-
588
581
start_merge_step = int (float (_style_strength_ratio ) / 100 * _num_steps )
589
582
if start_merge_step > 30 :
590
583
start_merge_step = 30
@@ -627,33 +620,37 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
627
620
id_images = []
628
621
results_dict = {}
629
622
global cur_character
630
- for character_key in character_dict .keys ():
631
- cur_character = [character_key ]
632
- ref_indexs = ref_indexs_dict [character_key ]
633
- print (character_key ,ref_indexs )
634
- current_prompts = [replace_prompts [ref_ind ] for ref_ind in ref_indexs ]
635
- print (current_prompts )
636
- setup_seed (seed_ )
637
- generator = torch .Generator (device = "cuda" ).manual_seed (seed_ )
638
- cur_step = 0
639
- cur_positive_prompts , negative_prompt = apply_style (style_name , current_prompts , negative_prompt )
640
- if _model_type == "original" :
641
- id_images = pipe (cur_positive_prompts , num_inference_steps = _num_steps , guidance_scale = guidance_scale , height = height , width = width ,negative_prompt = negative_prompt ,generator = generator ).images
642
- elif _model_type == "Photomaker" :
643
- id_images = pipe (cur_positive_prompts ,input_id_images = input_id_images_dict [character_key ], num_inference_steps = _num_steps , guidance_scale = guidance_scale , start_merge_step = start_merge_step , height = height , width = width ,negative_prompt = negative_prompt ,generator = generator ).images
644
- else :
645
- raise NotImplementedError ("You should choice between original and Photomaker!" ,f"But you choice { _model_type } " )
646
-
647
- # total_results = id_images + total_results
648
- # yield total_results
649
- print (id_images )
650
- for ind ,img in enumerate (id_images ):
651
- print (ref_indexs [ind ])
652
- results_dict [ref_indexs [ind ]] = img
653
- # real_images = []
623
+ if not load_chars :
624
+ for character_key in character_dict .keys ():
625
+ cur_character = [character_key ]
626
+ ref_indexs = ref_indexs_dict [character_key ]
627
+ print (character_key ,ref_indexs )
628
+ current_prompts = [replace_prompts [ref_ind ] for ref_ind in ref_indexs ]
629
+ print (current_prompts )
630
+ setup_seed (seed_ )
631
+ generator = torch .Generator (device = "cuda" ).manual_seed (seed_ )
632
+ cur_step = 0
633
+ cur_positive_prompts , negative_prompt = apply_style (style_name , current_prompts , negative_prompt )
634
+ if _model_type == "original" :
635
+ id_images = pipe (cur_positive_prompts , num_inference_steps = _num_steps , guidance_scale = guidance_scale , height = height , width = width ,negative_prompt = negative_prompt ,generator = generator ).images
636
+ elif _model_type == "Photomaker" :
637
+ id_images = pipe (cur_positive_prompts ,input_id_images = input_id_images_dict [character_key ], num_inference_steps = _num_steps , guidance_scale = guidance_scale , start_merge_step = start_merge_step , height = height , width = width ,negative_prompt = negative_prompt ,generator = generator ).images
638
+ else :
639
+ raise NotImplementedError ("You should choice between original and Photomaker!" ,f"But you choice { _model_type } " )
640
+
641
+ # total_results = id_images + total_results
642
+ # yield total_results
643
+ print (id_images )
644
+ for ind ,img in enumerate (id_images ):
645
+ print (ref_indexs [ind ])
646
+ results_dict [ref_indexs [ind ]] = img
647
+ # real_images = []
648
+ yield [results_dict [ind ] for ind in results_dict .keys ()]
654
649
write = False
655
-
656
- real_prompts_inds = [ind for ind in range (len (prompts )) if ind not in ref_totals ]
650
+ if not load_chars :
651
+ real_prompts_inds = [ind for ind in range (len (prompts )) if ind not in ref_totals ]
652
+ else :
653
+ real_prompts_inds = [ind for ind in range (len (prompts ))]
657
654
print (real_prompts_inds )
658
655
659
656
for real_prompts_ind in real_prompts_inds :
@@ -672,7 +669,7 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
672
669
results_dict [real_prompts_ind ] = (pipe (real_prompt , input_id_images = input_id_images_dict [cur_character [0 ]] if real_prompts_ind not in nc_indexs else input_id_images_dict [character_list [0 ]], num_inference_steps = _num_steps , guidance_scale = guidance_scale , start_merge_step = start_merge_step , height = height , width = width ,negative_prompt = negative_prompt ,generator = generator ,nc_flag = True if real_prompts_ind in nc_indexs else False ).images [0 ])
673
670
else :
674
671
raise NotImplementedError ("You should choice between original and Photomaker!" ,f"But you choice { _model_type } " )
675
-
672
+ yield [ results_dict [ ind ] for ind in results_dict . keys ()]
676
673
total_results = [results_dict [ind ] for ind in range (len (prompts ))]
677
674
if _comic_type != "No typesetting (default)" :
678
675
captions = prompt_array .splitlines ()
@@ -684,6 +681,8 @@ def process_generation(_sd_type,_model_type,_upload_images, _num_steps,style_nam
684
681
print (f"Attempting to load font from path: { font_path } " )
685
682
font = ImageFont .truetype (font_path , int (45 ))
686
683
total_results = get_comic (total_results , _comic_type , captions = captions , font = font ) + total_results
684
+ save_results (pipe .unet ,total_results )
685
+
687
686
yield total_results
688
687
689
688
@@ -731,6 +730,8 @@ def array2string(arr):
731
730
negative_prompt = gr .Textbox (value = '' , label = "(2) Negative_prompt" , interactive = True )
732
731
style = gr .Dropdown (label = "Style template" , choices = STYLE_NAMES , value = DEFAULT_STYLE_NAME )
733
732
prompt_array = gr .Textbox (lines = 3 ,value = '' , label = "(3) Comic Description (each line corresponds to a frame)." , interactive = True )
733
+ char_path = gr .Textbox (lines = 2 ,value = '' , visible = False ,label = "(Optional) Character files" , interactive = True )
734
+ char_btn = gr .Button ("Load Character files" ,visible = False )
734
735
with gr .Accordion ("(4) Tune the hyperparameters" , open = True ):
735
736
font_choice = gr .Dropdown (label = "Select Font" , choices = [f for f in os .listdir ("./fonts" ) if f .endswith ('.ttf' )], value = "Inkfree.ttf" , info = "Select font for the final slide." , interactive = True )
736
737
sa32_ = gr .Slider (label = " (The degree of Paired Attention at 32 x 32 self-attention layers) " , minimum = 0 , maximum = 1. , value = 0.5 , step = 0.1 )
@@ -792,9 +793,9 @@ def array2string(arr):
792
793
model_type .change (fn = change_visiale_by_model_type , inputs = model_type , outputs = [control_image_input ,style_strength_ratio ,Ip_Adapter_Strength ])
793
794
files .upload (fn = swap_to_gallery , inputs = files , outputs = [uploaded_files , clear_button , files ])
794
795
remove_and_reupload .click (fn = remove_back_to_files , outputs = [uploaded_files , clear_button , files ])
795
-
796
+ char_btn . click ( fn = load_character_files , inputs = char_path , outputs = [ general_prompt ])
796
797
final_run_btn .click (fn = set_text_unfinished , outputs = generated_information
797
- ).then (process_generation , inputs = [sd_type ,model_type ,files , num_steps ,style , Ip_Adapter_Strength ,style_strength_ratio , guidance_scale , seed_ , sa32_ , sa64_ , id_length_ , general_prompt , negative_prompt , prompt_array ,G_height ,G_width ,comic_type , font_choice ], outputs = out_image
798
+ ).then (process_generation , inputs = [sd_type ,model_type ,files , num_steps ,style , Ip_Adapter_Strength ,style_strength_ratio , guidance_scale , seed_ , sa32_ , sa64_ , id_length_ , general_prompt , negative_prompt , prompt_array ,G_height ,G_width ,comic_type , font_choice , char_path ], outputs = out_image
798
799
).then (fn = set_text_finished ,outputs = generated_information )
799
800
800
801
0 commit comments