Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -15,7 +15,10 @@ import random
from typing import List, Tuple, Union
import torch
from torch import nn
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class DyLoRAModule(torch.nn.Module):
"""
@@ -223,7 +226,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# print(lora_name, value.size(), dim)
# logger.info(f"{lora_name} {value.size()} {dim}")
# support old LoRA without alpha
for key in modules_dim.keys():
@@ -267,11 +270,11 @@ class DyLoRANetwork(torch.nn.Module):
self.apply_to_conv = apply_to_conv
if modules_dim is not None:
print(f"create LoRA network from weights")
logger.info("create LoRA network from weights")
else:
print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}, unit: {unit}")
if self.apply_to_conv:
print(f"apply LoRA to Conv2d with kernel size (3,3).")
logger.info("apply LoRA to Conv2d with kernel size (3,3).")
# create module instances
def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules) -> List[DyLoRAModule]:
@@ -308,7 +311,7 @@ class DyLoRANetwork(torch.nn.Module):
return loras
self.text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
# extend U-Net target modules if conv2d 3x3 is enabled, or load from weights
target_modules = DyLoRANetwork.UNET_TARGET_REPLACE_MODULE
@@ -316,7 +319,7 @@ class DyLoRANetwork(torch.nn.Module):
target_modules += DyLoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
self.unet_loras = create_modules(True, unet, target_modules)
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
def set_multiplier(self, multiplier):
self.multiplier = multiplier
@@ -336,12 +339,12 @@ class DyLoRANetwork(torch.nn.Module):
def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -359,12 +362,12 @@ class DyLoRANetwork(torch.nn.Module):
apply_unet = True
if apply_text_encoder:
print("enable LoRA for text encoder")
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []
if apply_unet:
print("enable LoRA for U-Net")
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []
@@ -375,7 +378,7 @@ class DyLoRANetwork(torch.nn.Module):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)
print(f"weights are merged")
logger.info(f"weights are merged")
"""
def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr):