mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results)
This commit is contained in:
@@ -135,7 +135,7 @@ class NetworkTrainer:
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||
def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||
|
||||
total_loss = 0.0
|
||||
timesteps_list = [10, 350, 500, 650, 990]
|
||||
@@ -153,7 +153,7 @@ class NetworkTrainer:
|
||||
latents = latents * self.vae_scale_factor
|
||||
b_size = latents.shape[0]
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
@@ -173,7 +173,7 @@ class NetworkTrainer:
|
||||
# with noise offset and/or multires noise if specified
|
||||
|
||||
for fixed_timesteps in timesteps_list:
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
||||
@@ -189,6 +189,7 @@ class NetworkTrainer:
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
total_loss += loss
|
||||
|
||||
@@ -885,8 +886,7 @@ class NetworkTrainer:
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start(text_encoder, unet)
|
||||
is_train = True
|
||||
on_step_start(text_encoder, unet)
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
@@ -911,7 +911,7 @@ class NetworkTrainer:
|
||||
# print(f"set multiplier: {multipliers}")
|
||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||
|
||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||
# Get the text embedding for conditioning
|
||||
if args.weighted_captions:
|
||||
text_encoder_conds = get_weighted_text_embeddings(
|
||||
@@ -941,7 +941,7 @@ class NetworkTrainer:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
with accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args,
|
||||
accelerator,
|
||||
@@ -1040,10 +1040,9 @@ class NetworkTrainer:
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
is_train = False
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / validation_steps
|
||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||
|
||||
Reference in New Issue
Block a user