add feature to sample images during sdxl training

This commit is contained in:
Kohya S
2023-07-02 16:42:19 +09:00
parent 227a62e4c4
commit 64cf922841
5 changed files with 1402 additions and 32 deletions

View File

@@ -8,7 +8,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
self.sampling_warning_showed = False
def assert_extra_args(self, args, train_dataset_group):
super().assert_extra_args(args, train_dataset_group)
@@ -65,8 +64,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
)
accelerator.wait_for_everyone()
text_encoders[0].to("cpu")
text_encoders[1].to("cpu")
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
text_encoders[1].to("cpu", dtype=torch.float32)
if torch.cuda.is_available():
torch.cuda.empty_cache()
@@ -149,9 +148,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
return noise_pred
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
if not self.sampling_warning_showed:
print("sample_images is not implemented")
self.sampling_warning_showed = True
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def setup_parser() -> argparse.ArgumentParser: