mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update train_network.py
This commit is contained in:
@@ -174,7 +174,7 @@ class NetworkTrainer:
|
|||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
|
|
||||||
for fixed_timesteps in timesteps_list:
|
for fixed_timesteps in tqdm(timesteps_list, desc='Training Progress'):
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
@@ -184,16 +184,16 @@ class NetworkTrainer:
|
|||||||
noise_pred = self.call_unet(
|
noise_pred = self.call_unet(
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
)
|
)
|
||||||
if args.v_parameterization:
|
if args.v_parameterization:
|
||||||
# v-parameterization training
|
# v-parameterization training
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
else:
|
else:
|
||||||
target = noise
|
target = noise
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
|
|
||||||
average_loss = total_loss / len(timesteps_list)
|
average_loss = total_loss / len(timesteps_list)
|
||||||
return average_loss
|
return average_loss
|
||||||
@@ -985,7 +985,7 @@ class NetworkTrainer:
|
|||||||
if args.validation_every_n_step is not None:
|
if args.validation_every_n_step is not None:
|
||||||
if global_step % (args.validation_every_n_step) == 0:
|
if global_step % (args.validation_every_n_step) == 0:
|
||||||
if len(val_dataloader) > 0:
|
if len(val_dataloader) > 0:
|
||||||
print("Validating バリデーション処理...")
|
print(f"\nValidating バリデーション処理...")
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
@@ -998,6 +998,8 @@ class NetworkTrainer:
|
|||||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
|
logs = {"loss/current_val_loss": current_loss}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
logs = {"loss/average_val_loss": avr_loss}
|
logs = {"loss/average_val_loss": avr_loss}
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
@@ -1011,7 +1013,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.validation_every_n_step is None:
|
if args.validation_every_n_step is None:
|
||||||
if len(val_dataloader) > 0:
|
if len(val_dataloader) > 0:
|
||||||
print("Validating バリデーション処理...")
|
print(f"\nValidating バリデーション処理...")
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
validation_steps = min(args.validation_batches, len(val_dataloader)) if args.validation_batches is not None else len(val_dataloader)
|
||||||
@@ -1025,7 +1027,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
logs = {"loss/val_epoch_average": avr_loss}
|
logs = {"loss/epoch_val_average": avr_loss}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
Reference in New Issue
Block a user