revert lora+ for lora_fa

This commit is contained in:
Kohya S
2024-05-12 17:00:51 +09:00
parent c6a437054a
commit 3c8193f642

View File

@@ -15,10 +15,8 @@ import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -506,15 +504,6 @@ def create_network(
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network return network
@@ -540,9 +529,7 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
logger.warning( logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -816,17 +803,11 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None: if modules_dim is not None:
logger.info(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
logger.info(f"create LoRA network from block_dims") logger.info(f"create LoRA network from block_dims")
logger.info( logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}") logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}") logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
@@ -834,13 +815,9 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"conv_block_alphas: {conv_block_alphas}") logger.info(f"conv_block_alphas: {conv_block_alphas}")
else: else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info( logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
logger.info( logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances # create module instances
def create_modules( def create_modules(
@@ -962,11 +939,6 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
@@ -1065,42 +1037,18 @@ class LoRANetwork(torch.nn.Module):
self.requires_grad_(True) self.requires_grad_(True)
all_params = [] all_params = []
def assemble_params(loras, lr, ratio): def enumerate_params(loras: List[LoRAModule]):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
params = [] params = []
for key in param_groups.keys(): for lora in loras:
param_data = {"params": param_groups[key].values()} # params.extend(lora.parameters())
params.extend(lora.get_trainable_params())
if len(param_data["params"]) == 0:
continue
if lr is not None:
if key == "plus":
param_data["lr"] = lr * ratio
else:
param_data["lr"] = lr
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue
params.append(param_data)
return params return params
if self.text_encoder_loras: if self.text_encoder_loras:
params = assemble_params( param_data = {"params": enumerate_params(self.text_encoder_loras)}
self.text_encoder_loras, if text_encoder_lr is not None:
text_encoder_lr if text_encoder_lr is not None else default_lr, param_data["lr"] = text_encoder_lr
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, all_params.append(param_data)
)
all_params.extend(params)
if self.unet_loras: if self.unet_loras:
if self.block_lr: if self.block_lr:
@@ -1114,20 +1062,21 @@ class LoRANetwork(torch.nn.Module):
# blockごとにパラメータを設定する # blockごとにパラメータを設定する
for idx, block_loras in block_idx_to_lora.items(): for idx, block_loras in block_idx_to_lora.items():
params = assemble_params( param_data = {"params": enumerate_params(block_loras)}
block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), if unet_lr is not None:
self.loraplus_unet_lr_ratio or self.loraplus_ratio, param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
) elif default_lr is not None:
all_params.extend(params) param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
if ("lr" in param_data) and (param_data["lr"] == 0):
continue
all_params.append(param_data)
else: else:
params = assemble_params( param_data = {"params": enumerate_params(self.unet_loras)}
self.unet_loras, if unet_lr is not None:
unet_lr if unet_lr is not None else default_lr, param_data["lr"] = unet_lr
self.loraplus_unet_lr_ratio or self.loraplus_ratio, all_params.append(param_data)
)
all_params.extend(params)
return all_params return all_params
@@ -1144,9 +1093,6 @@ class LoRANetwork(torch.nn.Module):
def get_trainable_params(self): def get_trainable_params(self):
return self.parameters() return self.parameters()
def get_trainable_named_params(self):
return self.named_parameters()
def save_weights(self, file, dtype, metadata): def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0: if metadata is not None and len(metadata) == 0:
metadata = None metadata = None