update loraplus on dylora/lofa_fa

This commit is contained in:
Kohya S
2024-05-06 11:09:32 +09:00
parent 52e64c69cf
commit 7fe81502d0
3 changed files with 71 additions and 34 deletions

View File

@@ -18,10 +18,13 @@ from transformers import CLIPTextModel
import torch import torch
from torch import nn from torch import nn
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module): class DyLoRAModule(torch.nn.Module):
""" """
replaces forward method of the original Linear, instead of replacing the original Linear module. replaces forward method of the original Linear, instead of replacing the original Linear module.
@@ -211,6 +214,16 @@ def create_network(
unit=unit, unit=unit,
varbose=True, varbose=True,
) )
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network return network
@@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
self.alpha = alpha self.alpha = alpha
self.apply_to_conv = apply_to_conv self.apply_to_conv = apply_to_conv
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None: if modules_dim is not None:
logger.info("create LoRA network from weights") logger.info("create LoRA network from weights")
else: else:
@@ -346,6 +363,11 @@ class DyLoRANetwork(torch.nn.Module):
self.unet_loras = create_modules(True, unet, target_modules) self.unet_loras = create_modules(True, unet, target_modules)
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
@@ -407,15 +429,7 @@ class DyLoRANetwork(torch.nn.Module):
""" """
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params( def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True) self.requires_grad_(True)
all_params = [] all_params = []
@@ -452,15 +466,13 @@ class DyLoRANetwork(torch.nn.Module):
params = assemble_params( params = assemble_params(
self.text_encoder_loras, self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr, text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
) )
all_params.extend(params) all_params.extend(params)
if self.unet_loras: if self.unet_loras:
params = assemble_params( params = assemble_params(
self.unet_loras, self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio
default_lr if unet_lr is None else unet_lr,
unet_loraplus_ratio or loraplus_ratio
) )
all_params.extend(params) all_params.extend(params)

View File

@@ -499,6 +499,7 @@ def create_network(
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
if block_lr_weight is not None: if block_lr_weight is not None:
@@ -855,6 +856,10 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None: if modules_dim is not None:
logger.info(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:

View File

@@ -15,8 +15,10 @@ import numpy as np
import torch import torch
import re import re
from library.utils import setup_logging from library.utils import setup_logging
setup_logging() setup_logging()
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_")
@@ -504,6 +506,15 @@ def create_network(
if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None:
network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight)
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
return network return network
@@ -529,7 +540,9 @@ def get_block_dims_and_alphas(
len(block_dims) == num_total_blocks len(block_dims) == num_total_blocks
), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください"
else: else:
logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") logger.warning(
f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります"
)
block_dims = [network_dim] * num_total_blocks block_dims = [network_dim] * num_total_blocks
if block_alphas is not None: if block_alphas is not None:
@@ -803,11 +816,17 @@ class LoRANetwork(torch.nn.Module):
self.rank_dropout = rank_dropout self.rank_dropout = rank_dropout
self.module_dropout = module_dropout self.module_dropout = module_dropout
self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None
if modules_dim is not None: if modules_dim is not None:
logger.info(f"create LoRA network from weights") logger.info(f"create LoRA network from weights")
elif block_dims is not None: elif block_dims is not None:
logger.info(f"create LoRA network from block_dims") logger.info(f"create LoRA network from block_dims")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
logger.info(f"block_dims: {block_dims}") logger.info(f"block_dims: {block_dims}")
logger.info(f"block_alphas: {block_alphas}") logger.info(f"block_alphas: {block_alphas}")
if conv_block_dims is not None: if conv_block_dims is not None:
@@ -815,9 +834,13 @@ class LoRANetwork(torch.nn.Module):
logger.info(f"conv_block_alphas: {conv_block_alphas}") logger.info(f"conv_block_alphas: {conv_block_alphas}")
else: else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
if self.conv_lora_dim is not None: if self.conv_lora_dim is not None:
logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") logger.info(
f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
)
# create module instances # create module instances
def create_modules( def create_modules(
@@ -939,6 +962,11 @@ class LoRANetwork(torch.nn.Module):
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name) names.add(lora.lora_name)
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
def set_multiplier(self, multiplier): def set_multiplier(self, multiplier):
self.multiplier = multiplier self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras: for lora in self.text_encoder_loras + self.unet_loras:
@@ -1033,15 +1061,7 @@ class LoRANetwork(torch.nn.Module):
return lr_weight return lr_weight
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも # 二つのText Encoderに別々の学習率を設定できるようにするといいかも
def prepare_optimizer_params( def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
self,
text_encoder_lr,
unet_lr,
default_lr,
text_encoder_loraplus_ratio=None,
unet_loraplus_ratio=None,
loraplus_ratio=None
):
self.requires_grad_(True) self.requires_grad_(True)
all_params = [] all_params = []
@@ -1078,7 +1098,7 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params( params = assemble_params(
self.text_encoder_loras, self.text_encoder_loras,
text_encoder_lr if text_encoder_lr is not None else default_lr, text_encoder_lr if text_encoder_lr is not None else default_lr,
text_encoder_loraplus_ratio or loraplus_ratio self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
) )
all_params.extend(params) all_params.extend(params)
@@ -1097,7 +1117,7 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params( params = assemble_params(
block_loras, block_loras,
(unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]),
unet_loraplus_ratio or loraplus_ratio self.loraplus_unet_lr_ratio or self.loraplus_ratio,
) )
all_params.extend(params) all_params.extend(params)
@@ -1105,7 +1125,7 @@ class LoRANetwork(torch.nn.Module):
params = assemble_params( params = assemble_params(
self.unet_loras, self.unet_loras,
unet_lr if unet_lr is not None else default_lr, unet_lr if unet_lr is not None else default_lr,
unet_loraplus_ratio or loraplus_ratio self.loraplus_unet_lr_ratio or self.loraplus_ratio,
) )
all_params.extend(params) all_params.extend(params)