mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Replace print with logger if they are logs (#905)
* Add get_my_logger() * Use logger instead of print * Fix log level * Removed line-breaks for readability * Use setup_logging() * Add rich to requirements.txt * Make simple * Use logger instead of print --------- Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,10 @@ import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
import library.model_util as model_util
|
||||
|
||||
from library.utils import setup_logging
|
||||
setup_logging()
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def convert(args):
|
||||
# 引数を確認する
|
||||
@@ -30,7 +33,7 @@ def convert(args):
|
||||
|
||||
# モデルを読み込む
|
||||
msg = "checkpoint" if is_load_ckpt else ("Diffusers" + (" as fp16" if args.fp16 else ""))
|
||||
print(f"loading {msg}: {args.model_to_load}")
|
||||
logger.info(f"loading {msg}: {args.model_to_load}")
|
||||
|
||||
if is_load_ckpt:
|
||||
v2_model = args.v2
|
||||
@@ -48,13 +51,13 @@ def convert(args):
|
||||
if args.v1 == args.v2:
|
||||
# 自動判定する
|
||||
v2_model = unet.config.cross_attention_dim == 1024
|
||||
print("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
logger.info("checking model version: model is " + ("v2" if v2_model else "v1"))
|
||||
else:
|
||||
v2_model = not args.v1
|
||||
|
||||
# 変換して保存する
|
||||
msg = ("checkpoint" + ("" if save_dtype is None else f" in {save_dtype}")) if is_save_ckpt else "Diffusers"
|
||||
print(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
logger.info(f"converting and saving as {msg}: {args.model_to_save}")
|
||||
|
||||
if is_save_ckpt:
|
||||
original_model = args.model_to_load if is_load_ckpt else None
|
||||
@@ -70,15 +73,15 @@ def convert(args):
|
||||
save_dtype=save_dtype,
|
||||
vae=vae,
|
||||
)
|
||||
print(f"model saved. total converted state_dict keys: {key_count}")
|
||||
logger.info(f"model saved. total converted state_dict keys: {key_count}")
|
||||
else:
|
||||
print(
|
||||
logger.info(
|
||||
f"copy scheduler/tokenizer config from: {args.reference_model if args.reference_model is not None else 'default model'}"
|
||||
)
|
||||
model_util.save_diffusers_checkpoint(
|
||||
v2_model, args.model_to_save, text_encoder, unet, args.reference_model, vae, args.use_safetensors
|
||||
)
|
||||
print("model saved.")
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user