Cleanup unused code and formatting

This commit is contained in:
rockerBOO
2025-01-06 11:07:47 -05:00
parent f4840ef29e
commit 1c63e7cc49

View File

@@ -318,8 +318,27 @@ class NetworkTrainer:
# endregion
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:
def process_batch(
self,
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy: strategy_sd.SdTextEncodingStrategy,
tokenize_strategy: strategy_sd.SdTokenizeStrategy,
is_train=True,
train_text_encoder=True,
train_unet=True
) -> torch.Tensor:
"""
Process a batch for the network
"""
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
@@ -334,7 +353,6 @@ class NetworkTrainer:
latents = self.shift_scale_latents(args, latents)
text_encoder_conds = []
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
@@ -371,13 +389,6 @@ class NetworkTrainer:
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
batch_size = latents.shape[0]
# Predict the noise residual
# and add noise to the latents
# with noise offset and/or multires noise if specified
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
@@ -1288,7 +1299,23 @@ class NetworkTrainer:
# temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
loss = self.process_batch(batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=True,
train_text_encoder=train_text_encoder,
train_unet=train_unet
)
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
@@ -1366,12 +1393,26 @@ class NetworkTrainer:
if val_step >= validation_steps:
break
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
)
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
if is_tracking:
logs = {"loss/current_val_loss": loss.detach().item()}
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
@@ -1397,7 +1438,21 @@ class NetworkTrainer:
if val_step >= validation_steps:
break
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
loss = self.process_batch(
batch,
text_encoders,
unet,
network,
vae,
noise_scheduler,
vae_dtype,
weight_dtype,
accelerator,
args,
text_encoding_strategy,
tokenize_strategy,
is_train=False
)
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)