add sample image generation during training

This commit is contained in:
Kohya S
2024-08-14 22:15:26 +09:00
parent 56d7651f08
commit 7db4222119
4 changed files with 374 additions and 5 deletions

View File

@@ -232,6 +232,9 @@ class NetworkTrainer:
def update_metadata(self, metadata, args):
pass
def is_text_encoder_not_needed_for_training(self, args):
return False # use for sample images
# endregion
def train(self, args):
@@ -529,7 +532,7 @@ class NetworkTrainer:
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False)
@@ -989,6 +992,14 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# if text_encoder is not needed for training, delete it to save memory.
# TODO this can be automated after SDXL sample prompt cache is implemented
if self.is_text_encoder_not_needed_for_training(args):
logger.info("text_encoder is not needed for training. deleting to save memory.")
for t_enc in text_encoders:
del t_enc
text_encoders = []
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)