mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
fix do not mean in batch dim when min_snr_gamma
This commit is contained in:
18
fine_tune.py
18
fine_tune.py
@@ -23,6 +23,7 @@ from library.config_util import (
|
|||||||
import library.custom_train_functions as custom_train_functions
|
import library.custom_train_functions as custom_train_functions
|
||||||
from library.custom_train_functions import apply_snr_weight
|
from library.custom_train_functions import apply_snr_weight
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
train_util.verify_training_args(args)
|
train_util.verify_training_args(args)
|
||||||
train_util.prepare_dataset_args(args, True)
|
train_util.prepare_dataset_args(args, True)
|
||||||
@@ -62,8 +63,8 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value('i',0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value('i',0)
|
current_step = Value("i", 0)
|
||||||
collater = train_util.collater_class(current_epoch, current_step)
|
collater = train_util.collater_class(current_epoch, current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
@@ -196,7 +197,9 @@ def train(args):
|
|||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
args.max_train_steps = args.max_train_epochs * math.ceil(
|
||||||
|
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
|
||||||
|
)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# データセット側にも学習ステップを送信
|
# データセット側にも学習ステップを送信
|
||||||
@@ -308,10 +311,14 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
if args.min_snr_gamma:
|
||||||
|
# do not mean over batch dimension for snr weight
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
loss = loss.mean() # mean over batch dimension
|
||||||
|
else:
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -407,7 +414,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user