Skip to content

Commit

Permalink
Merge pull request #3 from LykosAI/add-experiments
Browse files Browse the repository at this point in the history
Add comfyui_experiments module
  • Loading branch information
ionite34 authored Mar 11, 2024
2 parents 92bb404 + 3d6b3eb commit b996c18
Show file tree
Hide file tree
Showing 10 changed files with 1,040 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/inference_core_nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = ("__version__", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS")

__version__ = "0.2.1"
__version__ = "0.3.0"


def _get_node_mappings():
Expand Down
674 changes: 674 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/LICENSE

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
## ComfyUI Experiments

Based on or modified from: [comfyanonymous/ComfyUI_experiments](https://github.com/comfyanonymous/ComfyUI_experiments) @ 934dba9d206e4738e0dac26a09b51f1dffcb4e44

License: GPL-3.0


23 changes: 23 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import importlib
import os

node_list = [ #Add list of .py files containing nodes here
"advanced_model_merging",
"reference_only",
"sampler_rescalecfg",
"sampler_tonemap",
"sampler_tonemap_rescalecfg",
"sdxl_model_merging"
]

NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}

for module_name in node_list:
imported_module = importlib.import_module(".{}".format(module_name), __name__)

NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
if hasattr(imported_module, "NODE_DISPLAY_NAME_MAPPINGS"):
NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import comfy_extras.nodes_model_merging

class ModelMergeBlockNumber(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}

argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})

arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument

for i in range(12):
arg_dict["input_blocks.{}.".format(i)] = argument

for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument

for i in range(12):
arg_dict["output_blocks.{}.".format(i)] = argument

arg_dict["out."] = argument

return {"required": arg_dict}


NODE_CLASS_MAPPINGS = {
"ModelMergeBlockNumber": ModelMergeBlockNumber,
}
54 changes: 54 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/reference_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

class ReferenceOnlySimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"reference": ("LATENT",),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 64})
}}

RETURN_TYPES = ("MODEL", "LATENT")
FUNCTION = "reference_only"

CATEGORY = "custom_node_experiments"

def reference_only(self, model, reference, batch_size):
model_reference = model.clone()
size_latent = list(reference["samples"].shape)
size_latent[0] = batch_size
latent = {}
latent["samples"] = torch.zeros(size_latent)

batch = latent["samples"].shape[0] + reference["samples"].shape[0]
def reference_apply(q, k, v, extra_options):
k = k.clone().repeat(1, 2, 1)
offset = 0
if q.shape[0] > batch:
offset = batch

for o in range(0, q.shape[0], batch):
for x in range(1, batch):
k[x + o, q.shape[1]:] = q[o,:]

return q, k, k

model_reference.set_model_attn1_patch(reference_apply)
out_latent = torch.cat((reference["samples"], latent["samples"]))
if "noise_mask" in latent:
mask = latent["noise_mask"]
else:
mask = torch.ones((64,64), dtype=torch.float32, device="cpu")

if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
if mask.shape[0] < latent["samples"].shape[0]:
print(latent["samples"].shape, mask.shape)
mask = mask.repeat(latent["samples"].shape[0], 1, 1)

out_mask = torch.zeros((1,mask.shape[1],mask.shape[2]), dtype=torch.float32, device="cpu")
return (model_reference, {"samples": out_latent, "noise_mask": torch.cat((out_mask, mask))})

NODE_CLASS_MAPPINGS = {
"ReferenceOnlySimple": ReferenceOnlySimple,
}
38 changes: 38 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/sampler_rescalecfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch


class RescaleClassifierFreeGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "custom_node_experiments"

def patch(self, model, multiplier):

def rescale_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]

x_cfg = uncond + cond_scale * (cond - uncond)
ro_pos = torch.std(cond, dim=(1,2,3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1,2,3), keepdim=True)

x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = multiplier * x_rescaled + (1.0 - multiplier) * x_cfg

return x_final

m = model.clone()
m.set_model_sampler_cfg_function(rescale_cfg)
return (m, )


NODE_CLASS_MAPPINGS = {
"RescaleClassifierFreeGuidanceTest": RescaleClassifierFreeGuidance,
}
44 changes: 44 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/sampler_tonemap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch


class ModelSamplerTonemapNoiseTest:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "custom_node_experiments"

def patch(self, model, multiplier):

def sampler_tonemap_reinhard(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:,None]
noise_pred /= noise_pred_vector_magnitude

mean = torch.mean(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1,2,3), keepdim=True)

top = (std * 3 + mean) * multiplier

#reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top

return uncond + noise_pred * new_magnitude * cond_scale

m = model.clone()
m.set_model_sampler_cfg_function(sampler_tonemap_reinhard)
return (m, )


NODE_CLASS_MAPPINGS = {
"ModelSamplerTonemapNoiseTest": ModelSamplerTonemapNoiseTest,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch


class TonemapNoiseWithRescaleCFG:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"tonemap_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
"rescale_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"

CATEGORY = "custom_node_experiments"

def patch(self, model, tonemap_multiplier, rescale_multiplier):

def tonemap_noise_rescale_cfg(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]

# Tonemap
noise_pred = (cond - uncond)
noise_pred_vector_magnitude = (torch.linalg.vector_norm(noise_pred, dim=(1)) + 0.0000000001)[:, None]
noise_pred /= noise_pred_vector_magnitude

mean = torch.mean(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True)
std = torch.std(noise_pred_vector_magnitude, dim=(1, 2, 3), keepdim=True)

top = (std * 3 + mean) * tonemap_multiplier

# Reinhard
noise_pred_vector_magnitude *= (1.0 / top)
new_magnitude = noise_pred_vector_magnitude / (noise_pred_vector_magnitude + 1.0)
new_magnitude *= top

# Rescale CFG
x_cfg = uncond + (noise_pred * new_magnitude * cond_scale)
ro_pos = torch.std(cond, dim=(1, 2, 3), keepdim=True)
ro_cfg = torch.std(x_cfg, dim=(1, 2, 3), keepdim=True)

x_rescaled = x_cfg * (ro_pos / ro_cfg)
x_final = rescale_multiplier * x_rescaled + (1.0 - rescale_multiplier) * x_cfg

return x_final

m = model.clone()
m.set_model_sampler_cfg_function(tonemap_noise_rescale_cfg)
return (m, )


NODE_CLASS_MAPPINGS = {
"TonemapNoiseWithRescaleCFG": TonemapNoiseWithRescaleCFG,
}
114 changes: 114 additions & 0 deletions src/inference_core_nodes/comfyui_experiments/sdxl_model_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import comfy_extras.nodes_model_merging

class ModelMergeSDXL(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}

argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})

arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument

for i in range(9):
arg_dict["input_blocks.{}".format(i)] = argument

for i in range(3):
arg_dict["middle_block.{}".format(i)] = argument

for i in range(9):
arg_dict["output_blocks.{}".format(i)] = argument

arg_dict["out."] = argument

return {"required": arg_dict}


class ModelMergeSDXLTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}

argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})

arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument

transformers = {4: 2, 5:2, 7:10, 8:10}

for i in range(9):
arg_dict["input_blocks.{}.0.".format(i)] = argument
if i in transformers:
arg_dict["input_blocks.{}.1.".format(i)] = argument
for j in range(transformers[i]):
arg_dict["input_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument

for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
if i == 1:
for j in range(10):
arg_dict["middle_block.{}.transformer_blocks.{}.".format(i, j)] = argument

transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
for i in range(9):
arg_dict["output_blocks.{}.0.".format(i)] = argument
t = 8 - i
if t in transformers:
arg_dict["output_blocks.{}.1.".format(i)] = argument
for j in range(transformers[t]):
arg_dict["output_blocks.{}.1.transformer_blocks.{}.".format(i, j)] = argument

arg_dict["out."] = argument

return {"required": arg_dict}

class ModelMergeSDXLDetailedTransformers(comfy_extras.nodes_model_merging.ModelMergeBlocks):
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}

argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})

arg_dict["time_embed."] = argument
arg_dict["label_emb."] = argument

transformers = {4: 2, 5:2, 7:10, 8:10}
transformers_args = ["norm1", "attn1.to_q", "attn1.to_k", "attn1.to_v", "attn1.to_out", "ff.net", "norm2", "attn2.to_q", "attn2.to_k", "attn2.to_v", "attn2.to_out", "norm3"]

for i in range(9):
arg_dict["input_blocks.{}.0.".format(i)] = argument
if i in transformers:
arg_dict["input_blocks.{}.1.".format(i)] = argument
for j in range(transformers[i]):
for x in transformers_args:
arg_dict["input_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument

for i in range(3):
arg_dict["middle_block.{}.".format(i)] = argument
if i == 1:
for j in range(10):
for x in transformers_args:
arg_dict["middle_block.{}.transformer_blocks.{}.{}".format(i, j, x)] = argument

transformers = {3:2, 4: 2, 5:2, 6:10, 7:10, 8:10}
for i in range(9):
arg_dict["output_blocks.{}.0.".format(i)] = argument
t = 8 - i
if t in transformers:
arg_dict["output_blocks.{}.1.".format(i)] = argument
for j in range(transformers[t]):
for x in transformers_args:
arg_dict["output_blocks.{}.1.transformer_blocks.{}.{}".format(i, j, x)] = argument

arg_dict["out."] = argument

return {"required": arg_dict}

NODE_CLASS_MAPPINGS = {
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSDXLTransformers": ModelMergeSDXLTransformers,
"ModelMergeSDXLDetailedTransformers": ModelMergeSDXLDetailedTransformers,
}

0 comments on commit b996c18

Please sign in to comment.