mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add missing functions for training batch
This commit is contained in:
@@ -318,7 +318,7 @@ class NetworkTrainer:
|
||||
# endregion
|
||||
|
||||
def process_batch(self, batch, tokenizers, text_encoders, unet, network, vae: AutoencoderKL, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy: strategy_sd.SdTextEncodingStrategy, tokenize_strategy: strategy_sd.SdTokenizeStrategy, is_train=True, train_text_encoder=True, train_unet=True, timesteps_list: Optional[List[Number]]=None) -> torch.Tensor:
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
@@ -1333,6 +1333,11 @@ class NetworkTrainer:
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
Reference in New Issue
Block a user