mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add sample image generation during training
This commit is contained in:
@@ -232,6 +232,9 @@ class NetworkTrainer:
|
||||
def update_metadata(self, metadata, args):
|
||||
pass
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return False # use for sample images
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -529,7 +532,7 @@ class NetworkTrainer:
|
||||
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
|
||||
unet.requires_grad_(False)
|
||||
@@ -989,6 +992,14 @@ class NetworkTrainer:
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# if text_encoder is not needed for training, delete it to save memory.
|
||||
# TODO this can be automated after SDXL sample prompt cache is implemented
|
||||
if self.is_text_encoder_not_needed_for_training(args):
|
||||
logger.info("text_encoder is not needed for training. deleting to save memory.")
|
||||
for t_enc in text_encoders:
|
||||
del t_enc
|
||||
text_encoders = []
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user