mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
rename and update
This commit is contained in:
145
sdxl_gen_img.py
145
sdxl_gen_img.py
@@ -47,10 +47,9 @@ import library.train_util as train_util
|
||||
import library.sdxl_model_util as sdxl_model_util
|
||||
import library.sdxl_train_util as sdxl_train_util
|
||||
from networks.lora import LoRANetwork
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
from library.sdxl_original_unet import SdxlUNet2DConditionModel
|
||||
from library.original_unet import FlashAttentionFunction
|
||||
from networks.control_net_lllite import ControlNetLLLite
|
||||
|
||||
# scheduler:
|
||||
SCHEDULER_LINEAR_START = 0.00085
|
||||
@@ -327,7 +326,7 @@ class PipelineLike:
|
||||
self.token_replacements_list.append({})
|
||||
|
||||
# ControlNet # not supported yet
|
||||
self.control_nets: List[ControlNetInfo] = []
|
||||
self.control_nets: List[ControlNetLLLite] = []
|
||||
self.control_net_enabled = True # control_netsが空ならTrueでもFalseでもControlNetは動作しない
|
||||
|
||||
# Textual Inversion
|
||||
@@ -392,6 +391,7 @@ class PipelineLike:
|
||||
is_cancelled_callback: Optional[Callable[[], bool]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
img2img_noise=None,
|
||||
clip_guide_images=None,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO support secondary prompt
|
||||
@@ -496,11 +496,16 @@ class PipelineLike:
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
||||
|
||||
if self.control_nets:
|
||||
# ControlNetのhintにguide imageを流用する
|
||||
if isinstance(clip_guide_images, PIL.Image.Image):
|
||||
clip_guide_images = [clip_guide_images]
|
||||
if isinstance(clip_guide_images[0], PIL.Image.Image):
|
||||
clip_guide_images = [preprocess_image(im) for im in clip_guide_images]
|
||||
clip_guide_images = torch.cat(clip_guide_images)
|
||||
if isinstance(clip_guide_images, list):
|
||||
clip_guide_images = torch.stack(clip_guide_images)
|
||||
|
||||
# ControlNetのhintにguide imageを流用する
|
||||
# 前処理はControlNet側で行う
|
||||
clip_guide_images = clip_guide_images.to(self.device, dtype=text_embeddings.dtype)
|
||||
|
||||
# create size embs
|
||||
if original_height is None:
|
||||
@@ -654,35 +659,47 @@ class PipelineLike:
|
||||
num_latent_input = (3 if negative_scale is not None else 2) if do_classifier_free_guidance else 1
|
||||
|
||||
if self.control_nets:
|
||||
guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
# guided_hints = original_control_net.get_guided_hints(self.control_nets, num_latent_input, batch_size, clip_guide_images)
|
||||
if self.control_net_enabled:
|
||||
for control_net in self.control_nets:
|
||||
with torch.no_grad():
|
||||
control_net.set_cond_image(clip_guide_images)
|
||||
else:
|
||||
for control_net in self.control_nets:
|
||||
control_net.set_cond_image(None)
|
||||
|
||||
for i, t in enumerate(tqdm(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets and self.control_net_enabled:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
# # disable control net if ratio is set
|
||||
# if self.control_nets and self.control_net_enabled:
|
||||
# pass # TODO
|
||||
|
||||
# not working yet
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i,
|
||||
num_latent_input,
|
||||
self.unet,
|
||||
self.control_nets,
|
||||
guided_hints,
|
||||
i / len(timesteps),
|
||||
latent_model_input,
|
||||
t,
|
||||
text_emb_last,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
# predict the noise residual
|
||||
# TODO Diffusers' ControlNet
|
||||
# if self.control_nets and self.control_net_enabled:
|
||||
# if reginonal_network:
|
||||
# num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
# text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2 :: num_sub_and_neg_prompts] # last subprompt
|
||||
# else:
|
||||
# text_emb_last = text_embeddings
|
||||
|
||||
# # not working yet
|
||||
# noise_pred = original_control_net.call_unet_and_control_net(
|
||||
# i,
|
||||
# num_latent_input,
|
||||
# self.unet,
|
||||
# self.control_nets,
|
||||
# guided_hints,
|
||||
# i / len(timesteps),
|
||||
# latent_model_input,
|
||||
# t,
|
||||
# text_emb_last,
|
||||
# ).sample
|
||||
# else:
|
||||
noise_pred = self.unet(latent_model_input, t, text_embeddings, vector_embeddings)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
@@ -1550,16 +1567,40 @@ def main(args):
|
||||
upscaler.to(dtype).to(device)
|
||||
|
||||
# ControlNetの処理
|
||||
control_nets: List[ControlNetInfo] = []
|
||||
if args.control_net_models:
|
||||
for i, model in enumerate(args.control_net_models):
|
||||
prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
control_nets: List[ControlNetLLLite] = []
|
||||
# if args.control_net_models:
|
||||
# for i, model in enumerate(args.control_net_models):
|
||||
# prep_type = None if not args.control_net_preps or len(args.control_net_preps) <= i else args.control_net_preps[i]
|
||||
# weight = 1.0 if not args.control_net_weights or len(args.control_net_weights) <= i else args.control_net_weights[i]
|
||||
# ratio = 1.0 if not args.control_net_ratios or len(args.control_net_ratios) <= i else args.control_net_ratios[i]
|
||||
|
||||
ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model)
|
||||
prep = original_control_net.load_preprocess(prep_type)
|
||||
control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||
# ctrl_unet, ctrl_net = original_control_net.load_control_net(False, unet, model)
|
||||
# prep = original_control_net.load_preprocess(prep_type)
|
||||
# control_nets.append(ControlNetInfo(ctrl_unet, ctrl_net, prep, weight, ratio))
|
||||
if args.control_net_lllite_models:
|
||||
for i, model_file in enumerate(args.control_net_lllite_models):
|
||||
print(f"loading ControlNet-LLLite: {model_file}")
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_file)
|
||||
mlp_dim = None
|
||||
cond_emb_dim = None
|
||||
for key, value in state_dict.items():
|
||||
if mlp_dim is None and "down.0.weight" in key:
|
||||
mlp_dim = value.shape[0]
|
||||
elif cond_emb_dim is None and "conditioning1.0" in key:
|
||||
cond_emb_dim = value.shape[0] * 2
|
||||
if mlp_dim is not None and cond_emb_dim is not None:
|
||||
break
|
||||
assert mlp_dim is not None and cond_emb_dim is not None, f"invalid control net: {model_file}"
|
||||
|
||||
control_net = ControlNetLLLite(unet, cond_emb_dim, mlp_dim)
|
||||
control_net.apply_to()
|
||||
control_net.load_state_dict(state_dict)
|
||||
control_net.to(dtype).to(device)
|
||||
control_net.set_batch_cond_only(False, False)
|
||||
control_nets.append(control_net)
|
||||
|
||||
if args.opt_channels_last:
|
||||
print(f"set optimizing: channels last")
|
||||
@@ -1572,8 +1613,9 @@ def main(args):
|
||||
network.to(memory_format=torch.channels_last)
|
||||
|
||||
for cn in control_nets:
|
||||
cn.unet.to(memory_format=torch.channels_last)
|
||||
cn.net.to(memory_format=torch.channels_last)
|
||||
cn.to(memory_format=torch.channels_last)
|
||||
# cn.unet.to(memory_format=torch.channels_last)
|
||||
# cn.net.to(memory_format=torch.channels_last)
|
||||
|
||||
pipe = PipelineLike(
|
||||
device,
|
||||
@@ -2573,20 +2615,23 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
||||
)
|
||||
parser.add_argument("--control_net_weights", type=float, default=None, nargs="*", help="ControlNet weights / ControlNetの重み")
|
||||
parser.add_argument(
|
||||
"--control_net_ratios",
|
||||
type=float,
|
||||
default=None,
|
||||
nargs="*",
|
||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
"--control_net_lllite_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--control_net_models", type=str, default=None, nargs="*", help="ControlNet models to use / 使用するControlNetのモデル名"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--control_net_preps", type=str, default=None, nargs="*", help="ControlNet preprocess to use / 使用するControlNetのプリプロセス名"
|
||||
# )
|
||||
# parser.add_argument("--control_net_multiplier", type=float, default=None, nargs="*", help="ControlNet multiplier / ControlNetの適用率")
|
||||
# parser.add_argument(
|
||||
# "--control_net_ratios",
|
||||
# type=float,
|
||||
# default=None,
|
||||
# nargs="*",
|
||||
# help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
# )
|
||||
# # parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
|
||||
Reference in New Issue
Block a user