mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix timesteps
This commit is contained in:
@@ -141,7 +141,6 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
timesteps_list = [10, 350, 500, 650, 990]
|
timesteps_list = [10, 350, 500, 650, 990]
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
@@ -174,16 +173,17 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps(
|
|
||||||
args, noise_scheduler, latents
|
for fixed_timesteps in timesteps_list:
|
||||||
)
|
|
||||||
for timesteps in timesteps_list:
|
|
||||||
# Predict the noise residual
|
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
timesteps = torch.randint(fixed_timesteps, fixed_timesteps, (b_size,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
@@ -988,7 +988,7 @@ class NetworkTrainer:
|
|||||||
print("Validating バリデーション処理...")
|
print("Validating バリデーション処理...")
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
for val_step in range(validation_steps):
|
for val_step in range(validation_steps):
|
||||||
is_train = False
|
is_train = False
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
@@ -999,7 +999,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
logs = {"loss/avr_val_loss": avr_loss}
|
logs = {"loss/average_val_loss": avr_loss}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
@@ -1014,7 +1014,7 @@ class NetworkTrainer:
|
|||||||
print("Validating バリデーション処理...")
|
print("Validating バリデーション処理...")
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = args.validation_batches if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
for val_step in range(validation_steps):
|
for val_step in range(validation_steps):
|
||||||
is_train = False
|
is_train = False
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
|
|||||||
Reference in New Issue
Block a user