Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -8,6 +8,10 @@ from safetensors.torch import load_file, save_file, safe_open
from tqdm import tqdm
from library import train_util, model_util
import numpy as np
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
MIN_SV = 1e-6
@@ -206,7 +210,7 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
scale = network_alpha/network_dim
if dynamic_method:
print(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
logger.info(f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}")
lora_down_weight = None
lora_up_weight = None
@@ -275,10 +279,10 @@ def resize_lora_model(lora_sd, new_rank, save_dtype, device, dynamic_method, dyn
del param_dict
if verbose:
print(verbose_str)
logger.info(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
print("resizing complete")
logger.info(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
@@ -304,10 +308,10 @@ def resize(args):
if save_dtype is None:
save_dtype = merge_dtype
print("loading Model...")
logger.info("loading Model...")
lora_sd, metadata = load_state_dict(args.model, merge_dtype)
print("Resizing Lora...")
logger.info("Resizing Lora...")
state_dict, old_dim, new_alpha = resize_lora_model(lora_sd, args.new_rank, save_dtype, args.device, args.dynamic_method, args.dynamic_param, args.verbose)
# update metadata
@@ -329,7 +333,7 @@ def resize(args):
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash
print(f"saving model to: {args.save_to}")
logger.info(f"saving model to: {args.save_to}")
save_to_file(args.save_to, state_dict, state_dict, save_dtype, metadata)