mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results)
This commit is contained in:
@@ -135,7 +135,7 @@ class NetworkTrainer:
|
|||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
def process_val_batch(self, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
timesteps_list = [10, 350, 500, 650, 990]
|
timesteps_list = [10, 350, 500, 650, 990]
|
||||||
@@ -153,7 +153,7 @@ class NetworkTrainer:
|
|||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
@@ -173,7 +173,7 @@ class NetworkTrainer:
|
|||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
|
|
||||||
for fixed_timesteps in timesteps_list:
|
for fixed_timesteps in timesteps_list:
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
||||||
@@ -189,6 +189,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
|
|
||||||
@@ -886,7 +887,6 @@ class NetworkTrainer:
|
|||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
is_train = True
|
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||||
else:
|
else:
|
||||||
@@ -911,7 +911,7 @@ class NetworkTrainer:
|
|||||||
# print(f"set multiplier: {multipliers}")
|
# print(f"set multiplier: {multipliers}")
|
||||||
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
accelerator.unwrap_model(network).set_multiplier(multipliers)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
@@ -941,7 +941,7 @@ class NetworkTrainer:
|
|||||||
t.requires_grad_(True)
|
t.requires_grad_(True)
|
||||||
|
|
||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with accelerator.autocast():
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args,
|
args,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -1041,9 +1041,8 @@ class NetworkTrainer:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||||
is_train = False
|
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||||
total_loss += loss.detach().item()
|
total_loss += loss.detach().item()
|
||||||
current_loss = total_loss / validation_steps
|
current_loss = total_loss / validation_steps
|
||||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||||
|
|||||||
Reference in New Issue
Block a user