mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 14:45:19 +00:00
revert lora+ for lora_fa
This commit is contained in:
@@ -15,10 +15,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
@@ -506,15 +504,6 @@ def create_network(
|
|||||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||||
|
|
||||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
|
||||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
|
||||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
|
||||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
|
||||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
|
||||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
|
||||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
|
||||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@@ -540,9 +529,7 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_dims) == num_total_blocks
|
len(block_dims) == num_total_blocks
|
||||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
||||||
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
|
||||||
)
|
|
||||||
block_dims = [network_dim] * num_total_blocks
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
if block_alphas is not None:
|
if block_alphas is not None:
|
||||||
@@ -816,17 +803,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
self.loraplus_lr_ratio = None
|
|
||||||
self.loraplus_unet_lr_ratio = None
|
|
||||||
self.loraplus_text_encoder_lr_ratio = None
|
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
logger.info(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
logger.info(f"create LoRA network from block_dims")
|
logger.info(f"create LoRA network from block_dims")
|
||||||
logger.info(
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
|
||||||
)
|
|
||||||
logger.info(f"block_dims: {block_dims}")
|
logger.info(f"block_dims: {block_dims}")
|
||||||
logger.info(f"block_alphas: {block_alphas}")
|
logger.info(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
@@ -834,13 +815,9 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
logger.info(
|
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
||||||
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
|
||||||
)
|
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
logger.info(
|
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
||||||
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -962,11 +939,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
|
||||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
|
||||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
|
||||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
@@ -1065,42 +1037,18 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
def assemble_params(loras, lr, ratio):
|
def enumerate_params(loras: List[LoRAModule]):
|
||||||
param_groups = {"lora": {}, "plus": {}}
|
|
||||||
for lora in loras:
|
|
||||||
for name, param in lora.named_parameters():
|
|
||||||
if ratio is not None and "lora_up" in name:
|
|
||||||
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
|
|
||||||
else:
|
|
||||||
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
|
|
||||||
|
|
||||||
params = []
|
params = []
|
||||||
for key in param_groups.keys():
|
for lora in loras:
|
||||||
param_data = {"params": param_groups[key].values()}
|
# params.extend(lora.parameters())
|
||||||
|
params.extend(lora.get_trainable_params())
|
||||||
if len(param_data["params"]) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if lr is not None:
|
|
||||||
if key == "plus":
|
|
||||||
param_data["lr"] = lr * ratio
|
|
||||||
else:
|
|
||||||
param_data["lr"] = lr
|
|
||||||
|
|
||||||
if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
params.append(param_data)
|
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
if self.text_encoder_loras:
|
if self.text_encoder_loras:
|
||||||
params = assemble_params(
|
param_data = {"params": enumerate_params(self.text_encoder_loras)}
|
||||||
self.text_encoder_loras,
|
if text_encoder_lr is not None:
|
||||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
param_data["lr"] = text_encoder_lr
|
||||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
|
all_params.append(param_data)
|
||||||
)
|
|
||||||
all_params.extend(params)
|
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
if self.block_lr:
|
if self.block_lr:
|
||||||
@@ -1114,20 +1062,21 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
|
|
||||||
# blockごとにパラメータを設定する
|
# blockごとにパラメータを設定する
|
||||||
for idx, block_loras in block_idx_to_lora.items():
|
for idx, block_loras in block_idx_to_lora.items():
|
||||||
params = assemble_params(
|
param_data = {"params": enumerate_params(block_loras)}
|
||||||
block_loras,
|
|
||||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
if unet_lr is not None:
|
||||||
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
|
param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0])
|
||||||
)
|
elif default_lr is not None:
|
||||||
all_params.extend(params)
|
param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0])
|
||||||
|
if ("lr" in param_data) and (param_data["lr"] == 0):
|
||||||
|
continue
|
||||||
|
all_params.append(param_data)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
params = assemble_params(
|
param_data = {"params": enumerate_params(self.unet_loras)}
|
||||||
self.unet_loras,
|
if unet_lr is not None:
|
||||||
unet_lr if unet_lr is not None else default_lr,
|
param_data["lr"] = unet_lr
|
||||||
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
|
all_params.append(param_data)
|
||||||
)
|
|
||||||
all_params.extend(params)
|
|
||||||
|
|
||||||
return all_params
|
return all_params
|
||||||
|
|
||||||
@@ -1144,9 +1093,6 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
def get_trainable_params(self):
|
def get_trainable_params(self):
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
def get_trainable_named_params(self):
|
|
||||||
return self.named_parameters()
|
|
||||||
|
|
||||||
def save_weights(self, file, dtype, metadata):
|
def save_weights(self, file, dtype, metadata):
|
||||||
if metadata is not None and len(metadata) == 0:
|
if metadata is not None and len(metadata) == 0:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|||||||
Reference in New Issue
Block a user