mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Adjusted English grammar in logs to be more clear (#554)
* Update train_network.py * Update train_network.py * Update train_network.py * Update train_network.py * Update train_network.py * Update train_network.py
This commit is contained in:
@@ -80,25 +80,25 @@ def train(args):
|
|||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, True))
|
||||||
if use_user_config:
|
if use_user_config:
|
||||||
print(f"Load dataset config from {args.dataset_config}")
|
print(f"Loading dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
ignored = ["train_data_dir", "reg_data_dir", "in_json"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
print(
|
print(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignoring the following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
", ".join(ignored)
|
", ".join(ignored)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_dreambooth_method:
|
if use_dreambooth_method:
|
||||||
print("Use DreamBooth method.")
|
print("Using DreamBooth method.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
{"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(args.train_data_dir, args.reg_data_dir)}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
print("Train with captions.")
|
print("Training with captions.")
|
||||||
user_config = {
|
user_config = {
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -135,7 +135,7 @@ def train(args):
|
|||||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("preparing accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
is_main_process = accelerator.is_main_process
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
@@ -147,7 +147,7 @@ def train(args):
|
|||||||
|
|
||||||
# モデルに xformers とか memory efficient attention を組み込む
|
# モデルに xformers とか memory efficient attention を組み込む
|
||||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||||
|
|
||||||
# 差分追加学習のためにモデルを読み込む
|
# 差分追加学習のためにモデルを読み込む
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -171,7 +171,6 @@ def train(args):
|
|||||||
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
module.merge_to(text_encoder, unet, weights_sd, weight_dtype, accelerator.device if args.lowram else "cpu")
|
||||||
|
|
||||||
print(f"all weights merged: {', '.join(args.base_weights)}")
|
print(f"all weights merged: {', '.join(args.base_weights)}")
|
||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=weight_dtype)
|
||||||
@@ -210,7 +209,7 @@ def train(args):
|
|||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
info = network.load_weights(args.network_weights)
|
info = network.load_weights(args.network_weights)
|
||||||
print(f"load network weights from {args.network_weights}: {info}")
|
print(f"loaded network weights from {args.network_weights}: {info}")
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
@@ -218,7 +217,7 @@ def train(args):
|
|||||||
network.enable_gradient_checkpointing() # may have no effect
|
network.enable_gradient_checkpointing() # may have no effect
|
||||||
|
|
||||||
# 学習に必要なクラスを準備する
|
# 学習に必要なクラスを準備する
|
||||||
print("prepare optimizer, data loader etc.")
|
print("preparing optimizer, data loader etc.")
|
||||||
|
|
||||||
# 後方互換性を確保するよ
|
# 後方互換性を確保するよ
|
||||||
try:
|
try:
|
||||||
@@ -263,7 +262,7 @@ def train(args):
|
|||||||
assert (
|
assert (
|
||||||
args.mixed_precision == "fp16"
|
args.mixed_precision == "fp16"
|
||||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||||
print("enable full fp16 training.")
|
print("enabling full fp16 training.")
|
||||||
network.to(weight_dtype)
|
network.to(weight_dtype)
|
||||||
|
|
||||||
# acceleratorがなんかよろしくやってくれるらしい
|
# acceleratorがなんかよろしくやってくれるらしい
|
||||||
|
|||||||
Reference in New Issue
Block a user