mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support T5XXL LoRA, reduce peak memory usage #1560
This commit is contained in:
@@ -5,7 +5,7 @@ import datetime
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional
|
||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import accelerate
|
import accelerate
|
||||||
|
from transformers import CLIPTextModel
|
||||||
|
|
||||||
from library import device_utils
|
from library import device_utils
|
||||||
from library.device_utils import init_ipex, get_preferred_device
|
from library.device_utils import init_ipex, get_preferred_device
|
||||||
@@ -125,7 +126,7 @@ def do_sample(
|
|||||||
|
|
||||||
def generate_image(
|
def generate_image(
|
||||||
model,
|
model,
|
||||||
clip_l,
|
clip_l: CLIPTextModel,
|
||||||
t5xxl,
|
t5xxl,
|
||||||
ae,
|
ae,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@@ -141,12 +142,13 @@ def generate_image(
|
|||||||
# make first noise with packed shape
|
# make first noise with packed shape
|
||||||
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
# original: b,16,2*h//16,2*w//16, packed: b,h//16*w//16,16*2*2
|
||||||
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
packed_latent_height, packed_latent_width = math.ceil(image_height / 16), math.ceil(image_width / 16)
|
||||||
|
noise_dtype = torch.float32 if is_fp8(dtype) else dtype
|
||||||
noise = torch.randn(
|
noise = torch.randn(
|
||||||
1,
|
1,
|
||||||
packed_latent_height * packed_latent_width,
|
packed_latent_height * packed_latent_width,
|
||||||
16 * 2 * 2,
|
16 * 2 * 2,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=noise_dtype,
|
||||||
generator=torch.Generator(device=device).manual_seed(seed),
|
generator=torch.Generator(device=device).manual_seed(seed),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -166,9 +168,48 @@ def generate_image(
|
|||||||
clip_l = clip_l.to(device)
|
clip_l = clip_l.to(device)
|
||||||
t5xxl = t5xxl.to(device)
|
t5xxl = t5xxl.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if is_fp8(clip_l_dtype) or is_fp8(t5xxl_dtype):
|
if is_fp8(clip_l_dtype):
|
||||||
clip_l.to(clip_l_dtype)
|
param_itr = clip_l.parameters()
|
||||||
t5xxl.to(t5xxl_dtype)
|
param_itr.__next__() # skip first
|
||||||
|
param_2nd = param_itr.__next__()
|
||||||
|
if param_2nd.dtype != clip_l_dtype:
|
||||||
|
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
|
||||||
|
clip_l.to(clip_l_dtype) # fp8
|
||||||
|
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
with accelerator.autocast():
|
||||||
|
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
|
||||||
|
|
||||||
|
if is_fp8(t5xxl_dtype):
|
||||||
|
if flux_utils.get_t5xxl_actual_dtype(t5xxl) != t5xxl_dtype or not hasattr(t5xxl, "fp8_prepared"):
|
||||||
|
logger.info(f"prepare T5xxl for fp8: set to {t5xxl_dtype}")
|
||||||
|
|
||||||
|
def prepare_fp8(text_encoder, target_dtype):
|
||||||
|
def forward_hook(module):
|
||||||
|
def forward(hidden_states):
|
||||||
|
hidden_gelu = module.act(module.wi_0(hidden_states))
|
||||||
|
hidden_linear = module.wi_1(hidden_states)
|
||||||
|
hidden_states = hidden_gelu * hidden_linear
|
||||||
|
hidden_states = module.dropout(hidden_states)
|
||||||
|
|
||||||
|
hidden_states = module.wo(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
for module in text_encoder.modules():
|
||||||
|
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||||
|
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||||
|
module.to(target_dtype)
|
||||||
|
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||||
|
# print("set", module.__class__.__name__, "hooks")
|
||||||
|
module.forward = forward_hook(module)
|
||||||
|
|
||||||
|
text_encoder.fp8_prepared = True
|
||||||
|
|
||||||
|
t5xxl.to(t5xxl_dtype)
|
||||||
|
prepare_fp8(t5xxl.encoder, torch.bfloat16)
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
|
||||||
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
|
||||||
@@ -315,10 +356,10 @@ if __name__ == "__main__":
|
|||||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
|
||||||
t5xxl.eval()
|
t5xxl.eval()
|
||||||
|
|
||||||
if is_fp8(clip_l_dtype):
|
# if is_fp8(clip_l_dtype):
|
||||||
clip_l = accelerator.prepare(clip_l)
|
# clip_l = accelerator.prepare(clip_l)
|
||||||
if is_fp8(t5xxl_dtype):
|
# if is_fp8(t5xxl_dtype):
|
||||||
t5xxl = accelerator.prepare(t5xxl)
|
# t5xxl = accelerator.prepare(t5xxl)
|
||||||
|
|
||||||
t5xxl_max_length = 256 if is_schnell else 512
|
t5xxl_max_length = 256 if is_schnell else 512
|
||||||
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||||
@@ -329,14 +370,16 @@ if __name__ == "__main__":
|
|||||||
model.eval()
|
model.eval()
|
||||||
logger.info(f"Casting model to {flux_dtype}")
|
logger.info(f"Casting model to {flux_dtype}")
|
||||||
model.to(flux_dtype) # make sure model is dtype
|
model.to(flux_dtype) # make sure model is dtype
|
||||||
if is_fp8(flux_dtype):
|
# if is_fp8(flux_dtype):
|
||||||
model = accelerator.prepare(model)
|
# model = accelerator.prepare(model)
|
||||||
|
# if args.offload:
|
||||||
|
# model = model.to("cpu")
|
||||||
|
|
||||||
# AE
|
# AE
|
||||||
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
|
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
|
||||||
ae.eval()
|
ae.eval()
|
||||||
if is_fp8(ae_dtype):
|
# if is_fp8(ae_dtype):
|
||||||
ae = accelerator.prepare(ae)
|
# ae = accelerator.prepare(ae)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_models: List[lora_flux.LoRANetwork] = []
|
lora_models: List[lora_flux.LoRANetwork] = []
|
||||||
|
|||||||
@@ -392,7 +392,7 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
|
|||||||
modules_dim[lora_name] = dim
|
modules_dim[lora_name] = dim
|
||||||
# logger.info(lora_name, value.size(), dim)
|
# logger.info(lora_name, value.size(), dim)
|
||||||
|
|
||||||
if train_t5xxl is None:
|
if train_t5xxl is None or train_t5xxl is False:
|
||||||
train_t5xxl = "lora_te3" in lora_name
|
train_t5xxl = "lora_te3" in lora_name
|
||||||
|
|
||||||
if train_t5xxl is None:
|
if train_t5xxl is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user