mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
val
This commit is contained in:
102
train_network.py
102
train_network.py
@@ -130,7 +130,9 @@ class NetworkTrainer:
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None):
|
||||
total_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
@@ -167,37 +169,40 @@ class NetworkTrainer:
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
timesteps_list = timesteps_list or [timesteps]
|
||||
for timesteps in timesteps_list:
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
return loss
|
||||
total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
average_loss = total_loss / len(timesteps_list)
|
||||
return average_loss
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -283,10 +288,10 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
# acceleratorを準備する
|
||||
@@ -430,6 +435,15 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
@@ -798,7 +812,6 @@ class NetworkTrainer:
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
@@ -848,7 +861,6 @@ class NetworkTrainer:
|
||||
on_step_start(text_encoder, unet)
|
||||
is_train = True
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
@@ -900,7 +912,25 @@ class NetworkTrainer:
|
||||
if args.logging_dir is not None:
|
||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % 25 == 0:
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
|
||||
with torch.no_grad():
|
||||
val_dataloader_iter = iter(val_dataloader)
|
||||
batch = next(val_dataloader_iter)
|
||||
is_train = False
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -912,7 +942,7 @@ class NetworkTrainer:
|
||||
with torch.no_grad():
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
is_train = False
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
@@ -933,6 +963,12 @@ class NetworkTrainer:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
|
||||
Reference in New Issue
Block a user