mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add attension couple+reginal LoRA
This commit is contained in:
@@ -92,6 +92,7 @@ from PIL.PngImagePlugin import PngInfo
|
|||||||
|
|
||||||
import library.model_util as model_util
|
import library.model_util as model_util
|
||||||
import library.train_util as train_util
|
import library.train_util as train_util
|
||||||
|
from networks.lora import LoRANetwork
|
||||||
import tools.original_control_net as original_control_net
|
import tools.original_control_net as original_control_net
|
||||||
from tools.original_control_net import ControlNetInfo
|
from tools.original_control_net import ControlNetInfo
|
||||||
|
|
||||||
@@ -634,6 +635,7 @@ class PipelineLike:
|
|||||||
img2img_noise=None,
|
img2img_noise=None,
|
||||||
clip_prompts=None,
|
clip_prompts=None,
|
||||||
clip_guide_images=None,
|
clip_guide_images=None,
|
||||||
|
networks: Optional[List[LoRANetwork]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -717,6 +719,7 @@ class PipelineLike:
|
|||||||
batch_size = len(prompt)
|
batch_size = len(prompt)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
reginonal_network = " AND " in prompt[0]
|
||||||
|
|
||||||
vae_batch_size = (
|
vae_batch_size = (
|
||||||
batch_size
|
batch_size
|
||||||
@@ -1010,6 +1013,11 @@ class PipelineLike:
|
|||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
if self.control_nets:
|
if self.control_nets:
|
||||||
|
if reginonal_network:
|
||||||
|
num_sub_and_neg_prompts = len(text_embeddings) // batch_size
|
||||||
|
text_emb_last = text_embeddings[num_sub_and_neg_prompts - 2::num_sub_and_neg_prompts] # last subprompt
|
||||||
|
else:
|
||||||
|
text_emb_last = text_embeddings
|
||||||
noise_pred = original_control_net.call_unet_and_control_net(
|
noise_pred = original_control_net.call_unet_and_control_net(
|
||||||
i,
|
i,
|
||||||
num_latent_input,
|
num_latent_input,
|
||||||
@@ -1019,7 +1027,7 @@ class PipelineLike:
|
|||||||
i / len(timesteps),
|
i / len(timesteps),
|
||||||
latent_model_input,
|
latent_model_input,
|
||||||
t,
|
t,
|
||||||
text_embeddings,
|
text_emb_last,
|
||||||
).sample
|
).sample
|
||||||
else:
|
else:
|
||||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||||
@@ -1890,6 +1898,12 @@ def get_weighted_text_embeddings(
|
|||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = [prompt]
|
prompt = [prompt]
|
||||||
|
|
||||||
|
# split the prompts with "AND". each prompt must have the same number of splits
|
||||||
|
new_prompts = []
|
||||||
|
for p in prompt:
|
||||||
|
new_prompts.extend(p.split(" AND "))
|
||||||
|
prompt = new_prompts
|
||||||
|
|
||||||
if not skip_parsing:
|
if not skip_parsing:
|
||||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2, layer=layer)
|
||||||
if uncond_prompt is not None:
|
if uncond_prompt is not None:
|
||||||
@@ -2059,6 +2073,7 @@ class BatchDataExt(NamedTuple):
|
|||||||
negative_scale: float
|
negative_scale: float
|
||||||
strength: float
|
strength: float
|
||||||
network_muls: Tuple[float]
|
network_muls: Tuple[float]
|
||||||
|
num_sub_prompts: int
|
||||||
|
|
||||||
|
|
||||||
class BatchData(NamedTuple):
|
class BatchData(NamedTuple):
|
||||||
@@ -2276,14 +2291,18 @@ def main(args):
|
|||||||
print(f"metadata for: {network_weight}: {metadata}")
|
print(f"metadata for: {network_weight}: {metadata}")
|
||||||
|
|
||||||
network, weights_sd = imported_module.create_network_from_weights(
|
network, weights_sd = imported_module.create_network_from_weights(
|
||||||
network_mul, network_weight, vae, text_encoder, unet, **net_kwargs
|
network_mul, network_weight, vae, text_encoder, unet, for_inference=True, **net_kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("No weight. Weight is required.")
|
raise ValueError("No weight. Weight is required.")
|
||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not args.network_merge:
|
mergiable = hasattr(network, "merge_to")
|
||||||
|
if args.network_merge and not mergiable:
|
||||||
|
print("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
|
if not args.network_merge or not mergiable:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
@@ -2349,12 +2368,12 @@ def main(args):
|
|||||||
if args.diffusers_xformers:
|
if args.diffusers_xformers:
|
||||||
pipe.enable_xformers_memory_efficient_attention()
|
pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
# Extended Textual Inversion および Textual Inversionを処理する
|
||||||
if args.XTI_embeddings:
|
if args.XTI_embeddings:
|
||||||
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
diffusers.models.unet_2d_blocks.CrossAttnDownBlock2D.forward = downblock_forward_XTI
|
||||||
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
diffusers.models.unet_2d_blocks.CrossAttnUpBlock2D.forward = upblock_forward_XTI
|
||||||
|
|
||||||
# Textual Inversionを処理する
|
|
||||||
if args.textual_inversion_embeddings:
|
if args.textual_inversion_embeddings:
|
||||||
token_ids_embeds = []
|
token_ids_embeds = []
|
||||||
for embeds_file in args.textual_inversion_embeddings:
|
for embeds_file in args.textual_inversion_embeddings:
|
||||||
@@ -2558,16 +2577,22 @@ def main(args):
|
|||||||
print(f"resize img2img mask images to {args.W}*{args.H}")
|
print(f"resize img2img mask images to {args.W}*{args.H}")
|
||||||
mask_images = resize_images(mask_images, (args.W, args.H))
|
mask_images = resize_images(mask_images, (args.W, args.H))
|
||||||
|
|
||||||
|
regional_network = False
|
||||||
if networks and mask_images:
|
if networks and mask_images:
|
||||||
# mask を領域情報として流用する、現在は1枚だけ対応
|
# mask を領域情報として流用する、現在は一回のコマンド呼び出しで1枚だけ対応
|
||||||
# TODO 複数のnetwork classの混在時の考慮
|
regional_network = True
|
||||||
print("use mask as region")
|
print("use mask as region")
|
||||||
# import cv2
|
|
||||||
# for i in range(3):
|
size = None
|
||||||
# cv2.imshow("msk", np.array(mask_images[0])[:,:,i])
|
for i, network in enumerate(networks):
|
||||||
# cv2.waitKey()
|
if i < 3:
|
||||||
# cv2.destroyAllWindows()
|
np_mask = np.array(mask_images[0])
|
||||||
networks[0].__class__.set_regions(networks, np.array(mask_images[0]))
|
np_mask = np_mask[:, :, i]
|
||||||
|
size = np_mask.shape
|
||||||
|
else:
|
||||||
|
np_mask = np.full(size, 255, dtype=np.uint8)
|
||||||
|
mask = torch.from_numpy(np_mask.astype(np.float32) / 255.0)
|
||||||
|
network.set_region(i, i == len(networks) - 1, mask)
|
||||||
mask_images = None
|
mask_images = None
|
||||||
|
|
||||||
prev_image = None # for VGG16 guided
|
prev_image = None # for VGG16 guided
|
||||||
@@ -2623,7 +2648,14 @@ def main(args):
|
|||||||
height_1st = height_1st - height_1st % 32
|
height_1st = height_1st - height_1st % 32
|
||||||
|
|
||||||
ext_1st = BatchDataExt(
|
ext_1st = BatchDataExt(
|
||||||
width_1st, height_1st, args.highres_fix_steps, ext.scale, ext.negative_scale, ext.strength, ext.network_muls
|
width_1st,
|
||||||
|
height_1st,
|
||||||
|
args.highres_fix_steps,
|
||||||
|
ext.scale,
|
||||||
|
ext.negative_scale,
|
||||||
|
ext.strength,
|
||||||
|
ext.network_muls,
|
||||||
|
ext.num_sub_prompts,
|
||||||
)
|
)
|
||||||
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
batch_1st.append(BatchData(args.highres_fix_latents_upscaling, base, ext_1st))
|
||||||
images_1st = process_batch(batch_1st, True, True)
|
images_1st = process_batch(batch_1st, True, True)
|
||||||
@@ -2651,7 +2683,7 @@ def main(args):
|
|||||||
(
|
(
|
||||||
return_latents,
|
return_latents,
|
||||||
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
(step_first, _, _, _, init_image, mask_image, _, guide_image),
|
||||||
(width, height, steps, scale, negative_scale, strength, network_muls),
|
(width, height, steps, scale, negative_scale, strength, network_muls, num_sub_prompts),
|
||||||
) = batch[0]
|
) = batch[0]
|
||||||
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
noise_shape = (LATENT_CHANNELS, height // DOWNSAMPLING_FACTOR, width // DOWNSAMPLING_FACTOR)
|
||||||
|
|
||||||
@@ -2743,8 +2775,11 @@ def main(args):
|
|||||||
|
|
||||||
# generate
|
# generate
|
||||||
if networks:
|
if networks:
|
||||||
|
shared = {}
|
||||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||||
n.set_multiplier(m)
|
n.set_multiplier(m)
|
||||||
|
if regional_network:
|
||||||
|
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -2969,11 +3004,26 @@ def main(args):
|
|||||||
print("Use previous image as guide image.")
|
print("Use previous image as guide image.")
|
||||||
guide_image = prev_image
|
guide_image = prev_image
|
||||||
|
|
||||||
|
if regional_network:
|
||||||
|
num_sub_prompts = len(prompt.split(" AND "))
|
||||||
|
assert (
|
||||||
|
len(networks) <= num_sub_prompts
|
||||||
|
), "Number of networks must be less than or equal to number of sub prompts."
|
||||||
|
else:
|
||||||
|
num_sub_prompts = None
|
||||||
|
|
||||||
b1 = BatchData(
|
b1 = BatchData(
|
||||||
False,
|
False,
|
||||||
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
BatchDataBase(global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image),
|
||||||
BatchDataExt(
|
BatchDataExt(
|
||||||
width, height, steps, scale, negative_scale, strength, tuple(network_muls) if network_muls else None
|
width,
|
||||||
|
height,
|
||||||
|
steps,
|
||||||
|
scale,
|
||||||
|
negative_scale,
|
||||||
|
strength,
|
||||||
|
tuple(network_muls) if network_muls else None,
|
||||||
|
num_sub_prompts,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
if len(batch_data) > 0 and batch_data[-1].ext != b1.ext: # バッチ分割必要?
|
||||||
@@ -3197,6 +3247,9 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
help="ControlNet guidance ratio for steps / ControlNetでガイドするステップ比率",
|
||||||
)
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--control_net_image_path", type=str, default=None, nargs="*", help="image for ControlNet guidance / ControlNetでガイドに使う画像"
|
||||||
|
# )
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
261
networks/lora.py
261
networks/lora.py
@@ -10,7 +10,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from library import train_util
|
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
|
|
||||||
@@ -61,8 +60,6 @@ class LoRAModule(torch.nn.Module):
|
|||||||
|
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
self.org_module = org_module # remove in applying
|
self.org_module = org_module # remove in applying
|
||||||
self.region = None
|
|
||||||
self.region_mask = None
|
|
||||||
|
|
||||||
def apply_to(self):
|
def apply_to(self):
|
||||||
self.org_forward = self.org_module.forward
|
self.org_forward = self.org_module.forward
|
||||||
@@ -105,39 +102,187 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.region is None:
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# regional LoRA FIXME same as additional-network extension
|
|
||||||
if x.size()[1] % 77 == 0:
|
class LoRAInfModule(LoRAModule):
|
||||||
# print(f"LoRA for context: {self.lora_name}")
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
self.region = None
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||||
|
|
||||||
|
# check regional or not by lora_name
|
||||||
|
self.text_encoder = False
|
||||||
|
if lora_name.startswith("lora_te_"):
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
self.text_encoder = True
|
||||||
|
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
elif "time_emb" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
else:
|
||||||
|
self.regional = True
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
|
||||||
|
self.network: LoRANetwork = None
|
||||||
|
|
||||||
|
def set_network(self, network):
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
def default_forward(self, x):
|
||||||
|
# print("default_forward", self.lora_name, x.size())
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
# calculate region mask first time
|
def forward(self, x):
|
||||||
if self.region_mask is None:
|
if self.network is None or self.network.sub_prompt_index is None:
|
||||||
|
return self.default_forward(x)
|
||||||
|
if not self.regional and not self.use_sub_prompt:
|
||||||
|
return self.default_forward(x)
|
||||||
|
|
||||||
|
if self.regional:
|
||||||
|
return self.regional_forward(x)
|
||||||
|
else:
|
||||||
|
return self.sub_prompt_forward(x)
|
||||||
|
|
||||||
|
def get_mask_for_x(self, x):
|
||||||
|
# calculate size from shape of x
|
||||||
if len(x.size()) == 4:
|
if len(x.size()) == 4:
|
||||||
h, w = x.size()[2:4]
|
h, w = x.size()[2:4]
|
||||||
|
area = h * w
|
||||||
else:
|
else:
|
||||||
seq_len = x.size()[1]
|
area = x.size()[1]
|
||||||
ratio = math.sqrt((self.region.size()[0] * self.region.size()[1]) / seq_len)
|
|
||||||
h = int(self.region.size()[0] / ratio + 0.5)
|
|
||||||
w = seq_len // h
|
|
||||||
|
|
||||||
r = self.region.to(x.device)
|
mask = self.network.mask_dic[area]
|
||||||
if r.dtype == torch.bfloat16:
|
if mask is None:
|
||||||
r = r.to(torch.float)
|
raise ValueError(f"mask is None for resolution {area}")
|
||||||
r = r.unsqueeze(0).unsqueeze(1)
|
if len(x.size()) != 4:
|
||||||
# print(self.lora_name, self.region.size(), x.size(), r.size(), h, w)
|
mask = torch.reshape(mask, (1, -1, 1))
|
||||||
r = torch.nn.functional.interpolate(r, (h, w), mode="bilinear")
|
return mask
|
||||||
r = r.to(x.dtype)
|
|
||||||
|
|
||||||
if len(x.size()) == 3:
|
def regional_forward(self, x):
|
||||||
r = torch.reshape(r, (1, x.size()[1], -1))
|
if "attn2_to_out" in self.lora_name:
|
||||||
|
return self.to_out_forward(x)
|
||||||
|
|
||||||
self.region_mask = r
|
if self.network.mask_dic is None: # sub_prompt_index >= 3
|
||||||
|
return self.default_forward(x)
|
||||||
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale * self.region_mask
|
# apply mask for LoRA result
|
||||||
|
lx = self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
mask = self.get_mask_for_x(lx)
|
||||||
|
# print("regional", self.lora_name, self.network.sub_prompt_index, lx.size(), mask.size())
|
||||||
|
lx = lx * mask
|
||||||
|
|
||||||
|
x = self.org_forward(x)
|
||||||
|
x = x + lx
|
||||||
|
|
||||||
|
if "attn2_to_q" in self.lora_name and self.network.is_last_network:
|
||||||
|
x = self.postp_to_q(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def postp_to_q(self, x):
|
||||||
|
# repeat x to num_sub_prompts
|
||||||
|
has_real_uncond = x.size()[0] // self.network.batch_size == 3
|
||||||
|
qc = self.network.batch_size # uncond
|
||||||
|
qc += self.network.batch_size * self.network.num_sub_prompts # cond
|
||||||
|
if has_real_uncond:
|
||||||
|
qc += self.network.batch_size # real_uncond
|
||||||
|
|
||||||
|
query = torch.zeros((qc, x.size()[1], x.size()[2]), device=x.device, dtype=x.dtype)
|
||||||
|
query[: self.network.batch_size] = x[: self.network.batch_size]
|
||||||
|
|
||||||
|
for i in range(self.network.batch_size):
|
||||||
|
qi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||||
|
query[qi : qi + self.network.num_sub_prompts] = x[self.network.batch_size + i]
|
||||||
|
|
||||||
|
if has_real_uncond:
|
||||||
|
query[-self.network.batch_size :] = x[-self.network.batch_size :]
|
||||||
|
|
||||||
|
# print("postp_to_q", self.lora_name, x.size(), query.size(), self.network.num_sub_prompts)
|
||||||
|
return query
|
||||||
|
|
||||||
|
def sub_prompt_forward(self, x):
|
||||||
|
if x.size()[0] == self.network.batch_size: # if uncond in text_encoder, do not apply LoRA
|
||||||
|
return self.org_forward(x)
|
||||||
|
|
||||||
|
emb_idx = self.network.sub_prompt_index
|
||||||
|
if not self.text_encoder:
|
||||||
|
emb_idx += self.network.batch_size
|
||||||
|
|
||||||
|
# apply sub prompt of X
|
||||||
|
lx = x[emb_idx :: self.network.num_sub_prompts]
|
||||||
|
lx = self.lora_up(self.lora_down(lx)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
# print("sub_prompt_forward", self.lora_name, x.size(), lx.size(), emb_idx)
|
||||||
|
|
||||||
|
x = self.org_forward(x)
|
||||||
|
x[emb_idx :: self.network.num_sub_prompts] += lx
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def to_out_forward(self, x):
|
||||||
|
# print("to_out_forward", self.lora_name, x.size(), self.network.is_last_network)
|
||||||
|
|
||||||
|
if self.network.is_last_network:
|
||||||
|
masks = [None] * self.network.num_sub_prompts
|
||||||
|
self.network.shared[self.lora_name] = (None, masks)
|
||||||
|
else:
|
||||||
|
lx, masks = self.network.shared[self.lora_name]
|
||||||
|
|
||||||
|
# call own LoRA
|
||||||
|
x1 = x[self.network.batch_size + self.network.sub_prompt_index :: self.network.num_sub_prompts]
|
||||||
|
lx1 = self.lora_up(self.lora_down(x1)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
if self.network.is_last_network:
|
||||||
|
lx = torch.zeros(
|
||||||
|
(self.network.num_sub_prompts * self.network.batch_size, *lx1.size()[1:]), device=lx1.device, dtype=lx1.dtype
|
||||||
|
)
|
||||||
|
self.network.shared[self.lora_name] = (lx, masks)
|
||||||
|
|
||||||
|
# print("to_out_forward", lx.size(), lx1.size(), self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
|
lx[self.network.sub_prompt_index :: self.network.num_sub_prompts] += lx1
|
||||||
|
masks[self.network.sub_prompt_index] = self.get_mask_for_x(lx1)
|
||||||
|
|
||||||
|
# if not last network, return x and masks
|
||||||
|
x = self.org_forward(x)
|
||||||
|
if not self.network.is_last_network:
|
||||||
|
return x
|
||||||
|
|
||||||
|
lx, masks = self.network.shared.pop(self.lora_name)
|
||||||
|
|
||||||
|
# if last network, combine separated x with mask weighted sum
|
||||||
|
has_real_uncond = x.size()[0] // self.network.batch_size == self.network.num_sub_prompts + 2
|
||||||
|
|
||||||
|
out = torch.zeros((self.network.batch_size * (3 if has_real_uncond else 2), *x.size()[1:]), device=x.device, dtype=x.dtype)
|
||||||
|
out[: self.network.batch_size] = x[: self.network.batch_size] # uncond
|
||||||
|
if has_real_uncond:
|
||||||
|
out[-self.network.batch_size :] = x[-self.network.batch_size :] # real_uncond
|
||||||
|
|
||||||
|
# print("to_out_forward", self.lora_name, self.network.sub_prompt_index, self.network.num_sub_prompts)
|
||||||
|
# for i in range(len(masks)):
|
||||||
|
# if masks[i] is None:
|
||||||
|
# masks[i] = torch.zeros_like(masks[-1])
|
||||||
|
|
||||||
|
mask = torch.cat(masks)
|
||||||
|
mask_sum = torch.sum(mask, dim=0) + 1e-4
|
||||||
|
for i in range(self.network.batch_size):
|
||||||
|
# 1枚の画像ごとに処理する
|
||||||
|
lx1 = lx[i * self.network.num_sub_prompts : (i + 1) * self.network.num_sub_prompts]
|
||||||
|
lx1 = lx1 * mask
|
||||||
|
lx1 = torch.sum(lx1, dim=0)
|
||||||
|
|
||||||
|
xi = self.network.batch_size + i * self.network.num_sub_prompts
|
||||||
|
x1 = x[xi : xi + self.network.num_sub_prompts]
|
||||||
|
x1 = x1 * mask
|
||||||
|
x1 = torch.sum(x1, dim=0)
|
||||||
|
x1 = x1 / mask_sum
|
||||||
|
|
||||||
|
x1 = x1 + lx1
|
||||||
|
out[self.network.batch_size + i] = x1
|
||||||
|
|
||||||
|
# print("to_out_forward", x.size(), out.size(), has_real_uncond)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
def create_network(multiplier, network_dim, network_alpha, vae, text_encoder, unet, **kwargs):
|
||||||
@@ -421,7 +566,7 @@ def get_block_index(lora_name: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
# Create network from weights for inference, weights are not loaded here (because can be merged)
|
||||||
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, **kwargs):
|
def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs):
|
||||||
if weights_sd is None:
|
if weights_sd is None:
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import load_file, safe_open
|
from safetensors.torch import load_file, safe_open
|
||||||
@@ -450,7 +595,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
|
|||||||
if key not in modules_alpha:
|
if key not in modules_alpha:
|
||||||
modules_alpha = modules_dim[key]
|
modules_alpha = modules_dim[key]
|
||||||
|
|
||||||
network = LoRANetwork(text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha)
|
module_class = LoRAInfModule if for_inference else LoRAModule
|
||||||
|
|
||||||
|
network = LoRANetwork(
|
||||||
|
text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class
|
||||||
|
)
|
||||||
return network, weights_sd
|
return network, weights_sd
|
||||||
|
|
||||||
|
|
||||||
@@ -479,6 +628,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
conv_block_alphas=None,
|
conv_block_alphas=None,
|
||||||
modules_dim=None,
|
modules_dim=None,
|
||||||
modules_alpha=None,
|
modules_alpha=None,
|
||||||
|
module_class=LoRAModule,
|
||||||
varbose=False,
|
varbose=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -554,7 +704,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
skipped.append(lora_name)
|
skipped.append(lora_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora = LoRAModule(lora_name, child_module, self.multiplier, dim, alpha)
|
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha)
|
||||||
loras.append(lora)
|
loras.append(lora)
|
||||||
return loras, skipped
|
return loras, skipped
|
||||||
|
|
||||||
@@ -750,6 +900,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
if os.path.splitext(file)[1] == ".safetensors":
|
if os.path.splitext(file)[1] == ".safetensors":
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
from library import train_util
|
||||||
|
|
||||||
# Precalculate model hashes to save time on indexing
|
# Precalculate model hashes to save time on indexing
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
@@ -762,17 +913,45 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
torch.save(state_dict, file)
|
torch.save(state_dict, file)
|
||||||
|
|
||||||
@staticmethod
|
# mask is a tensor with values from 0 to 1
|
||||||
def set_regions(networks, image):
|
def set_region(self, sub_prompt_index, is_last_network, mask):
|
||||||
image = image.astype(np.float32) / 255.0
|
if mask.max() == 0:
|
||||||
for i, network in enumerate(networks[:3]):
|
mask = torch.ones_like(mask)
|
||||||
# NOTE: consider averaging overwrapping area
|
|
||||||
region = image[:, :, i]
|
|
||||||
if region.max() == 0:
|
|
||||||
continue
|
|
||||||
region = torch.tensor(region)
|
|
||||||
network.set_region(region)
|
|
||||||
|
|
||||||
def set_region(self, region):
|
self.mask = mask
|
||||||
for lora in self.unet_loras:
|
self.sub_prompt_index = sub_prompt_index
|
||||||
lora.set_region(region)
|
self.is_last_network = is_last_network
|
||||||
|
|
||||||
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
|
lora.set_network(self)
|
||||||
|
|
||||||
|
def set_current_generation(self, batch_size, num_sub_prompts, width, height, shared):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_sub_prompts = num_sub_prompts
|
||||||
|
self.current_size = (height, width)
|
||||||
|
self.shared = shared
|
||||||
|
|
||||||
|
# create masks
|
||||||
|
mask = self.mask
|
||||||
|
mask_dic = {}
|
||||||
|
mask = mask.unsqueeze(0).unsqueeze(1) # b(1),c(1),h,w
|
||||||
|
ref_weight = self.text_encoder_loras[0].lora_down.weight if self.text_encoder_loras else self.unet_loras[0].lora_down.weight
|
||||||
|
dtype = ref_weight.dtype
|
||||||
|
device = ref_weight.device
|
||||||
|
|
||||||
|
def resize_add(mh, mw):
|
||||||
|
# print(mh, mw, mh * mw)
|
||||||
|
m = torch.nn.functional.interpolate(mask, (mh, mw), mode="bilinear") # doesn't work in bf16
|
||||||
|
m = m.to(device, dtype=dtype)
|
||||||
|
mask_dic[mh * mw] = m
|
||||||
|
|
||||||
|
h = height // 8
|
||||||
|
w = width // 8
|
||||||
|
for _ in range(4):
|
||||||
|
resize_add(h, w)
|
||||||
|
if h % 2 == 1 or w % 2 == 1: # add extra shape if h/w is not divisible by 2
|
||||||
|
resize_add(h + h % 2, w + w % 2)
|
||||||
|
h = (h + 1) // 2
|
||||||
|
w = (w + 1) // 2
|
||||||
|
|
||||||
|
self.mask_dic = mask_dic
|
||||||
|
|||||||
Reference in New Issue
Block a user