mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Merge pull request #222 from kohya-ss/dev
fix training instability issue, add metadata
This commit is contained in:
14
README.md
14
README.md
@@ -124,6 +124,20 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
|
||||
|
||||
## Change History
|
||||
|
||||
- 23 Feb. 2023, 2023/2/23:
|
||||
- Fix instability training issue in ``train_network.py``.
|
||||
- Training with ``float`` for SD2.x models will work now. Also training with ``bf16`` might be improved.
|
||||
- This issue seems to have occurred in [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190).
|
||||
- Add some metadata to LoRA model. Thanks to space-nuko!
|
||||
- Raise an error if optimizer options conflict (e.g. ``--optimizer_type`` and ``--use_8bit_adam``.)
|
||||
- Support ControlNet in ``gen_img_diffusers.py`` (no documentation yet.)
|
||||
- ``train_network.py`` で学習が不安定になる不具合を修正しました。
|
||||
- ``float`` 精度での SD2.x モデルの学習が正しく動作するようになりました。また ``bf16`` 精度の学習も改善する可能性があります。
|
||||
- この問題は [PR#190](https://github.com/kohya-ss/sd-scripts/pull/190) から起きていたようです。
|
||||
- いくつかのメタデータを LoRA モデルに追加しました。 space-nuko 氏に感謝します。
|
||||
- オプティマイザ関係のオプションが矛盾していた場合、エラーとするように修正しました(例: ``--optimizer_type`` と ``--use_8bit_adam``)。
|
||||
- ``gen_img_diffusers.py`` で ControlNet をサポートしました(ドキュメントはのちほど追加します)。
|
||||
|
||||
- 22 Feb. 2023, 2023/2/22:
|
||||
- Refactor optmizer options. Thanks to mgz-dev!
|
||||
- Add ``--optimizer_type`` option for each training script. Please see help. Japanese documentation is [here](https://github.com/kohya-ss/sd-scripts/blob/main/train_network_README-ja.md#%E3%82%AA%E3%83%97%E3%83%86%E3%82%A3%E3%83%9E%E3%82%A4%E3%82%B6%E3%81%AE%E6%8C%87%E5%AE%9A%E3%81%AB%E3%81%A4%E3%81%84%E3%81%A6).
|
||||
|
||||
@@ -1372,8 +1372,8 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
|
||||
def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--optimizer_type", type=str, default="AdamW",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
||||
parser.add_argument("--optimizer_type", type=str, default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation, AdaFactor")
|
||||
|
||||
# backward compatibility
|
||||
parser.add_argument("--use_8bit_adam", action="store_true",
|
||||
@@ -1532,11 +1532,16 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
print(f"*** use_8bit_adam option is specified. optimizer_type is ignored / use_8bit_adamオプションが指定されているためoptimizer_typeは無視されます")
|
||||
assert not args.use_lion_optimizer, "both option use_8bit_adam and use_lion_optimizer are specified / use_8bit_adamとuse_lion_optimizerの両方のオプションが指定されています"
|
||||
assert optimizer_type is None or optimizer_type == "", "both option use_8bit_adam and optimizer_type are specified / use_8bit_adamとoptimizer_typeの両方のオプションが指定されています"
|
||||
optimizer_type = "AdamW8bit"
|
||||
|
||||
elif args.use_lion_optimizer:
|
||||
print(f"*** use_lion_optimizer option is specified. optimizer_type is ignored / use_lion_optimizerオプションが指定されているためoptimizer_typeは無視されます")
|
||||
assert optimizer_type is None or optimizer_type == "", "both option use_lion_optimizer and optimizer_type are specified / use_lion_optimizerとoptimizer_typeの両方のオプションが指定されています"
|
||||
optimizer_type = "Lion"
|
||||
|
||||
if optimizer_type is None or optimizer_type == "":
|
||||
optimizer_type = "AdamW"
|
||||
optimizer_type = optimizer_type.lower()
|
||||
|
||||
# 引数を分解する:boolとfloat、tupleのみ対応
|
||||
@@ -1557,7 +1562,7 @@ def get_optimizer(args, trainable_params):
|
||||
value = tuple(value)
|
||||
|
||||
optimizer_kwargs[key] = value
|
||||
print("optkwargs:", optimizer_kwargs)
|
||||
# print("optkwargs:", optimizer_kwargs)
|
||||
|
||||
lr = args.learning_rate
|
||||
|
||||
|
||||
@@ -275,9 +275,11 @@ def train(args):
|
||||
"ss_shuffle_caption": bool(args.shuffle_caption),
|
||||
"ss_cache_latents": bool(args.cache_latents),
|
||||
"ss_enable_bucket": bool(train_dataset.enable_bucket),
|
||||
"ss_bucket_no_upscale": bool(train_dataset.bucket_no_upscale),
|
||||
"ss_min_bucket_reso": train_dataset.min_bucket_reso,
|
||||
"ss_max_bucket_reso": train_dataset.max_bucket_reso,
|
||||
"ss_seed": args.seed,
|
||||
"ss_lowram": args.lowram,
|
||||
"ss_keep_tokens": args.keep_tokens,
|
||||
"ss_noise_offset": args.noise_offset,
|
||||
"ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info),
|
||||
@@ -286,7 +288,13 @@ def train(args):
|
||||
"ss_bucket_info": json.dumps(train_dataset.bucket_info),
|
||||
"ss_training_comment": args.training_comment, # will not be updated after training
|
||||
"ss_sd_scripts_commit_hash": train_util.get_git_revision_hash(),
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else "")
|
||||
"ss_optimizer": optimizer_name + (f"({optimizer_args})" if len(optimizer_args) > 0 else ""),
|
||||
"ss_max_grad_norm": args.max_grad_norm,
|
||||
"ss_caption_dropout_rate": args.caption_dropout_rate,
|
||||
"ss_caption_dropout_every_n_epochs": args.caption_dropout_every_n_epochs,
|
||||
"ss_caption_tag_dropout_rate": args.caption_tag_dropout_rate,
|
||||
"ss_face_crop_aug_range": args.face_crop_aug_range,
|
||||
"ss_prior_loss_weight": args.prior_loss_weight,
|
||||
}
|
||||
|
||||
# uncomment if another network is added
|
||||
@@ -361,7 +369,7 @@ def train(args):
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
with autocast():
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
@@ -422,6 +430,7 @@ def train(args):
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata)
|
||||
|
||||
@@ -439,6 +448,7 @@ def train(args):
|
||||
# end of epoch
|
||||
|
||||
metadata["ss_epoch"] = str(num_train_epochs)
|
||||
metadata["ss_training_finished_at"] = str(time.time())
|
||||
|
||||
is_main_process = accelerator.is_main_process
|
||||
if is_main_process:
|
||||
|
||||
Reference in New Issue
Block a user