mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Cleanup unused code and formatting
This commit is contained in:
@@ -318,8 +318,27 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True) -> torch.Tensor:
|
def process_batch(
|
||||||
|
self,
|
||||||
|
batch,
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
vae,
|
||||||
|
noise_scheduler,
|
||||||
|
vae_dtype,
|
||||||
|
weight_dtype,
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
text_encoding_strategy: strategy_sd.SdTextEncodingStrategy,
|
||||||
|
tokenize_strategy: strategy_sd.SdTokenizeStrategy,
|
||||||
|
is_train=True,
|
||||||
|
train_text_encoder=True,
|
||||||
|
train_unet=True
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Process a batch for the network
|
||||||
|
"""
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||||
@@ -334,7 +353,6 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
latents = self.shift_scale_latents(args, latents)
|
latents = self.shift_scale_latents(args, latents)
|
||||||
|
|
||||||
|
|
||||||
text_encoder_conds = []
|
text_encoder_conds = []
|
||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
@@ -371,13 +389,6 @@ class NetworkTrainer:
|
|||||||
if encoded_text_encoder_conds[i] is not None:
|
if encoded_text_encoder_conds[i] is not None:
|
||||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||||
|
|
||||||
batch_size = latents.shape[0]
|
|
||||||
|
|
||||||
|
|
||||||
# Predict the noise residual
|
|
||||||
# and add noise to the latents
|
|
||||||
# with noise offset and/or multires noise if specified
|
|
||||||
|
|
||||||
# sample noise, call unet, get target
|
# sample noise, call unet, get target
|
||||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||||
args,
|
args,
|
||||||
@@ -1288,7 +1299,23 @@ class NetworkTrainer:
|
|||||||
# temporary, for batch processing
|
# temporary, for batch processing
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
loss = self.process_batch(batch,
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
vae,
|
||||||
|
noise_scheduler,
|
||||||
|
vae_dtype,
|
||||||
|
weight_dtype,
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
text_encoding_strategy,
|
||||||
|
tokenize_strategy,
|
||||||
|
is_train=True,
|
||||||
|
train_text_encoder=train_text_encoder,
|
||||||
|
train_unet=train_unet
|
||||||
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||||
@@ -1366,12 +1393,26 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
loss = self.process_batch(
|
||||||
|
batch,
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
vae,
|
||||||
|
noise_scheduler,
|
||||||
|
vae_dtype,
|
||||||
|
weight_dtype,
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
text_encoding_strategy,
|
||||||
|
tokenize_strategy,
|
||||||
|
is_train=False
|
||||||
|
)
|
||||||
|
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=loss.detach().item())
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
val_progress_bar.set_postfix({ "val_avg_loss": val_loss_recorder.moving_average })
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/current_val_loss": loss.detach().item()}
|
logs = {"loss/current_val_loss": loss.detach().item()}
|
||||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
||||||
@@ -1397,7 +1438,21 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=False)
|
loss = self.process_batch(
|
||||||
|
batch,
|
||||||
|
text_encoders,
|
||||||
|
unet,
|
||||||
|
network,
|
||||||
|
vae,
|
||||||
|
noise_scheduler,
|
||||||
|
vae_dtype,
|
||||||
|
weight_dtype,
|
||||||
|
accelerator,
|
||||||
|
args,
|
||||||
|
text_encoding_strategy,
|
||||||
|
tokenize_strategy,
|
||||||
|
is_train=False
|
||||||
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
|
|||||||
Reference in New Issue
Block a user