mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
update loraplus on dylora/lofa_fa
This commit is contained in:
@@ -18,10 +18,13 @@ from transformers import CLIPTextModel
|
||||
import torch
|
||||
from torch import nn
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DyLoRAModule(torch.nn.Module):
|
||||
"""
|
||||
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
||||
@@ -195,7 +198,7 @@ def create_network(
|
||||
conv_alpha = 1.0
|
||||
else:
|
||||
conv_alpha = float(conv_alpha)
|
||||
|
||||
|
||||
if unit is not None:
|
||||
unit = int(unit)
|
||||
else:
|
||||
@@ -211,6 +214,16 @@ def create_network(
|
||||
unit=unit,
|
||||
varbose=True,
|
||||
)
|
||||
|
||||
loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
|
||||
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
|
||||
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
|
||||
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
|
||||
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
|
||||
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
|
||||
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
|
||||
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)
|
||||
|
||||
return network
|
||||
|
||||
|
||||
@@ -280,6 +293,10 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
self.alpha = alpha
|
||||
self.apply_to_conv = apply_to_conv
|
||||
|
||||
self.loraplus_lr_ratio = None
|
||||
self.loraplus_unet_lr_ratio = None
|
||||
self.loraplus_text_encoder_lr_ratio = None
|
||||
|
||||
if modules_dim is not None:
|
||||
logger.info("create LoRA network from weights")
|
||||
else:
|
||||
@@ -320,9 +337,9 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit)
|
||||
loras.append(lora)
|
||||
return loras
|
||||
|
||||
|
||||
text_encoders = text_encoder if type(text_encoder) == list else [text_encoder]
|
||||
|
||||
|
||||
self.text_encoder_loras = []
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if len(text_encoders) > 1:
|
||||
@@ -331,7 +348,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
else:
|
||||
index = None
|
||||
logger.info("create LoRA for Text Encoder")
|
||||
|
||||
|
||||
text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
|
||||
@@ -346,6 +363,11 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
self.unet_loras = create_modules(True, unet, target_modules)
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
|
||||
def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
|
||||
self.loraplus_lr_ratio = loraplus_lr_ratio
|
||||
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
|
||||
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio
|
||||
|
||||
def set_multiplier(self, multiplier):
|
||||
self.multiplier = multiplier
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
@@ -407,15 +429,7 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
"""
|
||||
|
||||
# 二つのText Encoderに別々の学習率を設定できるようにするといいかも
|
||||
def prepare_optimizer_params(
|
||||
self,
|
||||
text_encoder_lr,
|
||||
unet_lr,
|
||||
default_lr,
|
||||
text_encoder_loraplus_ratio=None,
|
||||
unet_loraplus_ratio=None,
|
||||
loraplus_ratio=None
|
||||
):
|
||||
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):
|
||||
self.requires_grad_(True)
|
||||
all_params = []
|
||||
|
||||
@@ -452,15 +466,13 @@ class DyLoRANetwork(torch.nn.Module):
|
||||
params = assemble_params(
|
||||
self.text_encoder_loras,
|
||||
text_encoder_lr if text_encoder_lr is not None else default_lr,
|
||||
text_encoder_loraplus_ratio or loraplus_ratio
|
||||
self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio,
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
if self.unet_loras:
|
||||
params = assemble_params(
|
||||
self.unet_loras,
|
||||
default_lr if unet_lr is None else unet_lr,
|
||||
unet_loraplus_ratio or loraplus_ratio
|
||||
self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio
|
||||
)
|
||||
all_params.extend(params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user