mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Replace print with logger if they are logs (#905)
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -10,7 +10,10 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel
|
||||
import torch
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def make_unet_conversion_map() -> Dict[str, str]:
|
||||
unet_conversion_map_layer = []
|
||||
@@ -248,7 +251,7 @@ def create_network_from_weights(
|
||||
elif "lora_down" in key:
|
||||
dim = value.size()[0]
|
||||
modules_dim[lora_name] = dim
|
||||
# print(lora_name, value.size(), dim)
|
||||
# logger.info(f"{lora_name} {value.size()} {dim}")
|
||||
|
||||
# support old LoRA without alpha
|
||||
for key in modules_dim.keys():
|
||||
@@ -291,12 +294,12 @@ class LoRANetwork(torch.nn.Module):
|
||||
super().__init__()
|
||||
self.multiplier = multiplier
|
||||
|
||||
print(f"create LoRA network from weights")
|
||||
logger.info("create LoRA network from weights")
|
||||
|
||||
# convert SDXL Stability AI's U-Net modules to Diffusers
|
||||
converted = self.convert_unet_modules(modules_dim, modules_alpha)
|
||||
if converted:
|
||||
print(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
logger.info(f"converted {converted} Stability AI's U-Net LoRA modules to Diffusers (SDXL)")
|
||||
|
||||
# create module instances
|
||||
def create_modules(
|
||||
@@ -331,7 +334,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora_name = lora_name.replace(".", "_")
|
||||
|
||||
if lora_name not in modules_dim:
|
||||
# print(f"skipped {lora_name} (not found in modules_dim)")
|
||||
# logger.info(f"skipped {lora_name} (not found in modules_dim)")
|
||||
skipped.append(lora_name)
|
||||
continue
|
||||
|
||||
@@ -362,18 +365,18 @@ class LoRANetwork(torch.nn.Module):
|
||||
text_encoder_loras, skipped = create_modules(False, index, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
|
||||
self.text_encoder_loras.extend(text_encoder_loras)
|
||||
skipped_te += skipped
|
||||
print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
logger.info(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.")
|
||||
if len(skipped_te) > 0:
|
||||
print(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
logger.warning(f"skipped {len(skipped_te)} modules because of missing weight for text encoder.")
|
||||
|
||||
# extend U-Net target modules to include Conv2d 3x3
|
||||
target_modules = LoRANetwork.UNET_TARGET_REPLACE_MODULE + LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
||||
|
||||
self.unet_loras: List[LoRAModule]
|
||||
self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules)
|
||||
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
||||
if len(skipped_un) > 0:
|
||||
print(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
logger.warning(f"skipped {len(skipped_un)} modules because of missing weight for U-Net.")
|
||||
|
||||
# assertion
|
||||
names = set()
|
||||
@@ -420,11 +423,11 @@ class LoRANetwork(torch.nn.Module):
|
||||
|
||||
def apply_to(self, multiplier=1.0, apply_text_encoder=True, apply_unet=True):
|
||||
if apply_text_encoder:
|
||||
print("enable LoRA for text encoder")
|
||||
logger.info("enable LoRA for text encoder")
|
||||
for lora in self.text_encoder_loras:
|
||||
lora.apply_to(multiplier)
|
||||
if apply_unet:
|
||||
print("enable LoRA for U-Net")
|
||||
logger.info("enable LoRA for U-Net")
|
||||
for lora in self.unet_loras:
|
||||
lora.apply_to(multiplier)
|
||||
|
||||
@@ -433,16 +436,16 @@ class LoRANetwork(torch.nn.Module):
|
||||
lora.unapply_to()
|
||||
|
||||
def merge_to(self, multiplier=1.0):
|
||||
print("merge LoRA weights to original weights")
|
||||
logger.info("merge LoRA weights to original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.merge_to(multiplier)
|
||||
print(f"weights are merged")
|
||||
logger.info(f"weights are merged")
|
||||
|
||||
def restore_from(self, multiplier=1.0):
|
||||
print("restore LoRA weights from original weights")
|
||||
logger.info("restore LoRA weights from original weights")
|
||||
for lora in tqdm(self.text_encoder_loras + self.unet_loras):
|
||||
lora.restore_from(multiplier)
|
||||
print(f"weights are restored")
|
||||
logger.info(f"weights are restored")
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
# convert SDXL Stability AI's state dict to Diffusers' based state dict
|
||||
@@ -463,7 +466,7 @@ class LoRANetwork(torch.nn.Module):
|
||||
my_state_dict = self.state_dict()
|
||||
for key in state_dict.keys():
|
||||
if state_dict[key].size() != my_state_dict[key].size():
|
||||
# print(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
# logger.info(f"convert {key} from {state_dict[key].size()} to {my_state_dict[key].size()}")
|
||||
state_dict[key] = state_dict[key].view(my_state_dict[key].size())
|
||||
|
||||
return super().load_state_dict(state_dict, strict)
|
||||
@@ -490,7 +493,7 @@ if __name__ == "__main__":
|
||||
image_prefix = args.model_id.replace("/", "_") + "_"
|
||||
|
||||
# load Diffusers model
|
||||
print(f"load model from {args.model_id}")
|
||||
logger.info(f"load model from {args.model_id}")
|
||||
pipe: Union[StableDiffusionPipeline, StableDiffusionXLPipeline]
|
||||
if args.sdxl:
|
||||
# use_safetensors=True does not work with 0.18.2
|
||||
@@ -503,7 +506,7 @@ if __name__ == "__main__":
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2] if args.sdxl else [pipe.text_encoder]
|
||||
|
||||
# load LoRA weights
|
||||
print(f"load LoRA weights from {args.lora_weights}")
|
||||
logger.info(f"load LoRA weights from {args.lora_weights}")
|
||||
if os.path.splitext(args.lora_weights)[1] == ".safetensors":
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -512,10 +515,10 @@ if __name__ == "__main__":
|
||||
lora_sd = torch.load(args.lora_weights)
|
||||
|
||||
# create by LoRA weights and load weights
|
||||
print(f"create LoRA network")
|
||||
logger.info(f"create LoRA network")
|
||||
lora_network: LoRANetwork = create_network_from_weights(text_encoders, pipe.unet, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"load LoRA network weights")
|
||||
logger.info(f"load LoRA network weights")
|
||||
lora_network.load_state_dict(lora_sd)
|
||||
|
||||
lora_network.to(device, dtype=pipe.unet.dtype) # required to apply_to. merge_to works without this
|
||||
@@ -544,34 +547,34 @@ if __name__ == "__main__":
|
||||
random.seed(seed)
|
||||
|
||||
# create image with original weights
|
||||
print(f"create image with original weights")
|
||||
logger.info(f"create image with original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "original.png")
|
||||
|
||||
# apply LoRA network to the model: slower than merge_to, but can be reverted easily
|
||||
print(f"apply LoRA network to the model")
|
||||
logger.info(f"apply LoRA network to the model")
|
||||
lora_network.apply_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with applied LoRA")
|
||||
logger.info(f"create image with applied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "applied_lora.png")
|
||||
|
||||
# unapply LoRA network to the model
|
||||
print(f"unapply LoRA network to the model")
|
||||
logger.info(f"unapply LoRA network to the model")
|
||||
lora_network.unapply_to()
|
||||
|
||||
print(f"create image with unapplied LoRA")
|
||||
logger.info(f"create image with unapplied LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unapplied_lora.png")
|
||||
|
||||
# merge LoRA network to the model: faster than apply_to, but requires back-up of original weights (or unmerge_to)
|
||||
print(f"merge LoRA network to the model")
|
||||
logger.info(f"merge LoRA network to the model")
|
||||
lora_network.merge_to(multiplier=1.0)
|
||||
|
||||
print(f"create image with LoRA")
|
||||
logger.info(f"create image with LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "merged_lora.png")
|
||||
@@ -579,31 +582,31 @@ if __name__ == "__main__":
|
||||
# restore (unmerge) LoRA weights: numerically unstable
|
||||
# マージされた重みを元に戻す。計算誤差のため、元の重みと完全に一致しないことがあるかもしれない
|
||||
# 保存したstate_dictから元の重みを復元するのが確実
|
||||
print(f"restore (unmerge) LoRA weights")
|
||||
logger.info(f"restore (unmerge) LoRA weights")
|
||||
lora_network.restore_from(multiplier=1.0)
|
||||
|
||||
print(f"create image without LoRA")
|
||||
logger.info(f"create image without LoRA")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "unmerged_lora.png")
|
||||
|
||||
# restore original weights
|
||||
print(f"restore original weights")
|
||||
logger.info(f"restore original weights")
|
||||
pipe.unet.load_state_dict(org_unet_sd)
|
||||
pipe.text_encoder.load_state_dict(org_text_encoder_sd)
|
||||
if args.sdxl:
|
||||
pipe.text_encoder_2.load_state_dict(org_text_encoder_2_sd)
|
||||
|
||||
print(f"create image with restored original weights")
|
||||
logger.info(f"create image with restored original weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "restore_original.png")
|
||||
|
||||
# use convenience function to merge LoRA weights
|
||||
print(f"merge LoRA weights with convenience function")
|
||||
logger.info(f"merge LoRA weights with convenience function")
|
||||
merge_lora_weights(pipe, lora_sd, multiplier=1.0)
|
||||
|
||||
print(f"create image with merged LoRA weights")
|
||||
logger.info(f"create image with merged LoRA weights")
|
||||
seed_everything(args.seed)
|
||||
image = pipe(args.prompt, negative_prompt=args.negative_prompt).images[0]
|
||||
image.save(image_prefix + "convenience_merged_lora.png")
|
||||
|
||||
Reference in New Issue
Block a user