mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Fix training, validation split, revert to using upstream implemenation
This commit is contained in:
@@ -205,10 +205,10 @@ class NetworkTrainer:
|
||||
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
|
||||
return noise_scheduler
|
||||
|
||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
||||
def encode_images_to_latents(self, args, vae: AutoencoderKL, images: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return vae.encode(images).latent_dist.sample()
|
||||
|
||||
def shift_scale_latents(self, args, latents):
|
||||
def shift_scale_latents(self, args, latents: torch.FloatTensor) -> torch.FloatTensor:
|
||||
return latents * self.vae_scale_factor
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
@@ -280,7 +280,7 @@ class NetworkTrainer:
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
@@ -317,20 +317,21 @@ class NetworkTrainer:
|
||||
|
||||
# endregion
|
||||
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents: torch.Tensor = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents: torch.Tensor = typing.cast(torch.FloatTensor, typing.cast(AutoencoderKLOutput, vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype))).latent_dist.sample())
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
accelerator.print("NaN found in latents, replacing with zeros")
|
||||
latents = typing.cast(torch.FloatTensor, torch.where(torch.isnan(latents), torch.zeros_like(latents), latents))
|
||||
latents = typing.cast(torch.FloatTensor, latents * self.vae_scale_factor)
|
||||
latents = typing.cast(torch.FloatTensor, torch.nan_to_num(latents, 0, out=latents))
|
||||
|
||||
latents = self.shift_scale_latents(args, latents)
|
||||
|
||||
|
||||
text_encoder_conds = []
|
||||
@@ -384,22 +385,36 @@ class NetworkTrainer:
|
||||
total_loss = torch.zeros((batch_size, 1)).to(latents.device)
|
||||
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
for fixed_timestep in chosen_timesteps_list:
|
||||
fixed_timestep = typing.cast(torch.IntTensor, fixed_timestep)
|
||||
for fixed_timesteps in chosen_timesteps_list:
|
||||
fixed_timesteps = typing.cast(torch.IntTensor, fixed_timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
# and add noise to the latents
|
||||
# with noise offset and/or multires noise if specified
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timestep)
|
||||
noisy_latents = train_util.get_noisy_latents(args, noise, noise_scheduler, latents, fixed_timesteps)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
for x in noisy_latents:
|
||||
x.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
t.requires_grad_(True)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_unet), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), fixed_timestep, text_encoder_conds, batch, weight_dtype
|
||||
args,
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents.requires_grad_(train_unet),
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timestep)
|
||||
target = noise_scheduler.get_velocity(latents, noise, fixed_timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
@@ -418,7 +433,7 @@ class NetworkTrainer:
|
||||
accelerator,
|
||||
unet,
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
fixed_timesteps,
|
||||
text_encoder_conds,
|
||||
batch,
|
||||
weight_dtype,
|
||||
@@ -427,7 +442,8 @@ class NetworkTrainer:
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, fixed_timesteps, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
loss = loss.mean([1, 2, 3]) # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
@@ -436,14 +452,7 @@ class NetworkTrainer:
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, fixed_timestep, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, fixed_timestep, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, fixed_timestep, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, fixed_timestep, noise_scheduler)
|
||||
loss = self.post_process_loss(loss, args, fixed_timesteps, noise_scheduler)
|
||||
|
||||
total_loss += loss
|
||||
|
||||
@@ -526,8 +535,12 @@ class NetworkTrainer:
|
||||
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)
|
||||
|
||||
if args.debug_dataset:
|
||||
train_dataset_group.set_current_strategies() # dasaset needs to know the strategies explicitly
|
||||
train_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(train_dataset_group)
|
||||
|
||||
if val_dataset_group is not None:
|
||||
val_dataset_group.set_current_strategies() # dataset needs to know the strategies explicitly
|
||||
train_util.debug_dataset(val_dataset_group)
|
||||
return
|
||||
if len(train_dataset_group) == 0:
|
||||
logger.error(
|
||||
@@ -753,10 +766,6 @@ class NetworkTrainer:
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# Not for sure here.
|
||||
# if val_dataset_group is not None:
|
||||
# val_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
|
||||
# lr schedulerを用意する
|
||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||
|
||||
@@ -1304,7 +1313,7 @@ class NetworkTrainer:
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
@@ -1324,7 +1333,7 @@ class NetworkTrainer:
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
@@ -1384,7 +1393,8 @@ class NetworkTrainer:
|
||||
logs = self.generate_step_logs(
|
||||
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||
)
|
||||
accelerator.log(logs, step=global_step)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
# VALIDATION PER STEP
|
||||
should_validate = (args.validation_every_n_step is not None
|
||||
@@ -1401,7 +1411,7 @@ class NetworkTrainer:
|
||||
if val_step >= validation_steps:
|
||||
break
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||
val_progress_bar.update(1)
|
||||
@@ -1409,10 +1419,12 @@ class NetworkTrainer:
|
||||
|
||||
if is_tracking:
|
||||
logs = {"loss/current_val_loss": loss.detach().item()}
|
||||
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
logs = {"loss/average_val_loss": val_loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=global_step)
|
||||
# accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
@@ -1427,7 +1439,7 @@ class NetworkTrainer:
|
||||
)
|
||||
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
@@ -1437,22 +1449,26 @@ class NetworkTrainer:
|
||||
if is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
# END OF EPOCH
|
||||
if is_tracking:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
if len(val_dataloader) > 0 and is_tracking:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
# accelerator.log(logs, step=epoch + 1)
|
||||
accelerator.log(logs)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user