diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 1bd4cb74..306c0f0f 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -8,6 +8,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() self.vae_scale_factor = sdxl_model_util.VAE_SCALE_FACTOR + self.sampling_warning_showed = False def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args) @@ -153,7 +154,9 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): return noise_pred def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): - print("sample_images is not implemented") + if not self.sampling_warning_showed: + print("sample_images is not implemented") + self.sampling_warning_showed = True def setup_parser() -> argparse.ArgumentParser: