mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add missing functions for training batch
This commit is contained in:
@@ -1333,6 +1333,11 @@ class NetworkTrainer:
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(training_model):
|
||||
on_step_start_for_network(text_encoder, unet)
|
||||
|
||||
# temporary, for batch processing
|
||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||
|
||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
||||
Reference in New Issue
Block a user