Add process_batch for train_network

This commit is contained in:
rockerBOO
2023-11-05 16:27:36 -05:00
parent a93c524b3a
commit c89252101e

View File

@@ -130,6 +130,75 @@ class NetworkTrainer:
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoders[0],
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
return loss
def train(self, args): def train(self, args):
session_id = random.randint(0, 2**32) session_id = random.randint(0, 2**32)
training_started_at = time.time() training_started_at = time.time()
@@ -777,71 +846,8 @@ class NetworkTrainer:
current_step.value = global_step current_step.value = global_step
with accelerator.accumulate(network): with accelerator.accumulate(network):
on_step_start(text_encoder, unet) on_step_start(text_encoder, unet)
is_train = True
with torch.no_grad(): loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
if args.debiased_estimation_loss:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss) accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0: if accelerator.sync_gradients and args.max_grad_norm != 0.0:
@@ -893,7 +899,7 @@ class NetworkTrainer:
if args.logging_dir is not None: if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs) accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
@@ -905,80 +911,27 @@ class NetworkTrainer:
with torch.no_grad(): with torch.no_grad():
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if "latents" in batch and batch["latents"] is not None: is_train = False
latents = batch["latents"].to(accelerator.device) loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
else:
# latentに変換
latents = vae.encode(batch["images"].to(device=accelerator.device, dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
if args.scale_v_pred_loss_like_noise_pred:
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
if args.v_pred_like_loss:
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
if len(val_dataloader) > 0: if len(val_dataloader) > 0:
if args.logging_dir is not None: if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation": avr_loss} logs = {"loss/validation_average": avr_loss}
accelerator.log(logs, step=epoch + 1) accelerator.log(logs, step=epoch + 1)
if args.logging_dir is not None: if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average} # logs = {"loss/epoch": loss_recorder.moving_average}
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1) accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()