Replace print with logger if they are logs (#905)

* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
This commit is contained in:
Yuta Hayashibe
2024-02-04 16:14:34 +07:00
committed by GitHub
parent 7f948db158
commit 5f6bf29e52
62 changed files with 1195 additions and 961 deletions

View File

@@ -41,7 +41,10 @@ from library.custom_train_functions import (
add_v_prediction_like_loss,
apply_debiased_estimation,
)
from library.utils import setup_logging
setup_logging()
import logging
logger = logging.getLogger(__name__)
class NetworkTrainer:
def __init__(self):
@@ -153,18 +156,18 @@ class NetworkTrainer:
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, False, True))
if use_user_config:
print(f"Loading dataset config from {args.dataset_config}")
logger.info(f"Loading dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
print(
logger.warning(
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
print("Using DreamBooth method.")
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
@@ -175,7 +178,7 @@ class NetworkTrainer:
]
}
else:
print("Training with captions.")
logger.info("Training with captions.")
user_config = {
"datasets": [
{
@@ -204,7 +207,7 @@ class NetworkTrainer:
train_util.debug_dataset(train_dataset_group)
return
if len(train_dataset_group) == 0:
print(
logger.error(
"No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります"
)
return
@@ -217,7 +220,7 @@ class NetworkTrainer:
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
print("preparing accelerator")
logger.info("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
@@ -310,7 +313,7 @@ class NetworkTrainer:
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
print(
logger.warning(
"warning: scale_weight_norms is specified but the network does not support it / scale_weight_normsが指定されていますが、ネットワークが対応していません"
)
args.scale_weight_norms = False
@@ -938,7 +941,7 @@ class NetworkTrainer:
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
save_model(ckpt_name, network, global_step, num_train_epochs, force_sync_upload=True)
print("model saved.")
logger.info("model saved.")
def setup_parser() -> argparse.ArgumentParser: