Merge branch 'dev' into DyLoRA-xl

This commit is contained in:
Kohya S
2024-02-24 19:18:36 +09:00
committed by GitHub
72 changed files with 6414 additions and 1483 deletions

View File

@@ -17,7 +17,10 @@ from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
@@ -234,7 +237,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -278,11 +281,11 @@ class DyLoRANetwork(torch.nn.Module):
self.apply_to_conv = apply_to_conv
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -333,7 +336,7 @@ class DyLoRANetwork(torch.nn.Module):
self.text_encoder_loras.extend(text_encoder_loras)
# self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -341,7 +344,7 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -361,12 +364,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -384,12 +387,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -400,7 +403,7 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
"""
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):