mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 16:39:42 +00:00
Add non_blocking to loading and moving tensors
This commit is contained in:
@@ -232,21 +232,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info("move vae and unet to cpu to save memory")
|
||||
org_vae_device = vae.device
|
||||
org_unet_device = unet.device
|
||||
vae.to("cpu")
|
||||
unet.to("cpu")
|
||||
vae = vae.to("cpu")
|
||||
unet = unet.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||
logger.info("move text encoders to gpu")
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8
|
||||
text_encoders[1].to(accelerator.device)
|
||||
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True) # always not fp8
|
||||
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
|
||||
|
||||
if text_encoders[1].dtype == torch.float8_e4m3fn:
|
||||
# if we load fp8 weights, the model is already fp8, so we use it as is
|
||||
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
|
||||
else:
|
||||
# otherwise, we need to convert it to target dtype
|
||||
text_encoders[1].to(weight_dtype)
|
||||
text_encoders[1] = text_encoders[1].to(weight_dtype, non_blocking=True)
|
||||
|
||||
with accelerator.autocast():
|
||||
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)
|
||||
@@ -276,19 +276,19 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# move back to cpu
|
||||
if not self.is_train_text_encoder(args):
|
||||
logger.info("move CLIP-L back to cpu")
|
||||
text_encoders[0].to("cpu")
|
||||
text_encoders[0] = text_encoders[0].to("cpu", non_blocking=True)
|
||||
logger.info("move t5XXL back to cpu")
|
||||
text_encoders[1].to("cpu")
|
||||
text_encoders[1] = text_encoders[1].to("cpu", non_blocking=True)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if not args.lowram:
|
||||
logger.info("move vae and unet back to original device")
|
||||
vae.to(org_vae_device)
|
||||
unet.to(org_unet_device)
|
||||
vae = vae.to(org_vae_device, non_blocking=True)
|
||||
unet = unet.to(org_unet_device, non_blocking=True)
|
||||
else:
|
||||
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
||||
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoders[1].to(accelerator.device)
|
||||
text_encoders[0] = text_encoders[0].to(accelerator.device, dtype=weight_dtype, non_blocking=True)
|
||||
text_encoders[1] = text_encoders[1].to(accelerator.device, non_blocking=True)
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
|
||||
text_encoders = text_encoder # for compatibility
|
||||
@@ -429,7 +429,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype, non_blocking=True)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
@@ -468,8 +468,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
if index == 0: # CLIP-L
|
||||
logger.info(f"prepare CLIP-L for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
|
||||
text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
else: # T5XXL
|
||||
|
||||
def prepare_fp8(text_encoder, target_dtype):
|
||||
@@ -488,7 +488,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
for module in text_encoder.modules():
|
||||
if module.__class__.__name__ in ["T5LayerNorm", "Embedding"]:
|
||||
# print("set", module.__class__.__name__, "to", target_dtype)
|
||||
module.to(target_dtype)
|
||||
module = module.to(target_dtype, non_blocking=True)
|
||||
if module.__class__.__name__ in ["T5DenseGatedActDense"]:
|
||||
# print("set", module.__class__.__name__, "hooks")
|
||||
module.forward = forward_hook(module)
|
||||
@@ -497,7 +497,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
logger.info(f"T5XXL already prepared for fp8")
|
||||
else:
|
||||
logger.info(f"prepare T5XXL for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}, add hooks")
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder = text_encoder.to(te_weight_dtype, non_blocking=True) # fp8
|
||||
prepare_fp8(text_encoder, weight_dtype)
|
||||
|
||||
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||
|
||||
@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
# print(
|
||||
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
||||
# )
|
||||
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
||||
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
@@ -307,7 +307,7 @@ class DiagonalGaussian(nn.Module):
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
return mean + std * torch.randn_like(mean, pin_memory=True)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
@@ -600,7 +600,7 @@ class QKNorm(torch.nn.Module):
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
return q.to(v, non_blocking=True), k.to(v, non_blocking=True)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -997,7 +997,7 @@ class Flux(nn.Module):
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
self = self.to(device, non_blocking=True)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
@@ -1081,8 +1081,8 @@ class Flux(nn.Module):
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
if self.training and self.cpu_offload_checkpointing:
|
||||
img = img.to(self.device)
|
||||
vec = vec.to(self.device)
|
||||
img = img.to(self.device, non_blocking=True)
|
||||
vec = vec.to(self.device, non_blocking=True)
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -1243,7 +1243,7 @@ class ControlNetFlux(nn.Module):
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
self = self.to(device, non_blocking=True)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
|
||||
@@ -30,81 +30,171 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
)
|
||||
else:
|
||||
self.tokenizer = self._load_tokenizer(CLIPTokenizer, TOKENIZER_ID, tokenizer_cache_dir=tokenizer_cache_dir)
|
||||
|
||||
if max_length is None:
|
||||
self.max_length = self.tokenizer.model_max_length
|
||||
else:
|
||||
self.max_length = max_length + 2
|
||||
|
||||
|
||||
self.break_separator = "BREAK"
|
||||
|
||||
def _split_on_break(self, text: str) -> List[str]:
|
||||
"""Split text on BREAK separator (case-sensitive), filtering empty segments."""
|
||||
segments = text.split(self.break_separator)
|
||||
# Filter out empty or whitespace-only segments
|
||||
filtered = [seg.strip() for seg in segments if seg.strip()]
|
||||
# Return at least one segment to maintain consistency
|
||||
return filtered if filtered else [""]
|
||||
|
||||
def _tokenize_segments(self, segments: List[str], weighted: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Tokenize multiple segments and concatenate them."""
|
||||
if len(segments) == 1:
|
||||
# No BREAK present, use existing logic
|
||||
if weighted:
|
||||
return self._get_input_ids(self.tokenizer, segments[0], self.max_length, weighted=True)
|
||||
else:
|
||||
tokens = self._get_input_ids(self.tokenizer, segments[0], self.max_length)
|
||||
return tokens, None
|
||||
|
||||
# Multiple segments - tokenize each separately
|
||||
all_tokens = []
|
||||
all_weights = [] if weighted else None
|
||||
|
||||
for segment in segments:
|
||||
if weighted:
|
||||
seg_tokens, seg_weights = self._get_input_ids(self.tokenizer, segment, self.max_length, weighted=True)
|
||||
all_tokens.append(seg_tokens)
|
||||
all_weights.append(seg_weights)
|
||||
else:
|
||||
seg_tokens = self._get_input_ids(self.tokenizer, segment, self.max_length)
|
||||
all_tokens.append(seg_tokens)
|
||||
|
||||
# Concatenate along the sequence dimension (dim=1 for tokens that are [batch, seq_len] or [n_chunks, seq_len])
|
||||
combined_tokens = torch.cat(all_tokens, dim=1) if all_tokens[0].dim() == 2 else torch.cat(all_tokens, dim=0)
|
||||
combined_weights = None
|
||||
if weighted:
|
||||
combined_weights = torch.cat(all_weights, dim=1) if all_weights[0].dim() == 2 else torch.cat(all_weights, dim=0)
|
||||
|
||||
return combined_tokens, combined_weights
|
||||
|
||||
def tokenize(self, text: Union[str, List[str]]) -> List[torch.Tensor]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
|
||||
tokens_list = []
|
||||
for t in text:
|
||||
segments = self._split_on_break(t)
|
||||
tokens, _ = self._tokenize_segments(segments, weighted=False)
|
||||
tokens_list.append(tokens)
|
||||
|
||||
# Pad tokens to same length for stacking
|
||||
max_length = max(t.shape[-1] for t in tokens_list)
|
||||
padded_tokens = []
|
||||
for tokens in tokens_list:
|
||||
if tokens.shape[-1] < max_length:
|
||||
# Pad with pad_token_id
|
||||
pad_size = max_length - tokens.shape[-1]
|
||||
if tokens.dim() == 2:
|
||||
padding = torch.full((tokens.shape[0], pad_size), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=1)
|
||||
else:
|
||||
padding = torch.full((pad_size,), self.tokenizer.pad_token_id, dtype=tokens.dtype)
|
||||
tokens = torch.cat([tokens, padding], dim=0)
|
||||
padded_tokens.append(tokens)
|
||||
|
||||
return [torch.stack(padded_tokens, dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
for t in text:
|
||||
tokens, weights = self._get_input_ids(self.tokenizer, t, self.max_length, weighted=True)
|
||||
segments = self._split_on_break(t)
|
||||
tokens, weights = self._tokenize_segments(segments, weighted=True)
|
||||
tokens_list.append(tokens)
|
||||
weights_list.append(weights)
|
||||
|
||||
return [torch.stack(tokens_list, dim=0)], [torch.stack(weights_list, dim=0)]
|
||||
|
||||
|
||||
class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
def __init__(self, clip_skip: Optional[int] = None) -> None:
|
||||
self.clip_skip = clip_skip
|
||||
|
||||
|
||||
def _encode_with_clip_skip(self, text_encoder: Any, tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode tokens with optional CLIP skip."""
|
||||
if self.clip_skip is None:
|
||||
return text_encoder(tokens)[0]
|
||||
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
return text_encoder.text_model.final_layer_norm(hidden_states)
|
||||
|
||||
def _reconstruct_embeddings(self, encoder_hidden_states: torch.Tensor, tokens: torch.Tensor,
|
||||
max_token_length: int, model_max_length: int,
|
||||
tokenizer: Any) -> torch.Tensor:
|
||||
"""Reconstruct embeddings from chunked encoding."""
|
||||
v1 = tokenizer.pad_token_id == tokenizer.eos_token_id
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2]
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == tokenizer.eos_token:
|
||||
chunk[j, 0] = chunk[j, 1]
|
||||
states_list.append(chunk)
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2])
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1))
|
||||
|
||||
return torch.cat(states_list, dim=1)
|
||||
|
||||
def _apply_weights_single_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for single chunk case (no max_token_length)."""
|
||||
return encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
|
||||
def _apply_weights_multi_chunk(self, encoder_hidden_states: torch.Tensor,
|
||||
weights: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply weights for multi-chunk case (with max_token_length)."""
|
||||
for i in range(weights.shape[1]):
|
||||
start_idx = i * 75 + 1
|
||||
end_idx = i * 75 + 76
|
||||
encoder_hidden_states[:, start_idx:end_idx] = (
|
||||
encoder_hidden_states[:, start_idx:end_idx] * weights[:, i, 1:-1].unsqueeze(-1)
|
||||
)
|
||||
return encoder_hidden_states
|
||||
|
||||
def encode_tokens(
|
||||
self, tokenize_strategy: TokenizeStrategy, models: List[Any], tokens: List[torch.Tensor]
|
||||
) -> List[torch.Tensor]:
|
||||
text_encoder = models[0]
|
||||
tokens = tokens[0]
|
||||
sd_tokenize_strategy = tokenize_strategy # type: SdTokenizeStrategy
|
||||
|
||||
# tokens: b,n,77
|
||||
|
||||
b_size = tokens.size()[0]
|
||||
max_token_length = tokens.size()[1] * tokens.size()[2]
|
||||
model_max_length = sd_tokenize_strategy.tokenizer.model_max_length
|
||||
tokens = tokens.reshape((-1, model_max_length)) # batch_size*3, 77
|
||||
|
||||
|
||||
tokens = tokens.reshape((-1, model_max_length))
|
||||
tokens = tokens.to(text_encoder.device)
|
||||
|
||||
if self.clip_skip is None:
|
||||
encoder_hidden_states = text_encoder(tokens)[0]
|
||||
else:
|
||||
enc_out = text_encoder(tokens, output_hidden_states=True, return_dict=True)
|
||||
encoder_hidden_states = enc_out["hidden_states"][-self.clip_skip]
|
||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
|
||||
|
||||
# bs*3, 77, 768 or 1024
|
||||
|
||||
encoder_hidden_states = self._encode_with_clip_skip(text_encoder, tokens)
|
||||
encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
|
||||
|
||||
|
||||
if max_token_length != model_max_length:
|
||||
v1 = sd_tokenize_strategy.tokenizer.pad_token_id == sd_tokenize_strategy.tokenizer.eos_token_id
|
||||
if not v1:
|
||||
# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
chunk = encoder_hidden_states[:, i : i + model_max_length - 2] # <BOS> の後から 最後の前まで
|
||||
if i > 0:
|
||||
for j in range(len(chunk)):
|
||||
if tokens[j, 1] == sd_tokenize_strategy.tokenizer.eos_token:
|
||||
# 空、つまり <BOS> <EOS> <PAD> ...のパターン
|
||||
chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
|
||||
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
else:
|
||||
# v1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
|
||||
states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] # <BOS>
|
||||
for i in range(1, max_token_length, model_max_length):
|
||||
states_list.append(encoder_hidden_states[:, i : i + model_max_length - 2]) # <BOS> の後から <EOS> の前まで
|
||||
states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # <EOS>
|
||||
encoder_hidden_states = torch.cat(states_list, dim=1)
|
||||
|
||||
encoder_hidden_states = self._reconstruct_embeddings(
|
||||
encoder_hidden_states, tokens, max_token_length,
|
||||
model_max_length, sd_tokenize_strategy.tokenizer
|
||||
)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
def encode_tokens_with_weights(
|
||||
self,
|
||||
tokenize_strategy: TokenizeStrategy,
|
||||
@@ -113,23 +203,15 @@ class SdTextEncodingStrategy(TextEncodingStrategy):
|
||||
weights_list: List[torch.Tensor],
|
||||
) -> List[torch.Tensor]:
|
||||
encoder_hidden_states = self.encode_tokens(tokenize_strategy, models, tokens_list)[0]
|
||||
|
||||
weights = weights_list[0].to(encoder_hidden_states.device)
|
||||
|
||||
# apply weights
|
||||
if weights.shape[1] == 1: # no max_token_length
|
||||
# weights: ((b, 1, 77), (b, 1, 77)), hidden_states: (b, 77, 768), (b, 77, 768)
|
||||
encoder_hidden_states = encoder_hidden_states * weights.squeeze(1).unsqueeze(2)
|
||||
|
||||
if weights.shape[1] == 1:
|
||||
encoder_hidden_states = self._apply_weights_single_chunk(encoder_hidden_states, weights)
|
||||
else:
|
||||
# weights: ((b, n, 77), (b, n, 77)), hidden_states: (b, n*75+2, 768), (b, n*75+2, 768)
|
||||
for i in range(weights.shape[1]):
|
||||
encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] = encoder_hidden_states[:, i * 75 + 1 : i * 75 + 76] * weights[
|
||||
:, i, 1:-1
|
||||
].unsqueeze(-1)
|
||||
|
||||
encoder_hidden_states = self._apply_weights_multi_chunk(encoder_hidden_states, weights)
|
||||
|
||||
return [encoder_hidden_states]
|
||||
|
||||
|
||||
class SdSdxlLatentsCachingStrategy(LatentsCachingStrategy):
|
||||
# sd and sdxl share the same strategy. we can make them separate, but the difference is only the suffix.
|
||||
# and we keep the old npz for the backward compatibility.
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
@@ -26,6 +27,7 @@ import toml
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from torch.cuda import Stream
|
||||
from tqdm import tqdm
|
||||
from packaging.version import Version
|
||||
|
||||
@@ -1415,10 +1417,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
|
||||
text_encoder.to(device)
|
||||
for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
|
||||
te_kwargs = {}
|
||||
if te_dtype is not None:
|
||||
text_encoder.to(dtype=te_dtype)
|
||||
te_kwargs['dtype'] = te_dtype
|
||||
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)
|
||||
|
||||
# create batch
|
||||
is_sd3 = len(tokenizers) == 1
|
||||
@@ -1440,6 +1443,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
if not is_sd3:
|
||||
@@ -3120,7 +3125,10 @@ def cache_batch_latents(
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
s = Stream()
|
||||
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
@@ -3156,12 +3164,13 @@ def cache_batch_latents(
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
@@ -5619,9 +5628,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
)
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
|
||||
unet = unet.to(accelerator.device, non_blocking=True)
|
||||
vae = vae.to(accelerator.device, non_blocking=True)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -6435,7 +6444,7 @@ def sample_images_common(
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet_wrapped)
|
||||
@@ -6470,7 +6479,7 @@ def sample_images_common(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.to(distributed_state.device)
|
||||
pipeline = pipeline.to(distributed_state.device)
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -6521,7 +6530,7 @@ def sample_images_common(
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
vae = vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu")
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
|
||||
@@ -49,11 +49,11 @@ class OFTModule(torch.nn.Module):
|
||||
|
||||
if type(alpha) == torch.Tensor:
|
||||
alpha = alpha.detach().numpy()
|
||||
|
||||
|
||||
# constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility
|
||||
# original alpha is 1e-5, so we use 1e-2 or 1e-4 for alpha
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
self.constraint = alpha * out_dim
|
||||
|
||||
self.register_buffer("alpha", torch.tensor(alpha))
|
||||
|
||||
self.block_size = out_dim // self.num_blocks
|
||||
|
||||
@@ -239,8 +239,8 @@ def train(args):
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
unet.to(weight_dtype)
|
||||
text_encoder.to(weight_dtype)
|
||||
unet = unet.to(weight_dtype)
|
||||
text_encoder = text_encoder.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if args.deepspeed:
|
||||
@@ -335,6 +335,7 @@ def train(args):
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
optimizer.train()
|
||||
current_step.value = global_step
|
||||
# 指定したステップ数でText Encoderの学習を止める
|
||||
if global_step == args.stop_text_encoder_training:
|
||||
@@ -384,7 +385,7 @@ def train(args):
|
||||
else:
|
||||
target = noise
|
||||
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, latents, noise_scheduler)
|
||||
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
|
||||
@@ -222,8 +222,8 @@ class NetworkTrainer:
|
||||
return not args.network_train_unet_only
|
||||
|
||||
def cache_text_encoder_outputs_if_needed(self, args, accelerator, unet, vae, text_encoders, dataset, weight_dtype):
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
for i, t_enc in enumerate(text_encoders):
|
||||
text_encoders[i] = t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype, **kwargs):
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds[0]).sample
|
||||
@@ -323,7 +323,7 @@ class NetworkTrainer:
|
||||
indices=diff_output_pr_indices,
|
||||
)
|
||||
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
|
||||
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype, non_blocking=True)
|
||||
|
||||
return noise_pred, target, timesteps, None
|
||||
|
||||
@@ -352,7 +352,7 @@ class NetworkTrainer:
|
||||
text_encoder.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
|
||||
text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
text_encoder.text_model.embeddings = text_encoder.text_model.embeddings.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
@@ -390,11 +390,11 @@ class NetworkTrainer:
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device, non_blocking=True))
|
||||
else:
|
||||
# latentに変換
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype, non_blocking=True))
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
@@ -402,7 +402,7 @@ class NetworkTrainer:
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype, non_blocking=True))
|
||||
list_latents.append(chunk)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
|
||||
@@ -431,14 +431,14 @@ class NetworkTrainer:
|
||||
weights_list,
|
||||
)
|
||||
else:
|
||||
input_ids = [ids.to(accelerator.device) for ids in batch["input_ids_list"]]
|
||||
input_ids = [ids.to(accelerator.device, non_blocking=True) for ids in batch["input_ids_list"]]
|
||||
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
tokenize_strategy,
|
||||
self.get_models_for_text_encoding(args, accelerator, text_encoders),
|
||||
input_ids,
|
||||
)
|
||||
if args.full_fp16:
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds]
|
||||
encoded_text_encoder_conds = [c.to(weight_dtype, non_blocking=True) for c in encoded_text_encoder_conds]
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
@@ -449,6 +449,8 @@ class NetworkTrainer:
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# sample noise, call unet, get target
|
||||
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
|
||||
args,
|
||||
@@ -816,13 +818,13 @@ class NetworkTrainer:
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
accelerator.print("enable full fp16 training.")
|
||||
network.to(weight_dtype)
|
||||
network = network.to(weight_dtype)
|
||||
elif args.full_bf16:
|
||||
assert (
|
||||
args.mixed_precision == "bf16"
|
||||
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
|
||||
accelerator.print("enable full bf16 training.")
|
||||
network.to(weight_dtype)
|
||||
network = network.to(weight_dtype)
|
||||
|
||||
unet_weight_dtype = te_weight_dtype = weight_dtype
|
||||
# Experimental Feature: Put base model into fp8 to save vram
|
||||
@@ -844,7 +846,7 @@ class NetworkTrainer:
|
||||
# logger.info(f"set U-Net weight dtype to {unet_weight_dtype}, device to {accelerator.device}")
|
||||
# unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
logger.info(f"set U-Net weight dtype to {unet_weight_dtype}")
|
||||
unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
||||
unet = unet.to(dtype=unet_weight_dtype) # do not move to device because unet is not prepared by accelerator
|
||||
|
||||
unet.requires_grad_(False)
|
||||
if self.cast_unet(args):
|
||||
@@ -858,7 +860,7 @@ class NetworkTrainer:
|
||||
|
||||
# nn.Embedding not support FP8
|
||||
if te_weight_dtype != weight_dtype:
|
||||
self.prepare_text_encoder_fp8(i, t_enc, te_weight_dtype, weight_dtype)
|
||||
self.prepare_text_encoder_fp8(i, text_encoders[i], te_weight_dtype, weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
|
||||
if args.deepspeed:
|
||||
@@ -920,7 +922,7 @@ class NetworkTrainer:
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae = vae.to(accelerator.device, dtype=vae_dtype)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
@@ -1398,6 +1400,8 @@ class NetworkTrainer:
|
||||
torch.cuda.set_rng_state(gpu_rng_state)
|
||||
random.setstate(python_rng_state)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
for epoch in range(epoch_to_start, num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||
current_epoch.value = epoch + 1
|
||||
@@ -1454,6 +1458,12 @@ class NetworkTrainer:
|
||||
if hasattr(network, "update_norms"):
|
||||
network.update_norms()
|
||||
|
||||
torch.cuda.synchronize() # Ensure GPU ops complete before next batch
|
||||
|
||||
# Periodic cleanup
|
||||
if step % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
Reference in New Issue
Block a user