support T5XXL LoRA, reduce peak memory usage #1560

This commit is contained in:
Kohya S
2024-09-04 23:15:27 +09:00
parent b7cff0a754
commit 56cb2fc885
2 changed files with 59 additions and 16 deletions

View File

@@ -5,7 +5,7 @@ import datetime
import math import math
import os import os
import random import random
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional
import einops import einops
import numpy as np import numpy as np
@@ -13,6 +13,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from PIL import Image from PIL import Image
import accelerate import accelerate
from transformers import CLIPTextModel
from library import device_utils from library import device_utils
from library.device_utils import init_ipex, get_preferred_device from library.device_utils import init_ipex, get_preferred_device
@@ -125,7 +126,7 @@ def do_sample(
def generate_image( def generate_image(
model, model,
clip_l, clip_l: CLIPTextModel,
t5xxl, t5xxl,
ae, ae,
prompt: str, prompt: str,
@@ -141,12 +142,13 @@ def generate_image(
# make first noise with packed shape # make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2 # original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16) packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn( noise = torch.randn(
1, 1,
packed_latent_height * packed_latent_width, packed_latent_height * packed_latent_width,
16 * 2 * 2, 16 * 2 * 2,
device=device, device=device,
dtype=dtype, dtype=noise_dtype,
generator=torch.Generator(device=device).manual_seed(seed), generator=torch.Generator(device=device).manual_seed(seed),
) )
@@ -166,9 +168,48 @@ def generate_image(
clip_l = clip_l.to(device) clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device) t5xxl = t5xxl.to(device)
with torch.no_grad(): with torch.no_grad():
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype): if is_fp8(clip_l_dtype):
clip_l.to(clip_l_dtype) param_itr = clip_l.parameters()
t5xxl.to(t5xxl_dtype) param_itr.__next__() # skip first
param_2nd = param_itr.__next__()
if param_2nd.dtype != clip_l_dtype:
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if is_fp8(t5xxl_dtype):
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
text_encoder.fp8_prepared = True
t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
with accelerator.autocast(): with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
@@ -315,10 +356,10 @@ if __name__ == "__main__":
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device) t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl.eval() t5xxl.eval()
if is_fp8(clip_l_dtype): # if is_fp8(clip_l_dtype):
clip_l = accelerator.prepare(clip_l) # clip_l = accelerator.prepare(clip_l)
if is_fp8(t5xxl_dtype): # if is_fp8(t5xxl_dtype):
t5xxl = accelerator.prepare(t5xxl) # t5xxl = accelerator.prepare(t5xxl)
t5xxl_max_length = 256 if is_schnell else 512 t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length) tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
@@ -329,14 +370,16 @@ if __name__ == "__main__":
model.eval() model.eval()
logger.info(f"Casting model to {flux_dtype}") logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype model.to(flux_dtype) # make sure model is dtype
if is_fp8(flux_dtype): # if is_fp8(flux_dtype):
model = accelerator.prepare(model) # model = accelerator.prepare(model)
# if args.offload:
# model = model.to("cpu")
# AE # AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device) ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae.eval() ae.eval()
if is_fp8(ae_dtype): # if is_fp8(ae_dtype):
ae = accelerator.prepare(ae) # ae = accelerator.prepare(ae)
# LoRA # LoRA
lora_models: List[lora_flux.LoRANetwork] = [] lora_models: List[lora_flux.LoRANetwork] = []
@@ -360,7 +403,7 @@ if __name__ == "__main__":
lora_model.to(device) lora_model.to(device)
lora_models.append(lora_model) lora_models.append(lora_model)
if not args.interactive: if not args.interactive:
generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance) generate_image(model, clip_l, t5xxl, ae, args.prompt, args.seed, args.width, args.height, args.steps, args.guidance)
else: else:

View File

@@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim) # logger.info(lora_name, value.size(), dim)
if train_t5xxl is None: if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None: if train_t5xxl is None: