This commit is contained in:
Dave Lage
2026-04-01 13:10:21 +00:00
committed by GitHub
9 changed files with 77 additions and 57 deletions

View File

@@ -232,21 +232,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info("move vae and unet to cpu to save memory") logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device org_vae_device = vae.device
org_unet_device = unet.device org_unet_device = unet.device
vae.to("cpu") vae = vae.to("cpu")
unet.to("cpu") unet = unet.to("cpu")
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
# When TE is not be trained, it will not be prepared so we need to use explicit autocast # When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu") logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8 text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8
text_encoders[1].to(accelerator.device) text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
if text_encoders[1].dtype == torch.float8_e4m3fn: if text_encoders[1].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is # if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype) self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else: else:
# otherwise, we need to convert it to target dtype # otherwise, we need to convert it to target dtype
text_encoders[1].to(weight_dtype) text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True)
with accelerator.autocast(): with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator) dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
@@ -276,19 +276,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# move back to cpu # move back to cpu
if not self.is_train_text_encoder(args): if not self.is_train_text_encoder(args):
logger.info("move CLIP-L back to cpu") logger.info("move CLIP-L back to cpu")
text_encoders[0].to("cpu") text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True)
logger.info("move t5XXL back to cpu") logger.info("move t5XXL back to cpu")
text_encoders[1].to("cpu") text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
if not args.lowram: if not args.lowram:
logger.info("move vae and unet back to original device") logger.info("move vae and unet back to original device")
vae.to(org_vae_device) vae = vae.to(org_vae_device, non_blocking=True)
unet.to(org_unet_device) unet = unet.to(org_unet_device, non_blocking=True)
else: else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく # Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype) text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True)
text_encoders[1].to(accelerator.device) text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux): def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility text_encoders = text_encoder # for compatibility
@@ -429,7 +429,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
noisy_model_input[diff_output_pr_indices], noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None, sigmas[diff_output_pr_indices] if sigmas is not None else None,
) )
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype) target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True)
return model_pred, target, timesteps, weighting return model_pred, target, timesteps, weighting
@@ -468,8 +468,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
if index == 0: # CLIP-L if index == 0: # CLIP-L
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}") logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8 text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
text_encoder.text_model.embeddings.to(dtype=weight_dtype) text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
else: # T5XXL else: # T5XXL
def prepare_fp8(text_encoder, target_dtype): def prepare_fp8(text_encoder, target_dtype):
@@ -488,7 +488,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
for module in text_encoder.modules(): for module in text_encoder.modules():
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]: if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
# print("set", module.__class__.__name__, "to", target_dtype) # print("set", module.__class__.__name__, "to", target_dtype)
module.to(target_dtype) module = module.to(target_dtype, non_blocking=True)
if module.__class__.__name__ in ["T5DenseGatedActDense"]: if module.__class__.__name__ in ["T5DenseGatedActDense"]:
# print("set", module.__class__.__name__, "hooks") # print("set", module.__class__.__name__, "hooks")
module.forward = forward_hook(module) module.forward = forward_hook(module)
@@ -497,7 +497,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"T5XXL already prepared for fp8") logger.info(f"T5XXL already prepared for fp8")
else: else:
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks") logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
text_encoder.to(te_weight_dtype) # fp8 text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):

View File

@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
# print( # print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device" # f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# ) # )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device) module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

View File

@@ -307,7 +307,7 @@ class DiagonalGaussian(nn.Module):
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample: if self.sample:
std = torch.exp(0.5 * logvar) std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean) return mean + std * torch.randn_like(mean, pin_memory=True)
else: else:
return mean return mean
@@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -600,7 +600,7 @@ class QKNorm(torch.nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q) q = self.query_norm(q)
k = self.key_norm(k) k = self.key_norm(k)
return q.to(v), k.to(v) return q.to(v, non_blocking=True), k.to(v, non_blocking=True)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
@@ -997,7 +997,7 @@ class Flux(nn.Module):
self.double_blocks = None self.double_blocks = None
self.single_blocks = None self.single_blocks = None
self.to(device) self = self.to(device, non_blocking=True)
if self.blocks_to_swap: if self.blocks_to_swap:
self.double_blocks = save_double_blocks self.double_blocks = save_double_blocks
@@ -1081,8 +1081,8 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
if self.training and self.cpu_offload_checkpointing: if self.training and self.cpu_offload_checkpointing:
img = img.to(self.device) img = img.to(self.device, non_blocking=True)
vec = vec.to(self.device) vec = vec.to(self.device, non_blocking=True)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
@@ -1243,7 +1243,7 @@ class ControlNetFlux(nn.Module):
self.double_blocks = nn.ModuleList() self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList() self.single_blocks = nn.ModuleList()
self.to(device) self = self.to(device, non_blocking=True)
if self.blocks_to_swap: if self.blocks_to_swap:
self.double_blocks = save_double_blocks self.double_blocks = save_double_blocks

View File

@@ -41,7 +41,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
text = [text] if isinstance(text, str) else text text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)] return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text text = [text] if isinstance(text, str) else text
tokens_list = [] tokens_list = []
weights_list = [] weights_list = []

View File

@@ -4,6 +4,7 @@ import argparse
import ast import ast
import asyncio import asyncio
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import nullcontext
import datetime import datetime
import importlib import importlib
import json import json
@@ -26,6 +27,7 @@ import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed # from concurrent.futures import ThreadPoolExecutor, as_completed
from torch.cuda import Stream
from tqdm import tqdm from tqdm import tqdm
from packaging.version import Version from packaging.version import Version
@@ -1425,10 +1427,11 @@ class BaseDataset(torch.utils.data.Dataset):
return return
# prepare tokenizers and text encoders # prepare tokenizers and text encoders
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes): for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
text_encoder.to(device) te_kwargs = {}
if te_dtype is not None: if te_dtype is not None:
text_encoder.to(dtype=te_dtype) te_kwargs['dtype'] = te_dtype
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)
# create batch # create batch
is_sd3 = len(tokenizers) == 1 is_sd3 = len(tokenizers) == 1
@@ -1450,6 +1453,8 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0: if len(batch) > 0:
batches.append(batch) batches.append(batch)
torch.cuda.synchronize()
# iterate batches: call text encoder and cache outputs for memory or disk # iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...") logger.info("caching text encoder outputs...")
if not is_sd3: if not is_sd3:
@@ -3136,7 +3141,10 @@ def cache_batch_latents(
images.append(image) images.append(image)
img_tensors = torch.stack(images, dim=0) img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
s = Stream()
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)
with torch.no_grad(): with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
@@ -3172,12 +3180,13 @@ def cache_batch_latents(
if not HIGH_VRAM: if not HIGH_VRAM:
clean_memory_on_device(vae.device) clean_memory_on_device(vae.device)
torch.cuda.synchronize()
def cache_batch_text_encoder_outputs( def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
): ):
input_ids1 = input_ids1.to(text_encoders[0].device) input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
input_ids2 = input_ids2.to(text_encoders[1].device) input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)
with torch.no_grad(): with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl( b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
@@ -5652,9 +5661,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
) )
# work on low-ram device # work on low-ram device
if args.lowram: if args.lowram:
text_encoder.to(accelerator.device) text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
unet.to(accelerator.device) unet = unet.to(accelerator.device, non_blocking=True)
vae.to(accelerator.device) vae = vae.to(accelerator.device, non_blocking=True)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -6474,7 +6483,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s) # unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet_wrapped) unet = accelerator.unwrap_model(unet_wrapped)
@@ -6509,7 +6518,7 @@ def sample_images_common(
requires_safety_checker=False, requires_safety_checker=False,
clip_skip=args.clip_skip, clip_skip=args.clip_skip,
) )
pipeline.to(distributed_state.device) pipeline = pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample" save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
@@ -6560,7 +6569,7 @@ def sample_images_common(
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None: if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae = vae.to(org_vae_device)
clean_memory_on_device(accelerator.device) clean_memory_on_device(accelerator.device)

View File

@@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# cuda to cpu # cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs: for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream) cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True) module_to_cpu.weight.data = cuda_data_view.data.to("cpu")
stream.synchronize() stream.synchronize()

View File

@@ -49,11 +49,11 @@ class OFTModule(torch.nn.Module):
if type(alpha) == torch.Tensor: if type(alpha) == torch.Tensor:
alpha = alpha.detach().numpy() alpha = alpha.detach().numpy()
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha # original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
self.constraint = alpha * out_dim self.constraint = alpha * out_dim
self.register_buffer("alpha", torch.tensor(alpha)) self.register_buffer("alpha", torch.tensor(alpha))
self.block_size = out_dim // self.num_blocks self.block_size = out_dim // self.num_blocks

View File

@@ -239,8 +239,8 @@ def train(args):
args.mixed_precision == "fp16" args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.") accelerator.print("enable full fp16 training.")
unet.to(weight_dtype) unet = unet.to(weight_dtype)
text_encoder.to(weight_dtype) text_encoder = text_encoder.to(weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい # acceleratorがなんかよろしくやってくれるらしい
if args.deepspeed: if args.deepspeed:
@@ -335,6 +335,7 @@ def train(args):
text_encoder.train() text_encoder.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
optimizer.train()
current_step.value = global_step current_step.value = global_step
# 指定したステップ数でText Encoderの学習を止める # 指定したステップ数でText Encoderの学習を止める
if global_step == args.stop_text_encoder_training: if global_step == args.stop_text_encoder_training:

View File

@@ -205,8 +205,8 @@ class NetworkTrainer:
return not args.network_train_unet_only return not args.network_train_unet_only
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype): def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
for t_enc in text_encoders: for i, t_enc in enumerate(text_encoders):
t_enc.to(accelerator.device, dtype=weight_dtype) text_encoders[i] = t_enc.to(accelerator.device, dtype=weight_dtype)
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs): def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
@@ -306,7 +306,7 @@ class NetworkTrainer:
indices=diff_output_pr_indices, indices=diff_output_pr_indices,
) )
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype) target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype, non_blocking=True)
return noise_pred, target, timesteps, None return noise_pred, target, timesteps, None
@@ -335,7 +335,7 @@ class NetworkTrainer:
text_encoder.text_model.embeddings.requires_grad_(True) text_encoder.text_model.embeddings.requires_grad_(True)
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype): def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
text_encoder.text_model.embeddings.to(dtype=weight_dtype) text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
@@ -373,11 +373,11 @@ class NetworkTrainer:
""" """
with torch.no_grad(): with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: if "latents" in batch and batch["latents"] is not None:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device, non_blocking=True))
else: else:
# latentに変換 # latentに変換
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype, non_blocking=True))
else: else:
chunks = [ chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
@@ -385,7 +385,7 @@ class NetworkTrainer:
list_latents = [] list_latents = []
for chunk in chunks: for chunk in chunks:
with torch.no_grad(): with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype, non_blocking=True))
list_latents.append(chunk) list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0) latents = torch.cat(list_latents, dim=0)
@@ -414,14 +414,14 @@ class NetworkTrainer:
weights_list, weights_list,
) )
else: else:
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]] input_ids = [ids.to(accelerator.device, non_blocking=True) for ids in batch["input_ids_list"]]
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens( encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
tokenize_strategy, tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders), self.get_models_for_text_encoding(args, accelerator, text_encoders),
input_ids, input_ids,
) )
if args.full_fp16: if args.full_fp16:
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] encoded_text_encoder_conds = [c.to(weight_dtype, non_blocking=True) for c in encoded_text_encoder_conds]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds # if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0: if len(text_encoder_conds) == 0:
@@ -432,6 +432,8 @@ class NetworkTrainer:
if encoded_text_encoder_conds[i] is not None: if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i] text_encoder_conds[i] = encoded_text_encoder_conds[i]
torch.cuda.synchronize()
# sample noise, call unet, get target # sample noise, call unet, get target
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target( noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
args, args,
@@ -799,13 +801,13 @@ class NetworkTrainer:
args.mixed_precision == "fp16" args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" ), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.") accelerator.print("enable full fp16 training.")
network.to(weight_dtype) network = network.to(weight_dtype)
elif args.full_bf16: elif args.full_bf16:
assert ( assert (
args.mixed_precision == "bf16" args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。" ), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.") accelerator.print("enable full bf16 training.")
network.to(weight_dtype) network = network.to(weight_dtype)
unet_weight_dtype = te_weight_dtype = weight_dtype unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram # Experimental Feature: Put base model into fp8 to save vram
@@ -827,7 +829,7 @@ class NetworkTrainer:
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}") # logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above # unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}") logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator unet = unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
unet.requires_grad_(False) unet.requires_grad_(False)
if self.cast_unet(args): if self.cast_unet(args):
@@ -841,7 +843,7 @@ class NetworkTrainer:
# nn.Embedding not support FP8 # nn.Embedding not support FP8
if te_weight_dtype != weight_dtype: if te_weight_dtype != weight_dtype:
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype) self.prepare_text_encoder_fp8(i, text_encoders[i], te_weight_dtype, weight_dtype)
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if args.deepspeed: if args.deepspeed:
@@ -903,7 +905,7 @@ class NetworkTrainer:
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False) vae.requires_grad_(False)
vae.eval() vae.eval()
vae.to(accelerator.device, dtype=vae_dtype) vae = vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする # 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16: if args.full_fp16:
@@ -1383,6 +1385,8 @@ class NetworkTrainer:
torch.cuda.set_rng_state(gpu_rng_state) torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state) random.setstate(python_rng_state)
torch.cuda.empty_cache()
for epoch in range(epoch_to_start, num_train_epochs): for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
@@ -1439,6 +1443,12 @@ class NetworkTrainer:
if hasattr(network, "update_norms"): if hasattr(network, "update_norms"):
network.update_norms() network.update_norms()
torch.cuda.synchronize() # Ensure GPU ops complete before next batch
# Periodic cleanup
if step % 50 == 0:
torch.cuda.empty_cache()
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)