mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
pre calc LoRA in generating
This commit is contained in:
@@ -2262,6 +2262,8 @@ def main(args):
|
|||||||
if args.network_module:
|
if args.network_module:
|
||||||
networks = []
|
networks = []
|
||||||
network_default_muls = []
|
network_default_muls = []
|
||||||
|
network_pre_calc=args.network_pre_calc
|
||||||
|
|
||||||
for i, network_module in enumerate(args.network_module):
|
for i, network_module in enumerate(args.network_module):
|
||||||
print("import network module:", network_module)
|
print("import network module:", network_module)
|
||||||
imported_module = importlib.import_module(network_module)
|
imported_module = importlib.import_module(network_module)
|
||||||
@@ -2298,11 +2300,11 @@ def main(args):
|
|||||||
if network is None:
|
if network is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
mergiable = hasattr(network, "merge_to")
|
mergeable = network.is_mergeable()
|
||||||
if args.network_merge and not mergiable:
|
if args.network_merge and not mergeable:
|
||||||
print("network is not mergiable. ignore merge option.")
|
print("network is not mergiable. ignore merge option.")
|
||||||
|
|
||||||
if not args.network_merge or not mergiable:
|
if not args.network_merge or not mergeable:
|
||||||
network.apply_to(text_encoder, unet)
|
network.apply_to(text_encoder, unet)
|
||||||
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
info = network.load_state_dict(weights_sd, False) # network.load_weightsを使うようにするとよい
|
||||||
print(f"weights are loaded: {info}")
|
print(f"weights are loaded: {info}")
|
||||||
@@ -2311,6 +2313,10 @@ def main(args):
|
|||||||
network.to(memory_format=torch.channels_last)
|
network.to(memory_format=torch.channels_last)
|
||||||
network.to(dtype).to(device)
|
network.to(dtype).to(device)
|
||||||
|
|
||||||
|
if network_pre_calc:
|
||||||
|
print("backup original weights")
|
||||||
|
network.backup_weights()
|
||||||
|
|
||||||
networks.append(network)
|
networks.append(network)
|
||||||
else:
|
else:
|
||||||
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
network.merge_to(text_encoder, unet, weights_sd, dtype, device)
|
||||||
@@ -2815,11 +2821,19 @@ def main(args):
|
|||||||
|
|
||||||
# generate
|
# generate
|
||||||
if networks:
|
if networks:
|
||||||
|
# 追加ネットワークの処理
|
||||||
shared = {}
|
shared = {}
|
||||||
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
for n, m in zip(networks, network_muls if network_muls else network_default_muls):
|
||||||
n.set_multiplier(m)
|
n.set_multiplier(m)
|
||||||
if regional_network:
|
if regional_network:
|
||||||
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
n.set_current_generation(batch_size, num_sub_prompts, width, height, shared)
|
||||||
|
|
||||||
|
if not regional_network and network_pre_calc:
|
||||||
|
for n in networks:
|
||||||
|
n.restore_weights()
|
||||||
|
for n in networks:
|
||||||
|
n.pre_calculation()
|
||||||
|
print("pre-calculation... done")
|
||||||
|
|
||||||
images = pipe(
|
images = pipe(
|
||||||
prompts,
|
prompts,
|
||||||
@@ -3204,6 +3218,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
parser.add_argument("--network_show_meta", action="store_true", help="show metadata of network model / ネットワークモデルのメタデータを表示する")
|
||||||
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
parser.add_argument("--network_merge", action="store_true", help="merge network weights to original model / ネットワークの重みをマージする")
|
||||||
|
parser.add_argument("--network_pre_calc", action="store_true", help="pre-calculate network for generation / ネットワークのあらかじめ計算して生成する")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--textual_inversion_embeddings",
|
"--textual_inversion_embeddings",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
133
networks/lora.py
133
networks/lora.py
@@ -66,6 +66,39 @@ class LoRAModule(torch.nn.Module):
|
|||||||
self.org_module.forward = self.forward
|
self.org_module.forward = self.forward
|
||||||
del self.org_module
|
del self.org_module
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAInfModule(LoRAModule):
|
||||||
|
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
||||||
|
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
||||||
|
|
||||||
|
self.org_module_ref = [org_module] # 後から参照できるように
|
||||||
|
self.enabled = True
|
||||||
|
|
||||||
|
# check regional or not by lora_name
|
||||||
|
self.text_encoder = False
|
||||||
|
if lora_name.startswith("lora_te_"):
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
self.text_encoder = True
|
||||||
|
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = True
|
||||||
|
elif "time_emb" in lora_name:
|
||||||
|
self.regional = False
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
else:
|
||||||
|
self.regional = True
|
||||||
|
self.use_sub_prompt = False
|
||||||
|
|
||||||
|
self.network: LoRANetwork = None
|
||||||
|
|
||||||
|
def set_network(self, network):
|
||||||
|
self.network = network
|
||||||
|
|
||||||
|
# freezeしてマージする
|
||||||
def merge_to(self, sd, dtype, device):
|
def merge_to(self, sd, dtype, device):
|
||||||
# get up/down weight
|
# get up/down weight
|
||||||
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
|
||||||
@@ -97,44 +130,45 @@ class LoRAModule(torch.nn.Module):
|
|||||||
org_sd["weight"] = weight.to(dtype)
|
org_sd["weight"] = weight.to(dtype)
|
||||||
self.org_module.load_state_dict(org_sd)
|
self.org_module.load_state_dict(org_sd)
|
||||||
|
|
||||||
|
# 復元できるマージのため、このモジュールのweightを返す
|
||||||
|
def get_weight(self, multiplier=None):
|
||||||
|
if multiplier is None:
|
||||||
|
multiplier = self.multiplier
|
||||||
|
|
||||||
|
# get up/down weight from module
|
||||||
|
up_weight = self.lora_up.weight.to(torch.float)
|
||||||
|
down_weight = self.lora_down.weight.to(torch.float)
|
||||||
|
|
||||||
|
# pre-calculated weight
|
||||||
|
if len(down_weight.size()) == 2:
|
||||||
|
# linear
|
||||||
|
weight = self.multiplier * (up_weight @ down_weight) * self.scale
|
||||||
|
elif down_weight.size()[2:4] == (1, 1):
|
||||||
|
# conv2d 1x1
|
||||||
|
weight = (
|
||||||
|
self.multiplier
|
||||||
|
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
|
||||||
|
* self.scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# conv2d 3x3
|
||||||
|
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
|
||||||
|
weight = self.multiplier * conved * self.scale
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|
||||||
def set_region(self, region):
|
def set_region(self, region):
|
||||||
self.region = region
|
self.region = region
|
||||||
self.region_mask = None
|
self.region_mask = None
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
|
||||||
|
|
||||||
|
|
||||||
class LoRAInfModule(LoRAModule):
|
|
||||||
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
|
|
||||||
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
|
|
||||||
|
|
||||||
# check regional or not by lora_name
|
|
||||||
self.text_encoder = False
|
|
||||||
if lora_name.startswith("lora_te_"):
|
|
||||||
self.regional = False
|
|
||||||
self.use_sub_prompt = True
|
|
||||||
self.text_encoder = True
|
|
||||||
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
|
|
||||||
self.regional = False
|
|
||||||
self.use_sub_prompt = True
|
|
||||||
elif "time_emb" in lora_name:
|
|
||||||
self.regional = False
|
|
||||||
self.use_sub_prompt = False
|
|
||||||
else:
|
|
||||||
self.regional = True
|
|
||||||
self.use_sub_prompt = False
|
|
||||||
|
|
||||||
self.network: LoRANetwork = None
|
|
||||||
|
|
||||||
def set_network(self, network):
|
|
||||||
self.network = network
|
|
||||||
|
|
||||||
def default_forward(self, x):
|
def default_forward(self, x):
|
||||||
# print("default_forward", self.lora_name, x.size())
|
# print("default_forward", self.lora_name, x.size())
|
||||||
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
if not self.enabled:
|
||||||
|
return self.org_forward(x)
|
||||||
|
|
||||||
if self.network is None or self.network.sub_prompt_index is None:
|
if self.network is None or self.network.sub_prompt_index is None:
|
||||||
return self.default_forward(x)
|
return self.default_forward(x)
|
||||||
if not self.regional and not self.use_sub_prompt:
|
if not self.regional and not self.use_sub_prompt:
|
||||||
@@ -769,6 +803,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
lora.apply_to()
|
lora.apply_to()
|
||||||
self.add_module(lora.lora_name, lora)
|
self.add_module(lora.lora_name, lora)
|
||||||
|
|
||||||
|
# マージできるかどうかを返す
|
||||||
|
def is_mergeable(self):
|
||||||
|
return True
|
||||||
|
|
||||||
# TODO refactor to common function with apply_to
|
# TODO refactor to common function with apply_to
|
||||||
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
|
||||||
apply_text_encoder = apply_unet = False
|
apply_text_encoder = apply_unet = False
|
||||||
@@ -955,3 +993,40 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
w = (w + 1) // 2
|
w = (w + 1) // 2
|
||||||
|
|
||||||
self.mask_dic = mask_dic
|
self.mask_dic = mask_dic
|
||||||
|
|
||||||
|
def backup_weights(self):
|
||||||
|
# 重みのバックアップを行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not hasattr(org_module, "_lora_org_weight"):
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
org_module._lora_org_weight = sd["weight"].detach().clone()
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def restore_weights(self):
|
||||||
|
# 重みのリストアを行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
if not org_module._lora_restored:
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
sd["weight"] = org_module._lora_org_weight
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
org_module._lora_restored = True
|
||||||
|
|
||||||
|
def pre_calculation(self):
|
||||||
|
# 事前計算を行う
|
||||||
|
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
|
||||||
|
for lora in loras:
|
||||||
|
org_module = lora.org_module_ref[0]
|
||||||
|
sd = org_module.state_dict()
|
||||||
|
|
||||||
|
org_weight = sd["weight"]
|
||||||
|
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
|
||||||
|
sd["weight"] = org_weight + lora_weight
|
||||||
|
assert sd["weight"].shape == org_weight.shape
|
||||||
|
org_module.load_state_dict(sd)
|
||||||
|
|
||||||
|
org_module._lora_restored = False
|
||||||
|
lora.enabled = False
|
||||||
|
|||||||
Reference in New Issue
Block a user