diff --git a/flux_train_network.py b/flux_train_network.py index a6e57eed..65b121e7 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -300,6 +300,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): self.flux_lower = flux_lower self.target_device = device + def prepare_block_swap_before_forward(self): + pass + def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None): self.flux_lower.to("cpu") clean_memory_on_device(self.target_device) diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 1d1eb9d2..b3c9184f 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -196,7 +196,6 @@ def sample_image_inference( tokens_and_masks = tokenize_strategy.tokenize(prompt) # strategy has apply_t5_attn_mask option encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks) - print([x.shape if x is not None else None for x in encoded_text_encoder_conds]) # if text_encoder_conds is not cached, use encoded_text_encoder_conds if len(text_encoder_conds) == 0: