mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Merge 46f9e24b24 into 1dae34b0af
This commit is contained in:
@@ -53,7 +53,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye
|
||||
# print(
|
||||
# f"Module {module_to_cuda_name} not found in CPU model or shape mismatch, so not swapping and moving to device"
|
||||
# )
|
||||
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device)
|
||||
module_to_cuda.weight.data = module_to_cuda.weight.data.to(device, non_blocking=True)
|
||||
|
||||
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
|
||||
|
||||
|
||||
@@ -307,7 +307,7 @@ class DiagonalGaussian(nn.Module):
|
||||
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
||||
if self.sample:
|
||||
std = torch.exp(0.5 * logvar)
|
||||
return mean + std * torch.randn_like(mean)
|
||||
return mean + std * torch.randn_like(mean, pin_memory=True)
|
||||
else:
|
||||
return mean
|
||||
|
||||
@@ -532,7 +532,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
"""
|
||||
t = time_factor * t
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, pin_memory=True) / half).to(t.device, non_blocking=True)
|
||||
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
@@ -600,7 +600,7 @@ class QKNorm(torch.nn.Module):
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
||||
q = self.query_norm(q)
|
||||
k = self.key_norm(k)
|
||||
return q.to(v), k.to(v)
|
||||
return q.to(v, non_blocking=True), k.to(v, non_blocking=True)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
@@ -997,7 +997,7 @@ class Flux(nn.Module):
|
||||
self.double_blocks = None
|
||||
self.single_blocks = None
|
||||
|
||||
self.to(device)
|
||||
self = self.to(device, non_blocking=True)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
@@ -1081,8 +1081,8 @@ class Flux(nn.Module):
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
if self.training and self.cpu_offload_checkpointing:
|
||||
img = img.to(self.device)
|
||||
vec = vec.to(self.device)
|
||||
img = img.to(self.device, non_blocking=True)
|
||||
vec = vec.to(self.device, non_blocking=True)
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -1243,7 +1243,7 @@ class ControlNetFlux(nn.Module):
|
||||
self.double_blocks = nn.ModuleList()
|
||||
self.single_blocks = nn.ModuleList()
|
||||
|
||||
self.to(device)
|
||||
self = self.to(device, non_blocking=True)
|
||||
|
||||
if self.blocks_to_swap:
|
||||
self.double_blocks = save_double_blocks
|
||||
|
||||
@@ -41,7 +41,7 @@ class SdTokenizeStrategy(TokenizeStrategy):
|
||||
text = [text] if isinstance(text, str) else text
|
||||
return [torch.stack([self._get_input_ids(self.tokenizer, t, self.max_length) for t in text], dim=0)]
|
||||
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
def tokenize_with_weights(self, text: str | List[str]) -> Tuple[List[torch.Tensor]]:
|
||||
text = [text] if isinstance(text, str) else text
|
||||
tokens_list = []
|
||||
weights_list = []
|
||||
|
||||
@@ -4,6 +4,7 @@ import argparse
|
||||
import ast
|
||||
import asyncio
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from contextlib import nullcontext
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
@@ -26,6 +27,7 @@ import toml
|
||||
|
||||
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from torch.cuda import Stream
|
||||
from tqdm import tqdm
|
||||
from packaging.version import Version
|
||||
|
||||
@@ -1425,10 +1427,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return
|
||||
|
||||
# prepare tokenizers and text encoders
|
||||
for text_encoder, device, te_dtype in zip(text_encoders, devices, te_dtypes):
|
||||
text_encoder.to(device)
|
||||
for i, (text_encoder, device, te_dtype) in enumerate(zip(text_encoders, devices, te_dtypes)):
|
||||
te_kwargs = {}
|
||||
if te_dtype is not None:
|
||||
text_encoder.to(dtype=te_dtype)
|
||||
te_kwargs['dtype'] = te_dtype
|
||||
text_encoders[i] = text_encoder.to(device, non_blocking=True, **te_dtype)
|
||||
|
||||
# create batch
|
||||
is_sd3 = len(tokenizers) == 1
|
||||
@@ -1450,6 +1453,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# iterate batches: call text encoder and cache outputs for memory or disk
|
||||
logger.info("caching text encoder outputs...")
|
||||
if not is_sd3:
|
||||
@@ -3136,7 +3141,10 @@ def cache_batch_latents(
|
||||
images.append(image)
|
||||
|
||||
img_tensors = torch.stack(images, dim=0)
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
|
||||
|
||||
s = Stream()
|
||||
|
||||
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype, non_blocking=True)
|
||||
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
@@ -3172,12 +3180,13 @@ def cache_batch_latents(
|
||||
if not HIGH_VRAM:
|
||||
clean_memory_on_device(vae.device)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def cache_batch_text_encoder_outputs(
|
||||
image_infos, tokenizers, text_encoders, max_token_length, cache_to_disk, input_ids1, input_ids2, dtype
|
||||
):
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device)
|
||||
input_ids1 = input_ids1.to(text_encoders[0].device, non_blocking=True)
|
||||
input_ids2 = input_ids2.to(text_encoders[1].device, non_blocking=True)
|
||||
|
||||
with torch.no_grad():
|
||||
b_hidden_state1, b_hidden_state2, b_pool2 = get_hidden_states_sdxl(
|
||||
@@ -5652,9 +5661,9 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
)
|
||||
# work on low-ram device
|
||||
if args.lowram:
|
||||
text_encoder.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
text_encoder = text_encoder.to(accelerator.device, non_blocking=True)
|
||||
unet = unet.to(accelerator.device, non_blocking=True)
|
||||
vae = vae.to(accelerator.device, non_blocking=True)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
@@ -6474,7 +6483,7 @@ def sample_images_common(
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
org_vae_device = vae.device # CPUにいるはず
|
||||
vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
vae = vae.to(distributed_state.device) # distributed_state.device is same as accelerator.device
|
||||
|
||||
# unwrap unet and text_encoder(s)
|
||||
unet = accelerator.unwrap_model(unet_wrapped)
|
||||
@@ -6509,7 +6518,7 @@ def sample_images_common(
|
||||
requires_safety_checker=False,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
pipeline.to(distributed_state.device)
|
||||
pipeline = pipeline.to(distributed_state.device)
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
@@ -6560,7 +6569,7 @@ def sample_images_common(
|
||||
torch.set_rng_state(rng_state)
|
||||
if torch.cuda.is_available() and cuda_rng_state is not None:
|
||||
torch.cuda.set_rng_state(cuda_rng_state)
|
||||
vae.to(org_vae_device)
|
||||
vae = vae.to(org_vae_device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
|
||||
# cuda to cpu
|
||||
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
|
||||
cuda_data_view.record_stream(stream)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)
|
||||
module_to_cpu.weight.data = cuda_data_view.data.to("cpu")
|
||||
|
||||
stream.synchronize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user