pre calc LoRA in generating

This commit is contained in:
Kohya S
2023-05-07 09:57:54 +09:00
parent 165fc43655
commit fdbdb4748a
2 changed files with 122 additions and 32 deletions

View File

@@ -66,6 +66,39 @@ class LoRAModule(torch.nn.Module):
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
# freezeしてマージする
def merge_to(self, sd, dtype, device):
# get up/down weight
up_weight = sd["lora_up.weight"].to(torch.float).to(device)
@@ -97,44 +130,45 @@ class LoRAModule(torch.nn.Module):
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier
# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)
# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale
return weight
def set_region(self, region):
self.region = region
self.region_mask = None
def forward(self, x):
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
class LoRAInfModule(LoRAModule):
def __init__(self, lora_name, org_module: torch.nn.Module, multiplier=1.0, lora_dim=4, alpha=1):
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)
# check regional or not by lora_name
self.text_encoder = False
if lora_name.startswith("lora_te_"):
self.regional = False
self.use_sub_prompt = True
self.text_encoder = True
elif "attn2_to_k" in lora_name or "attn2_to_v" in lora_name:
self.regional = False
self.use_sub_prompt = True
elif "time_emb" in lora_name:
self.regional = False
self.use_sub_prompt = False
else:
self.regional = True
self.use_sub_prompt = False
self.network: LoRANetwork = None
def set_network(self, network):
self.network = network
def default_forward(self, x):
# print("default_forward", self.lora_name, x.size())
return self.org_forward(x) + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
def forward(self, x):
if not self.enabled:
return self.org_forward(x)
if self.network is None or self.network.sub_prompt_index is None:
return self.default_forward(x)
if not self.regional and not self.use_sub_prompt:
@@ -769,6 +803,10 @@ class LoRANetwork(torch.nn.Module):
lora.apply_to()
self.add_module(lora.lora_name, lora)
# マージできるかどうかを返す
def is_mergeable(self):
return True
# TODO refactor to common function with apply_to
def merge_to(self, text_encoder, unet, weights_sd, dtype, device):
apply_text_encoder = apply_unet = False
@@ -955,3 +993,40 @@ class LoRANetwork(torch.nn.Module):
w = (w + 1) // 2
self.mask_dic = mask_dic
def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True
def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True
def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()
org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)
org_module._lora_restored = False
lora.enabled = False