mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add missing functions for training batch
This commit is contained in:
@@ -1333,6 +1333,11 @@ class NetworkTrainer:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
|
on_step_start_for_network(text_encoder, unet)
|
||||||
|
|
||||||
|
# temporary, for batch processing
|
||||||
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
|
|
||||||
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
loss = self.process_batch(batch, tokenizers, text_encoders, unet, network, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, text_encoding_strategy, tokenize_strategy, is_train=True, train_text_encoder=train_text_encoder, train_unet=train_unet)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
|||||||
Reference in New Issue
Block a user