diff --git a/flux_train_network.py b/flux_train_network.py index daa65c85..59b9d84b 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -23,6 +23,7 @@ logger = logging.getLogger(__name__) class FluxNetworkTrainer(train_network.NetworkTrainer): def __init__(self): super().__init__() + self.sample_prompts_te_outputs = None def assert_extra_args(self, args, train_dataset_group): super().assert_extra_args(args, train_dataset_group) @@ -147,7 +148,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process) # cache sample prompts - self.sample_prompts_te_outputs = None if args.sample_prompts is not None: logger.info(f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}")