mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add attension couple+reginal LoRA
This commit is contained in:
@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
|
||||
|
||||
import library.model_util as model_util
|
||||
import library.train_util as train_util
|
||||
from networks.lora import LoRANetwork
|
||||
import tools.original_control_net as original_control_net
|
||||
from tools.original_control_net import ControlNetInfo
|
||||
|
||||
@@ -634,6 +635,7 @@ class PipelineLike:
|
||||
img2img_noise=None,
|
||||
clip_prompts=None,
|
||||
clip_guide_images=None,
|
||||
networks: Optional[List[LoRANetwork]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -717,6 +719,7 @@ class PipelineLike:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
reginonal_network = " AND " in prompt[0]
|
||||
|
||||
vae_batch_size = (
|
||||
batch_size
|
||||
@@ -1010,6 +1013,11 @@ class PipelineLike:
|
||||
|
||||
# predict the noise residual
|
||||
if self.control_nets:
|
||||
if reginonal_network:
|
||||
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
|
||||
else:
|
||||
text_emb_last = text_embeddings
|
||||
noise_pred = original_control_net.call_unet_and_control_net(
|
||||
i,
|
||||
num_latent_input,
|
||||
@@ -1019,7 +1027,7 @@ class PipelineLike:
|
||||
i / len(timesteps),
|
||||
latent_model_input,
|
||||
t,
|
||||
text_embeddings,
|
||||
text_emb_last,
|
||||
).sample
|
||||
else:
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
# split the prompts with "AND". each prompt must have the same number of splits
|
||||
new_prompts = []
|
||||
for p in prompt:
|
||||
new_prompts.extend(p.split(" AND "))
|
||||
prompt = new_prompts
|
||||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
||||
if uncond_prompt is not None:
|
||||
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
|
||||
negative_scale: float
|
||||
strength: float
|
||||
network_muls: Tuple[float]
|
||||
num_sub_prompts: int
|
||||
|
||||
|
||||
class BatchData(NamedTuple):
|
||||
@@ -2276,16 +2291,20 @@ def main(args):
|
||||
print(f"metadata for: {network_weight}: {metadata}")
|
||||
|
||||
network, weights_sd = imported_module.create_network_from_weights(
|
||||
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
||||
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||
)
|
||||
else:
|
||||
raise ValueError("No weight. Weight is required.")
|
||||
if network is None:
|
||||
return
|
||||
|
||||
if not args.network_merge:
|
||||
mergiable = hasattr(network, "merge_to")
|
||||
if args.network_merge and not mergiable:
|
||||
print("network is not mergiable. ignore merge option.")
|
||||
|
||||
if not args.network_merge or not mergiable:
|
||||
network.apply_to(text_encoder, unet)
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||
print(f"weights are loaded: {info}")
|
||||
|
||||
if args.opt_channels_last:
|
||||
@@ -2349,12 +2368,12 @@ def main(args):
|
||||
if args.diffusers_xformers:
|
||||
pipe.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# Extended Textual Inversion および Textual Inversionを処理する
|
||||
if args.XTI_embeddings:
|
||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||
|
||||
# Textual Inversionを処理する
|
||||
if args.textual_inversion_embeddings:
|
||||
token_ids_embeds = []
|
||||
for embeds_file in args.textual_inversion_embeddings:
|
||||
@@ -2558,16 +2577,22 @@ def main(args):
|
||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||
|
||||
regional_network = False
|
||||
if networks and mask_images:
|
||||
# mask を領域情報として流用する、現在は1枚だけ対応
|
||||
# TODO 複数のnetwork classの混在時の考慮
|
||||
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||
regional_network = True
|
||||
print("use mask as region")
|
||||
# import cv2
|
||||
# for i in range(3):
|
||||
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
||||
# cv2.waitKey()
|
||||
# cv2.destroyAllWindows()
|
||||
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
||||
|
||||
size = None
|
||||
for i, network in enumerate(networks):
|
||||
if i < 3:
|
||||
np_mask = np.array(mask_images[0])
|
||||
np_mask = np_mask[:, :, i]
|
||||
size = np_mask.shape
|
||||
else:
|
||||
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
|
||||
network.set_region(i, i == len(networks) - 1, mask)
|
||||
mask_images = None
|
||||
|
||||
prev_image = None # for VGG16 guided
|
||||
@@ -2623,7 +2648,14 @@ def main(args):
|
||||
height_1st = height_1st - height_1st % 32
|
||||
|
||||
ext_1st = BatchDataExt(
|
||||
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
|
||||
width_1st,
|
||||
height_1st,
|
||||
args.highres_fix_steps,
|
||||
ext.scale,
|
||||
ext.negative_scale,
|
||||
ext.strength,
|
||||
ext.network_muls,
|
||||
ext.num_sub_prompts,
|
||||
)
|
||||
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||
images_1st = process_batch(batch_1st, True, True)
|
||||
@@ -2651,7 +2683,7 @@ def main(args):
|
||||
(
|
||||
return_latents,
|
||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls),
|
||||
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
|
||||
) = batch[0]
|
||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||
|
||||
@@ -2743,8 +2775,11 @@ def main(args):
|
||||
|
||||
# generate
|
||||
if networks:
|
||||
shared = {}
|
||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||
n.set_multiplier(m)
|
||||
if regional_network:
|
||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||
|
||||
images = pipe(
|
||||
prompts,
|
||||
@@ -2969,11 +3004,26 @@ def main(args):
|
||||
print("Use previous image as guide image.")
|
||||
guide_image = prev_image
|
||||
|
||||
if regional_network:
|
||||
num_sub_prompts = len(prompt.split(" AND "))
|
||||
assert (
|
||||
len(networks) <= num_sub_prompts
|
||||
), "Number of networks must be less than or equal to number of sub prompts."
|
||||
else:
|
||||
num_sub_prompts = None
|
||||
|
||||
b1 = BatchData(
|
||||
False,
|
||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||
BatchDataExt(
|
||||
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
|
||||
width,
|
||||
height,
|
||||
steps,
|
||||
scale,
|
||||
negative_scale,
|
||||
strength,
|
||||
tuple(network_muls) if network_muls else None,
|
||||
num_sub_prompts,
|
||||
),
|
||||
)
|
||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||
@@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
nargs="*",
|
||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||
# )
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
Reference in New Issue
Block a user