mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update loraplus on dylora/lofa_fa
This commit is contained in:
@@ -18,10 +18,13 @@ from transformers import CLIPTextModel
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DyLoRAModule(torch.nn.Module):
|
class DyLoRAModule(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||||
@@ -211,6 +214,16 @@ def create_network(
|
|||||||
unit=unit,
|
unit=unit,
|
||||||
varbose=True,
|
varbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||||
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||||
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||||
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
self.alpha = alpha
|
self.alpha = alpha
|
||||||
self.apply_to_conv = apply_to_conv
|
self.apply_to_conv = apply_to_conv
|
||||||
|
|
||||||
|
self.loraplus_lr_ratio = None
|
||||||
|
self.loraplus_unet_lr_ratio = None
|
||||||
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
logger.info("create LoRA network from weights")
|
logger.info("create LoRA network from weights")
|
||||||
else:
|
else:
|
||||||
@@ -346,6 +363,11 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
self.unet_loras = create_modules(True, unet, target_modules)
|
self.unet_loras = create_modules(True, unet, target_modules)
|
||||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||||
|
|
||||||
|
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||||
|
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||||
|
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||||
|
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
@@ -407,15 +429,7 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
self,
|
|
||||||
text_encoder_lr,
|
|
||||||
unet_lr,
|
|
||||||
default_lr,
|
|
||||||
text_encoder_loraplus_ratio=None,
|
|
||||||
unet_loraplus_ratio=None,
|
|
||||||
loraplus_ratio=None
|
|
||||||
):
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -452,15 +466,13 @@ class DyLoRANetwork(torch.nn.Module):
|
|||||||
params = assemble_params(
|
params = assemble_params(
|
||||||
self.text_encoder_loras,
|
self.text_encoder_loras,
|
||||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
text_encoder_loraplus_ratio or loraplus_ratio
|
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
if self.unet_loras:
|
if self.unet_loras:
|
||||||
params = assemble_params(
|
params = assemble_params(
|
||||||
self.unet_loras,
|
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio
|
||||||
default_lr if unet_lr is None else unet_lr,
|
|
||||||
unet_loraplus_ratio or loraplus_ratio
|
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
|
|||||||
@@ -499,6 +499,7 @@ def create_network(
|
|||||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
if block_lr_weight is not None:
|
if block_lr_weight is not None:
|
||||||
@@ -855,6 +856,10 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
self.loraplus_lr_ratio = None
|
||||||
|
self.loraplus_unet_lr_ratio = None
|
||||||
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
logger.info(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
|
|||||||
@@ -15,8 +15,10 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import re
|
import re
|
||||||
from library.utils import setup_logging
|
from library.utils import setup_logging
|
||||||
|
|
||||||
setup_logging()
|
setup_logging()
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
|
||||||
@@ -504,6 +506,15 @@ def create_network(
|
|||||||
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
|
||||||
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
|
||||||
|
|
||||||
|
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||||
|
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||||
|
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||||
|
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||||
|
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||||
|
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||||
|
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||||
|
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||||
|
|
||||||
return network
|
return network
|
||||||
|
|
||||||
|
|
||||||
@@ -529,7 +540,9 @@ def get_block_dims_and_alphas(
|
|||||||
len(block_dims) == num_total_blocks
|
len(block_dims) == num_total_blocks
|
||||||
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
|
||||||
else:
|
else:
|
||||||
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります")
|
logger.warning(
|
||||||
|
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
|
||||||
|
)
|
||||||
block_dims = [network_dim] * num_total_blocks
|
block_dims = [network_dim] * num_total_blocks
|
||||||
|
|
||||||
if block_alphas is not None:
|
if block_alphas is not None:
|
||||||
@@ -803,11 +816,17 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
self.rank_dropout = rank_dropout
|
self.rank_dropout = rank_dropout
|
||||||
self.module_dropout = module_dropout
|
self.module_dropout = module_dropout
|
||||||
|
|
||||||
|
self.loraplus_lr_ratio = None
|
||||||
|
self.loraplus_unet_lr_ratio = None
|
||||||
|
self.loraplus_text_encoder_lr_ratio = None
|
||||||
|
|
||||||
if modules_dim is not None:
|
if modules_dim is not None:
|
||||||
logger.info(f"create LoRA network from weights")
|
logger.info(f"create LoRA network from weights")
|
||||||
elif block_dims is not None:
|
elif block_dims is not None:
|
||||||
logger.info(f"create LoRA network from block_dims")
|
logger.info(f"create LoRA network from block_dims")
|
||||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(
|
||||||
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
|
)
|
||||||
logger.info(f"block_dims: {block_dims}")
|
logger.info(f"block_dims: {block_dims}")
|
||||||
logger.info(f"block_alphas: {block_alphas}")
|
logger.info(f"block_alphas: {block_alphas}")
|
||||||
if conv_block_dims is not None:
|
if conv_block_dims is not None:
|
||||||
@@ -815,9 +834,13 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
logger.info(f"conv_block_alphas: {conv_block_alphas}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
|
||||||
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}")
|
logger.info(
|
||||||
|
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
|
||||||
|
)
|
||||||
if self.conv_lora_dim is not None:
|
if self.conv_lora_dim is not None:
|
||||||
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}")
|
logger.info(
|
||||||
|
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
|
||||||
|
)
|
||||||
|
|
||||||
# create module instances
|
# create module instances
|
||||||
def create_modules(
|
def create_modules(
|
||||||
@@ -939,6 +962,11 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
|
||||||
names.add(lora.lora_name)
|
names.add(lora.lora_name)
|
||||||
|
|
||||||
|
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||||
|
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||||
|
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||||
|
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||||
|
|
||||||
def set_multiplier(self, multiplier):
|
def set_multiplier(self, multiplier):
|
||||||
self.multiplier = multiplier
|
self.multiplier = multiplier
|
||||||
for lora in self.text_encoder_loras + self.unet_loras:
|
for lora in self.text_encoder_loras + self.unet_loras:
|
||||||
@@ -1033,15 +1061,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
return lr_weight
|
return lr_weight
|
||||||
|
|
||||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||||
def prepare_optimizer_params(
|
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||||
self,
|
|
||||||
text_encoder_lr,
|
|
||||||
unet_lr,
|
|
||||||
default_lr,
|
|
||||||
text_encoder_loraplus_ratio=None,
|
|
||||||
unet_loraplus_ratio=None,
|
|
||||||
loraplus_ratio=None
|
|
||||||
):
|
|
||||||
self.requires_grad_(True)
|
self.requires_grad_(True)
|
||||||
all_params = []
|
all_params = []
|
||||||
|
|
||||||
@@ -1078,7 +1098,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
params = assemble_params(
|
params = assemble_params(
|
||||||
self.text_encoder_loras,
|
self.text_encoder_loras,
|
||||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||||
text_encoder_loraplus_ratio or loraplus_ratio
|
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
@@ -1097,7 +1117,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
params = assemble_params(
|
params = assemble_params(
|
||||||
block_loras,
|
block_loras,
|
||||||
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
|
||||||
unet_loraplus_ratio or loraplus_ratio
|
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
@@ -1105,7 +1125,7 @@ class LoRANetwork(torch.nn.Module):
|
|||||||
params = assemble_params(
|
params = assemble_params(
|
||||||
self.unet_loras,
|
self.unet_loras,
|
||||||
unet_lr if unet_lr is not None else default_lr,
|
unet_lr if unet_lr is not None else default_lr,
|
||||||
unet_loraplus_ratio or loraplus_ratio
|
self.loraplus_unet_lr_ratio or self.loraplus_ratio,
|
||||||
)
|
)
|
||||||
all_params.extend(params)
|
all_params.extend(params)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user