mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update train_network.py
This commit is contained in:
@@ -174,7 +174,7 @@ class NetworkTrainer:
|
|||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
|
|
||||||
for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'):
|
for fixed_timesteps in timesteps_list:
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
@@ -988,7 +988,7 @@ class NetworkTrainer:
|
|||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
for val_step in range(validation_steps):
|
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||||
is_train = False
|
is_train = False
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||||
@@ -1016,7 +1016,7 @@ class NetworkTrainer:
|
|||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
for val_step in range(validation_steps):
|
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||||
is_train = False
|
is_train = False
|
||||||
batch = next(cyclic_val_dataloader)
|
batch = next(cyclic_val_dataloader)
|
||||||
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||||
|
|||||||
Reference in New Issue
Block a user