This commit is contained in:
Dave Lage
2026-04-01 13:10:21 +00:00
committed by GitHub
9 changed files with 77 additions and 57 deletions

View File

@@ -232,21 +232,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
vae = vae.to("cpu")
unet = unet.to("cpu")
clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
text_encoders[1].to(accelerator.device)
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype)
text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True)
with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -276,19 +276,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu")
text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True)
logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu")
text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True)
clean_memory_on_device(accelerator.device)
if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
vae = vae.to(org_vae_device, non_blocking=True)
unet = unet.to(org_unet_device, non_blocking=True)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True)
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
@@ -429,7 +429,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True)
return model_pred, target, timesteps, weighting
@@ -468,8 +468,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL
def prepare_fp8(text_encoder, target_dtype):
@@ -488,7 +488,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype)
module = module.to(target_dtype, non_blocking=True)
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module)
@@ -497,7 +497,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"T5XXL already prepared for fp8")
else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):

View File

@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

View File

@@ -307,7 +307,7 @@ class DiagonalGaussian(nn.Module):
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
return mean + std * torch.randn_like(mean, pin_memory=True)
else:
return mean
@@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -600,7 +600,7 @@ class QKNorm(torch.nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
return q.to(v, non_blocking=True), k.to(v, non_blocking=True)
class SelfAttention(nn.Module):
@@ -997,7 +997,7 @@ class Flux(nn.Module):
self.double_blocks = None
self.single_blocks = None
self.to(device)
self = self.to(device, non_blocking=True)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
@@ -1081,8 +1081,8 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
if self.training and self.cpu_offload_checkpointing:
img = img.to(self.device)
vec = vec.to(self.device)
img = img.to(self.device, non_blocking=True)
vec = vec.to(self.device, non_blocking=True)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
@@ -1243,7 +1243,7 @@ class ControlNetFlux(nn.Module):
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)
self = self.to(device, non_blocking=True)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks

View File

@@ -41,7 +41,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []

View File

@@ -4,6 +4,7 @@ import argparse
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import nullcontext
import datetime
import importlib
import json
@@ -26,6 +27,7 @@ import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from torch.cuda import Stream
from tqdm import tqdm
from packaging.version import Version
@@ -1425,10 +1427,11 @@ class BaseDataset(torch.utils.data.Dataset):
return
# prepare tokenizers and text encoders
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
text_encoder.to(device)
for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
te_kwargs = {}
if te_dtype is not None:
text_encoder.to(dtype=te_dtype)
te_kwargs['dtype'] = te_dtype
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)
# create batch
is_sd3 = len(tokenizers) == 1
@@ -1450,6 +1453,8 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append(batch)
torch.cuda.synchronize()
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
if not is_sd3:
@@ -3136,7 +3141,10 @@ def cache_batch_latents(
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
s = Stream()
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
@@ -3172,12 +3180,13 @@ def cache_batch_latents(
if not HIGH_VRAM:
clean_memory_on_device(vae.device)
torch.cuda.synchronize()
def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
@@ -5652,9 +5661,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
unet = unet.to(accelerator.device, non_blocking=True)
vae = vae.to(accelerator.device, non_blocking=True)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -6474,7 +6483,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet_wrapped)
@@ -6509,7 +6518,7 @@ def sample_images_common(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(distributed_state.device)
pipeline = pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
@@ -6560,7 +6569,7 @@ def sample_images_common(
torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
vae = vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)

View File

@@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu")
stream.synchronize()

View File

@@ -49,11 +49,11 @@ class OFTModule(torch.nn.Module):
if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy()
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
self.constraint = alpha * out_dim
self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks

View File

@@ -239,8 +239,8 @@ def train(args):
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
unet.to(weight_dtype)
text_encoder.to(weight_dtype)
unet = unet.to(weight_dtype)
text_encoder = text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい
if args.deepspeed:
@@ -335,6 +335,7 @@ def train(args):
text_encoder.train()
for step, batch in enumerate(train_dataloader):
optimizer.train()
current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training:

View File

@@ -205,8 +205,8 @@ class NetworkTrainer:
return not args.network_train_unet_only
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
for i, t_enc in enumerate(text_encoders):
text_encoders[i] = t_enc.to(accelerator.device, dtype=weight_dtype)
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
@@ -306,7 +306,7 @@ class NetworkTrainer:
indices=diff_output_pr_indices,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype, non_blocking=True)
return noise_pred, target, timesteps, None
@@ -335,7 +335,7 @@ class NetworkTrainer:
text_encoder.text_model.embeddings.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -373,11 +373,11 @@ class NetworkTrainer:
"""
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device, non_blocking=True))
else:
# latentに変換
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype, non_blocking=True))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
@@ -385,7 +385,7 @@ class NetworkTrainer:
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype, non_blocking=True))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
@@ -414,14 +414,14 @@ class NetworkTrainer:
weights_list,
)
else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
input_ids = [ids.to(accelerator.device, non_blocking=True) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids,
)
if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
encoded_text_encoder_conds = [c.to(weight_dtype, non_blocking=True) for c in encoded_text_encoder_conds]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
@@ -432,6 +432,8 @@ class NetworkTrainer:
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
torch.cuda.synchronize()
# sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args,
@@ -799,13 +801,13 @@ class NetworkTrainer:
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
network.to(weight_dtype)
network = network.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)
network = network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
@@ -827,7 +829,7 @@ class NetworkTrainer:
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
unet = unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
unet.requires_grad_(False)
if self.cast_unet(args):
@@ -841,7 +843,7 @@ class NetworkTrainer:
# nn.Embedding not support FP8
if te_weight_dtype != weight_dtype:
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
self.prepare_text_encoder_fp8(i, text_encoders[i], te_weight_dtype, weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed:
@@ -903,7 +905,7 @@ class NetworkTrainer:
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
vae.eval()
vae.to(accelerator.device, dtype=vae_dtype)
vae = vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
@@ -1383,6 +1385,8 @@ class NetworkTrainer:
torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state)
torch.cuda.empty_cache()
for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1
@@ -1439,6 +1443,12 @@ class NetworkTrainer:
if hasattr(network, "update_norms"):
network.update_norms()
torch.cuda.synchronize() # Ensure GPU ops complete before next batch
# Periodic cleanup
if step % 50 == 0:
torch.cuda.empty_cache()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)