mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Add process_batch for train_network
This commit is contained in:
211
train_network.py
211
train_network.py
@@ -130,6 +130,75 @@ class NetworkTrainer:
|
|||||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
|
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||||
|
with torch.no_grad():
|
||||||
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
latents = batch["latents"].to(accelerator.device)
|
||||||
|
else:
|
||||||
|
# latentに変換
|
||||||
|
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample()
|
||||||
|
|
||||||
|
# NaNが含まれていれば警告を表示し0に置き換える
|
||||||
|
if torch.any(torch.isnan(latents)):
|
||||||
|
accelerator.print("NaN found in latents, replacing with zeros")
|
||||||
|
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
||||||
|
latents = latents * self.vae_scale_factor
|
||||||
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
if args.weighted_captions:
|
||||||
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
|
tokenizers[0],
|
||||||
|
text_encoders[0],
|
||||||
|
batch["captions"],
|
||||||
|
accelerator.device,
|
||||||
|
args.max_token_length // 75 if args.max_token_length else 1,
|
||||||
|
clip_skip=args.clip_skip,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
text_encoder_conds = self.get_text_cond(
|
||||||
|
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
|
# with noise offset and/or multires noise if specified
|
||||||
|
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
|
noise_pred = self.call_unet(
|
||||||
|
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.v_parameterization:
|
||||||
|
# v-parameterization training
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
target = noise
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||||
|
loss = loss * loss_weights
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
if args.scale_v_pred_loss_like_noise_pred:
|
||||||
|
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||||
|
if args.v_pred_like_loss:
|
||||||
|
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||||
|
if args.debiased_estimation_loss:
|
||||||
|
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||||
|
|
||||||
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
training_started_at = time.time()
|
training_started_at = time.time()
|
||||||
@@ -777,71 +846,8 @@ class NetworkTrainer:
|
|||||||
current_step.value = global_step
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
on_step_start(text_encoder, unet)
|
on_step_start(text_encoder, unet)
|
||||||
|
is_train = True
|
||||||
with torch.no_grad():
|
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
|
||||||
latents = batch["latents"].to(accelerator.device)
|
|
||||||
else:
|
|
||||||
# latentに変換
|
|
||||||
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
|
||||||
|
|
||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
|
||||||
if torch.any(torch.isnan(latents)):
|
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
|
||||||
latents = latents * self.vae_scale_factor
|
|
||||||
b_size = latents.shape[0]
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
|
||||||
# Get the text embedding for conditioning
|
|
||||||
if args.weighted_captions:
|
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
|
||||||
tokenizer,
|
|
||||||
text_encoder,
|
|
||||||
batch["captions"],
|
|
||||||
accelerator.device,
|
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
|
||||||
clip_skip=args.clip_skip,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text_encoder_conds = self.get_text_cond(
|
|
||||||
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
|
||||||
# with noise offset and/or multires noise if specified
|
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
|
||||||
args, noise_scheduler, latents
|
|
||||||
)
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
with accelerator.autocast():
|
|
||||||
noise_pred = self.call_unet(
|
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
# v-parameterization training
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
|
||||||
loss = loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
|
||||||
loss = loss * loss_weights
|
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
|
||||||
if args.v_pred_like_loss:
|
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
|
||||||
if args.debiased_estimation_loss:
|
|
||||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||||
@@ -893,7 +899,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||||
accelerator.log(logs)
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@@ -905,80 +911,27 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
is_train = False
|
||||||
latents = batch["latents"].to(accelerator.device)
|
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||||
else:
|
|
||||||
# latentに変換
|
|
||||||
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()
|
|
||||||
|
|
||||||
# NaNが含まれていれば警告を表示し0に置き換える
|
|
||||||
if torch.any(torch.isnan(latents)):
|
|
||||||
accelerator.print("NaN found in latents, replacing with zeros")
|
|
||||||
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
|
|
||||||
latents = latents * self.vae_scale_factor
|
|
||||||
b_size = latents.shape[0]
|
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
|
||||||
if args.weighted_captions:
|
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
|
||||||
tokenizer,
|
|
||||||
text_encoder,
|
|
||||||
batch["captions"],
|
|
||||||
accelerator.device,
|
|
||||||
args.max_token_length // 75 if args.max_token_length else 1,
|
|
||||||
clip_skip=args.clip_skip,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
text_encoder_conds = self.get_text_cond(
|
|
||||||
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
|
||||||
# with noise offset and/or multires noise if specified
|
|
||||||
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
|
|
||||||
args, noise_scheduler, latents
|
|
||||||
)
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
with accelerator.autocast():
|
|
||||||
noise_pred = self.call_unet(
|
|
||||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.v_parameterization:
|
|
||||||
# v-parameterization training
|
|
||||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
|
||||||
else:
|
|
||||||
target = noise
|
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
|
||||||
loss = loss.mean([1, 2, 3])
|
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
|
||||||
|
|
||||||
loss = loss * loss_weights
|
|
||||||
|
|
||||||
if args.min_snr_gamma:
|
|
||||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
|
||||||
if args.scale_v_pred_loss_like_noise_pred:
|
|
||||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
|
||||||
if args.v_pred_like_loss:
|
|
||||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
|
||||||
|
|
||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
|
|
||||||
|
if args.logging_dir is not None:
|
||||||
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
|
logs = {"loss/validation_current": current_loss}
|
||||||
|
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||||
|
|
||||||
if len(val_dataloader) > 0:
|
if len(val_dataloader) > 0:
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
logs = {"loss/validation": avr_loss}
|
logs = {"loss/validation_average": avr_loss}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = {"loss/epoch": loss_recorder.moving_average}
|
# logs = {"loss/epoch": loss_recorder.moving_average}
|
||||||
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|||||||
Reference in New Issue
Block a user