support T5XXL LoRA, reduce peak memory usage #1560

This commit is contained in:
Kohya S
2024-09-04 23:15:27 +09:00
parent b7cff0a754
commit 56cb2fc885
2 changed files with 59 additions and 16 deletions

View File

@@ -5,7 +5,7 @@ import datetime
import math
import os
import random
from typing import Callable, List, Optional, Tuple
from typing import Callable, List, Optional
import einops
import numpy as np
@@ -13,6 +13,7 @@ import torch
from tqdm import tqdm
from PIL import Image
import accelerate
from transformers import CLIPTextModel
from library import device_utils
from library.device_utils import init_ipex, get_preferred_device
@@ -125,7 +126,7 @@ def do_sample(
def generate_image(
model,
clip_l,
clip_l: CLIPTextModel,
t5xxl,
ae,
prompt: str,
@@ -141,12 +142,13 @@ def generate_image(
# make first noise with packed shape
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
noise = torch.randn(
1,
packed_latent_height * packed_latent_width,
16 * 2 * 2,
device=device,
dtype=dtype,
dtype=noise_dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
@@ -166,9 +168,48 @@ def generate_image(
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)
with torch.no_grad():
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype):
clip_l.to(clip_l_dtype)
if is_fp8(clip_l_dtype):
param_itr = clip_l.parameters()
param_itr.__next__() # skip first
param_2nd = param_itr.__next__()
if param_2nd.dtype != clip_l_dtype:
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if is_fp8(t5xxl_dtype):
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
def prepare_fp8(text_encoder, target_dtype):
def forward_hook(module):
def forward(hidden_states):
hidden_gelu = module.act(module.wi_0(hidden_states))
hidden_linear = module.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = module.dropout(hidden_states)
hidden_states = module.wo(hidden_states)
return hidden_states
return forward
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
text_encoder.fp8_prepared = True
t5xxl.to(t5xxl_dtype)
prepare_fp8(t5xxl.encoder, torch.bfloat16)
with accelerator.autocast():
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
@@ -315,10 +356,10 @@ if __name__ == "__main__":
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
t5xxl.eval()
if is_fp8(clip_l_dtype):
clip_l = accelerator.prepare(clip_l)
if is_fp8(t5xxl_dtype):
t5xxl = accelerator.prepare(t5xxl)
# if is_fp8(clip_l_dtype):
# clip_l = accelerator.prepare(clip_l)
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
@@ -329,14 +370,16 @@ if __name__ == "__main__":
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
if is_fp8(flux_dtype):
model = accelerator.prepare(model)
# if is_fp8(flux_dtype):
# model = accelerator.prepare(model)
# if args.offload:
# model = model.to("cpu")
# AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae.eval()
if is_fp8(ae_dtype):
ae = accelerator.prepare(ae)
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)
# LoRA
lora_models: List[lora_flux.LoRANetwork] = []

View File

@@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)
if train_t5xxl is None:
if train_t5xxl is None or train_t5xxl is False:
train_t5xxl = "lora_te3" in lora_name
if train_t5xxl is None: