This commit is contained in:
Dave Lage
2026-04-01 13:10:21 +00:00
committed by GitHub
9 changed files with 77 additions and 57 deletions

View File

@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
# print(
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
# )
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

View File

@@ -307,7 +307,7 @@ class DiagonalGaussian(nn.Module):
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
return mean + std * torch.randn_like(mean, pin_memory=True)
else:
return mean
@@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -600,7 +600,7 @@ class QKNorm(torch.nn.Module):
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
return q.to(v, non_blocking=True), k.to(v, non_blocking=True)
class SelfAttention(nn.Module):
@@ -997,7 +997,7 @@ class Flux(nn.Module):
self.double_blocks = None
self.single_blocks = None
self.to(device)
self = self.to(device, non_blocking=True)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks
@@ -1081,8 +1081,8 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...]
if self.training and self.cpu_offload_checkpointing:
img = img.to(self.device)
vec = vec.to(self.device)
img = img.to(self.device, non_blocking=True)
vec = vec.to(self.device, non_blocking=True)
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
@@ -1243,7 +1243,7 @@ class ControlNetFlux(nn.Module):
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()
self.to(device)
self = self.to(device, non_blocking=True)
if self.blocks_to_swap:
self.double_blocks = save_double_blocks

View File

@@ -41,7 +41,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
text = [text] if isinstance(text, str) else text
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
text = [text] if isinstance(text, str) else text
tokens_list = []
weights_list = []

View File

@@ -4,6 +4,7 @@ import argparse
import ast
import asyncio
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import nullcontext
import datetime
import importlib
import json
@@ -26,6 +27,7 @@ import toml
# from concurrent.futures import ThreadPoolExecutor, as_completed
from torch.cuda import Stream
from tqdm import tqdm
from packaging.version import Version
@@ -1425,10 +1427,11 @@ class BaseDataset(torch.utils.data.Dataset):
return
# prepare tokenizers and text encoders
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
text_encoder.to(device)
for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
te_kwargs = {}
if te_dtype is not None:
text_encoder.to(dtype=te_dtype)
te_kwargs['dtype'] = te_dtype
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)
# create batch
is_sd3 = len(tokenizers) == 1
@@ -1450,6 +1453,8 @@ class BaseDataset(torch.utils.data.Dataset):
if len(batch) > 0:
batches.append(batch)
torch.cuda.synchronize()
# iterate batches: call text encoder and cache outputs for memory or disk
logger.info("caching text encoder outputs...")
if not is_sd3:
@@ -3136,7 +3141,10 @@ def cache_batch_latents(
images.append(image)
img_tensors = torch.stack(images, dim=0)
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
s = Stream()
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
@@ -3172,12 +3180,13 @@ def cache_batch_latents(
if not HIGH_VRAM:
clean_memory_on_device(vae.device)
torch.cuda.synchronize()
def cache_batch_text_encoder_outputs(
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
):
input_ids1 = input_ids1.to(text_encoders[0].device)
input_ids2 = input_ids2.to(text_encoders[1].device)
input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)
with torch.no_grad():
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
@@ -5652,9 +5661,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
)
# work on low-ram device
if args.lowram:
text_encoder.to(accelerator.device)
unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
unet = unet.to(accelerator.device, non_blocking=True)
vae = vae.to(accelerator.device, non_blocking=True)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -6474,7 +6483,7 @@ def sample_images_common(
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet_wrapped)
@@ -6509,7 +6518,7 @@ def sample_images_common(
requires_safety_checker=False,
clip_skip=args.clip_skip,
)
pipeline.to(distributed_state.device)
pipeline = pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)
@@ -6560,7 +6569,7 @@ def sample_images_common(
torch.set_rng_state(rng_state)
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)
vae = vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)

View File

@@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
module_to_cpu.weight.data = cuda_data_view.data.to("cpu")
stream.synchronize()