diff --git a/flux_minimal_inference.py b/flux_minimal_inference.py index 56c1b198..1c194e7c 100644 --- a/flux_minimal_inference.py +++ b/flux_minimal_inference.py @@ -5,7 +5,7 @@ import datetime import math import os import random -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import einops import numpy as np @@ -13,6 +13,7 @@ import torch from tqdm import tqdm from PIL import Image import accelerate +from transformers import CLIPTextModel from library import device_utils from library.device_utils import init_ipex, get_preferred_device @@ -125,7 +126,7 @@ def do_sample( def generate_image( model, - clip_l, + clip_l: CLIPTextModel, t5xxl, ae, prompt: str, @@ -141,12 +142,13 @@ def generate_image( # make first noise with packed shape # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) + noise_dtype = torch.float32 if is_fp8(dtype) else dtype noise = torch.randn( 1, packed_latent_height * packed_latent_width, 16 * 2 * 2, device=device, - dtype=dtype, + dtype=noise_dtype, generator=torch.Generator(device=device).manual_seed(seed), ) @@ -166,9 +168,48 @@ def generate_image( clip_l = clip_l.to(device) t5xxl = t5xxl.to(device) with torch.no_grad(): - if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): - clip_l.to(clip_l_dtype) - t5xxl.to(t5xxl_dtype) + if is_fp8(clip_l_dtype): + param_itr = clip_l.parameters() + param_itr.__next__() # skip first + param_2nd = param_itr.__next__() + if param_2nd.dtype != clip_l_dtype: + logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}") + clip_l.to(clip_l_dtype) # fp8 + clip_l.text_model.embeddings.to(dtype=torch.bfloat16) + + with accelerator.autocast(): + l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) + + if is_fp8(t5xxl_dtype): + if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"): + logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}") + + def prepare_fp8(text_encoder, target_dtype): + def forward_hook(module): + def forward(hidden_states): + hidden_gelu = module.act(module.wi_0(hidden_states)) + hidden_linear = module.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = module.dropout(hidden_states) + + hidden_states = module.wo(hidden_states) + return hidden_states + + return forward + + for module in text_encoder.modules(): + if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: + # print("set", module.__class__.__name__, "to", target_dtype) + module.to(target_dtype) + if module.__class__.__name__ in ["T5DenseGatedActDense"]: + # print("set", module.__class__.__name__, "hooks") + module.forward = forward_hook(module) + + text_encoder.fp8_prepared = True + + t5xxl.to(t5xxl_dtype) + prepare_fp8(t5xxl.encoder, torch.bfloat16) + with accelerator.autocast(): _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask @@ -315,10 +356,10 @@ if __name__ == "__main__": t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl.eval() - if is_fp8(clip_l_dtype): - clip_l = accelerator.prepare(clip_l) - if is_fp8(t5xxl_dtype): - t5xxl = accelerator.prepare(t5xxl) + # if is_fp8(clip_l_dtype): + # clip_l = accelerator.prepare(clip_l) + # if is_fp8(t5xxl_dtype): + # t5xxl = accelerator.prepare(t5xxl) t5xxl_max_length = 256 if is_schnell else 512 tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) @@ -329,14 +370,16 @@ if __name__ == "__main__": model.eval() logger.info(f"Casting model to {flux_dtype}") model.to(flux_dtype) # make sure model is dtype - if is_fp8(flux_dtype): - model = accelerator.prepare(model) + # if is_fp8(flux_dtype): + # model = accelerator.prepare(model) + # if args.offload: + # model = model.to("cpu") # AE ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae.eval() - if is_fp8(ae_dtype): - ae = accelerator.prepare(ae) + # if is_fp8(ae_dtype): + # ae = accelerator.prepare(ae) # LoRA lora_models: List[lora_flux.LoRANetwork] = [] @@ -360,7 +403,7 @@ if __name__ == "__main__": lora_model.to(device) lora_models.append(lora_model) - + if not args.interactive: generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) else: diff --git a/networks/lora_flux.py b/networks/lora_flux.py index 295267be..ab9ccc4d 100644 --- a/networks/lora_flux.py +++ b/networks/lora_flux.py @@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh modules_dim[lora_name] = dim # logger.info(lora_name, value.size(), dim) - if train_t5xxl is None: + if train_t5xxl is None or train_t5xxl is False: train_t5xxl = "lora_te3" in lora_name if train_t5xxl is None: