mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add feature to sample images during sdxl training
This commit is contained in:
@@ -8,7 +8,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR
|
||||
self.sampling_warning_showed = False
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
@@ -65,8 +64,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
args, accelerator, tokenizers, text_encoders, data_loader, weight_dtype
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
text_encoders[0].to("cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
||||
text_encoders[1].to("cpu", dtype=torch.float32)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -149,9 +148,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
||||
return noise_pred
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
if not self.sampling_warning_showed:
|
||||
print("sample_images is not implemented")
|
||||
self.sampling_warning_showed = True
|
||||
sdxl_train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user