mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Replace print with logger if they are logs (#905)
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,10 @@ import re
|
||||
from typing import Optional, List, Type
|
||||
import torch
|
||||
from library import sdxl_original_unet
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# input_blocksに適用するかどうか / if True, input_blocks are not applied
|
||||
SKIP_INPUT_BLOCKS = False
|
||||
@@ -270,7 +273,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
# create module instances
|
||||
self.lllite_modules = apply_to_modules(self, target_modules)
|
||||
print(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
logger.info(f"enable ControlNet LLLite for U-Net: {len(self.lllite_modules)} modules.")
|
||||
|
||||
# def prepare_optimizer_params(self):
|
||||
def prepare_params(self):
|
||||
@@ -281,8 +284,8 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
train_params.append(p)
|
||||
else:
|
||||
non_train_params.append(p)
|
||||
print(f"count of trainable parameters: {len(train_params)}")
|
||||
print(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
logger.info(f"count of trainable parameters: {len(train_params)}")
|
||||
logger.info(f"count of non-trainable parameters: {len(non_train_params)}")
|
||||
|
||||
for p in non_train_params:
|
||||
p.requires_grad_(False)
|
||||
@@ -388,7 +391,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
matches = pattern.findall(module_name)
|
||||
if matches is not None:
|
||||
for m in matches:
|
||||
print(module_name, m)
|
||||
logger.info(f"{module_name} {m}")
|
||||
module_name = module_name.replace(m, m.replace("_", "@"))
|
||||
module_name = module_name.replace("_", ".")
|
||||
module_name = module_name.replace("@", "_")
|
||||
@@ -407,7 +410,7 @@ class SdxlUNet2DConditionModelControlNetLLLite(sdxl_original_unet.SdxlUNet2DCond
|
||||
|
||||
|
||||
def replace_unet_linear_and_conv2d():
|
||||
print("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
logger.info("replace torch.nn.Linear and torch.nn.Conv2d to LLLiteLinear and LLLiteConv2d in U-Net")
|
||||
sdxl_original_unet.torch.nn.Linear = LLLiteLinear
|
||||
sdxl_original_unet.torch.nn.Conv2d = LLLiteConv2d
|
||||
|
||||
@@ -419,10 +422,10 @@ if __name__ == "__main__":
|
||||
replace_unet_linear_and_conv2d()
|
||||
|
||||
# test shape etc
|
||||
print("create unet")
|
||||
logger.info("create unet")
|
||||
unet = SdxlUNet2DConditionModelControlNetLLLite()
|
||||
|
||||
print("enable ControlNet-LLLite")
|
||||
logger.info("enable ControlNet-LLLite")
|
||||
unet.apply_lllite(32, 64, None, False, 1.0)
|
||||
unet.to("cuda") # .to(torch.float16)
|
||||
|
||||
@@ -439,14 +442,14 @@ if __name__ == "__main__":
|
||||
# unet_sd[converted_key] = model_sd[key]
|
||||
|
||||
# info = unet.load_lllite_weights("r:/lllite_from_unet.safetensors", unet_sd)
|
||||
# print(info)
|
||||
# logger.info(info)
|
||||
|
||||
# print(unet)
|
||||
# logger.info(unet)
|
||||
|
||||
# print number of parameters
|
||||
# logger.info number of parameters
|
||||
params = unet.prepare_params()
|
||||
print("number of parameters", sum(p.numel() for p in params))
|
||||
# print("type any key to continue")
|
||||
logger.info(f"number of parameters {sum(p.numel() for p in params)}")
|
||||
# logger.info("type any key to continue")
|
||||
# input()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
@@ -455,12 +458,12 @@ if __name__ == "__main__":
|
||||
|
||||
# # visualize
|
||||
# import torchviz
|
||||
# print("run visualize")
|
||||
# logger.info("run visualize")
|
||||
# controlnet.set_control(conditioning_image)
|
||||
# output = unet(x, t, ctx, y)
|
||||
# print("make_dot")
|
||||
# logger.info("make_dot")
|
||||
# image = torchviz.make_dot(output, params=dict(controlnet.named_parameters()))
|
||||
# print("render")
|
||||
# logger.info("render")
|
||||
# image.format = "svg" # "png"
|
||||
# image.render("NeuralNet") # すごく時間がかかるので注意 / be careful because it takes a long time
|
||||
# input()
|
||||
@@ -471,13 +474,13 @@ if __name__ == "__main__":
|
||||
|
||||
scaler = torch.cuda.amp.GradScaler(enabled=True)
|
||||
|
||||
print("start training")
|
||||
logger.info("start training")
|
||||
steps = 10
|
||||
batch_size = 1
|
||||
|
||||
sample_param = [p for p in unet.named_parameters() if ".lllite_up." in p[0]][0]
|
||||
for step in range(steps):
|
||||
print(f"step {step}")
|
||||
logger.info(f"step {step}")
|
||||
|
||||
conditioning_image = torch.rand(batch_size, 3, 1024, 1024).cuda() * 2.0 - 1.0
|
||||
x = torch.randn(batch_size, 4, 128, 128).cuda()
|
||||
@@ -494,9 +497,9 @@ if __name__ == "__main__":
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
print(sample_param)
|
||||
logger.info(sample_param)
|
||||
|
||||
# from safetensors.torch import save_file
|
||||
|
||||
# print("save weights")
|
||||
# logger.info("save weights")
|
||||
# unet.save_lllite_weights("r:/lllite_from_unet.safetensors", torch.float16, None)
|
||||
|
||||
Reference in New Issue
Block a user