Merge branch 'dev' into gradual_latent_hires_fix

This commit is contained in:
Kohya S
2024-02-12 11:26:36 +09:00
6 changed files with 207 additions and 123 deletions

View File

@@ -125,9 +125,13 @@ def ipex_init(): # pylint: disable=too-many-statements
# AMP: # AMP:
torch.cuda.amp = torch.xpu.amp torch.cuda.amp = torch.xpu.amp
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
if not hasattr(torch.cuda.amp, "common"): if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext() torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False torch.cuda.amp.common.amp_definitely_not_available = lambda: False
try: try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
@@ -151,15 +155,16 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.has_half = True torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.version.cuda = "11.7" torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.cuda.get_device_capability = lambda *args, **kwargs: [11,7] torch.version.cuda = "12.1"
torch.cuda.get_device_properties.major = 11 torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
torch.cuda.get_device_properties.minor = 7 torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0 torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks() ipex_hijacks()
if not torch.xpu.has_fp64_dtype(): if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try: try:
from .diffusers import ipex_diffusers from .diffusers import ipex_diffusers
ipex_diffusers() ipex_diffusers()

View File

@@ -124,6 +124,7 @@ def torch_bmm_32_bit(input, mat2, *, out=None):
) )
else: else:
return original_torch_bmm(input, mat2, out=out) return original_torch_bmm(input, mat2, out=out)
torch.xpu.synchronize(input.device)
return hidden_states return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
@@ -172,4 +173,5 @@ def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropo
) )
else: else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
torch.xpu.synchronize(query.device)
return hidden_states return hidden_states

View File

@@ -149,6 +149,7 @@ class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice del attn_slice
torch.xpu.synchronize(query.device)
else: else:
query_slice = query[start_idx:end_idx] query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx] key_slice = key[start_idx:end_idx]
@@ -283,6 +284,7 @@ class AttnProcessor:
hidden_states[start_idx:end_idx] = attn_slice hidden_states[start_idx:end_idx] = attn_slice
del attn_slice del attn_slice
torch.xpu.synchronize(query.device)
else: else:
attention_probs = attn.get_attention_scores(query, key, attention_mask) attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value) hidden_states = torch.bmm(attention_probs, value)

View File

@@ -1,6 +1,11 @@
import contextlib import os
from functools import wraps
from contextlib import nullcontext
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype()
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -11,7 +16,7 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
return module.to("xpu") return module.to("xpu")
def return_null_context(*args, **kwargs): # pylint: disable=unused-argument def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return contextlib.nullcontext() return nullcontext()
@property @property
def is_cuda(self): def is_cuda(self):
@@ -25,15 +30,17 @@ def return_xpu(device):
# Autocast # Autocast
original_autocast = torch.autocast original_autocast_init = torch.amp.autocast_mode.autocast.__init__
def ipex_autocast(*args, **kwargs): @wraps(torch.amp.autocast_mode.autocast.__init__)
if len(args) > 0 and args[0] == "cuda": def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
return original_autocast("xpu", *args[1:], **kwargs) if device_type == "cuda":
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else: else:
return original_autocast(*args, **kwargs) return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
# Latent Antialias CPU Offload: # Latent Antialias CPU Offload:
original_interpolate = torch.nn.functional.interpolate original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None: if antialias or align_corners is not None:
return_device = tensor.device return_device = tensor.device
@@ -44,15 +51,29 @@ def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corn
return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode, return original_interpolate(tensor, size=size, scale_factor=scale_factor, mode=mode,
align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias) align_corners=align_corners, recompute_scale_factor=recompute_scale_factor, antialias=antialias)
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit): # Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
original_from_numpy = torch.from_numpy original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray): def from_numpy(ndarray):
if ndarray.dtype == float: if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32')) return original_from_numpy(ndarray.astype('float32'))
else: else:
return original_from_numpy(ndarray) return original_from_numpy(ndarray)
if torch.xpu.has_fp64_dtype(): original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
original_torch_bmm = torch.bmm original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else: else:
@@ -66,20 +87,25 @@ else:
# Data Type Errors: # Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None): def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype: if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype) mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out) return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype: if query.dtype != key.dtype:
key = key.to(dtype=query.dtype) key = key.to(dtype=query.dtype)
if query.dtype != value.dtype: if query.dtype != value.dtype:
value = value.to(dtype=query.dtype) value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
# A1111 FP16 # A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm original_functional_group_norm = torch.nn.functional.group_norm
@wraps(torch.nn.functional.group_norm)
def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05): def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype: if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -89,6 +115,7 @@ def functional_group_norm(input, num_groups, weight=None, bias=None, eps=1e-05):
# A1111 BF16 # A1111 BF16
original_functional_layer_norm = torch.nn.functional.layer_norm original_functional_layer_norm = torch.nn.functional.layer_norm
@wraps(torch.nn.functional.layer_norm)
def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05): def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if weight is not None and input.dtype != weight.data.dtype: if weight is not None and input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -98,6 +125,7 @@ def functional_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1
# Training # Training
original_functional_linear = torch.nn.functional.linear original_functional_linear = torch.nn.functional.linear
@wraps(torch.nn.functional.linear)
def functional_linear(input, weight, bias=None): def functional_linear(input, weight, bias=None):
if input.dtype != weight.data.dtype: if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -106,6 +134,7 @@ def functional_linear(input, weight, bias=None):
return original_functional_linear(input, weight, bias=bias) return original_functional_linear(input, weight, bias=bias)
original_functional_conv2d = torch.nn.functional.conv2d original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype: if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype) input = input.to(dtype=weight.data.dtype)
@@ -115,6 +144,7 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
# A1111 Embedding BF16 # A1111 Embedding BF16
original_torch_cat = torch.cat original_torch_cat = torch.cat
@wraps(torch.cat)
def torch_cat(tensor, *args, **kwargs): def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
@@ -123,6 +153,7 @@ def torch_cat(tensor, *args, **kwargs):
# SwinIR BF16: # SwinIR BF16:
original_functional_pad = torch.nn.functional.pad original_functional_pad = torch.nn.functional.pad
@wraps(torch.nn.functional.pad)
def functional_pad(input, pad, mode='constant', value=None): def functional_pad(input, pad, mode='constant', value=None):
if mode == 'reflect' and input.dtype == torch.bfloat16: if mode == 'reflect' and input.dtype == torch.bfloat16:
return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16) return original_functional_pad(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16)
@@ -131,13 +162,20 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor original_torch_tensor = torch.tensor
def torch_tensor(*args, device=None, **kwargs): @wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_tensor(*args, device=return_xpu(device), **kwargs) device = return_xpu(device)
else: if not device_supports_fp64:
return original_torch_tensor(*args, device=device, **kwargs) if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)
original_Tensor_to = torch.Tensor.to original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs): def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs) return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
@@ -145,6 +183,7 @@ def Tensor_to(self, device=None, *args, **kwargs):
return original_Tensor_to(self, device, *args, **kwargs) return original_Tensor_to(self, device, *args, **kwargs)
original_Tensor_cuda = torch.Tensor.cuda original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs): def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs) return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
@@ -152,6 +191,7 @@ def Tensor_cuda(self, device=None, *args, **kwargs):
return original_Tensor_cuda(self, device, *args, **kwargs) return original_Tensor_cuda(self, device, *args, **kwargs)
original_UntypedStorage_init = torch.UntypedStorage.__init__ original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs): def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs) return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
@@ -159,6 +199,7 @@ def UntypedStorage_init(*args, device=None, **kwargs):
return original_UntypedStorage_init(*args, device=device, **kwargs) return original_UntypedStorage_init(*args, device=device, **kwargs)
original_UntypedStorage_cuda = torch.UntypedStorage.cuda original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs): def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device): if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs) return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
@@ -166,6 +207,7 @@ def UntypedStorage_cuda(self, device=None, *args, **kwargs):
return original_UntypedStorage_cuda(self, device, *args, **kwargs) return original_UntypedStorage_cuda(self, device, *args, **kwargs)
original_torch_empty = torch.empty original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs): def torch_empty(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs) return original_torch_empty(*args, device=return_xpu(device), **kwargs)
@@ -173,6 +215,7 @@ def torch_empty(*args, device=None, **kwargs):
return original_torch_empty(*args, device=device, **kwargs) return original_torch_empty(*args, device=device, **kwargs)
original_torch_randn = torch.randn original_torch_randn = torch.randn
@wraps(torch.randn)
def torch_randn(*args, device=None, **kwargs): def torch_randn(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs) return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -180,6 +223,7 @@ def torch_randn(*args, device=None, **kwargs):
return original_torch_randn(*args, device=device, **kwargs) return original_torch_randn(*args, device=device, **kwargs)
original_torch_ones = torch.ones original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs): def torch_ones(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs) return original_torch_ones(*args, device=return_xpu(device), **kwargs)
@@ -187,6 +231,7 @@ def torch_ones(*args, device=None, **kwargs):
return original_torch_ones(*args, device=device, **kwargs) return original_torch_ones(*args, device=device, **kwargs)
original_torch_zeros = torch.zeros original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs): def torch_zeros(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs) return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
@@ -194,6 +239,7 @@ def torch_zeros(*args, device=None, **kwargs):
return original_torch_zeros(*args, device=device, **kwargs) return original_torch_zeros(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs): def torch_linspace(*args, device=None, **kwargs):
if check_device(device): if check_device(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs) return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
@@ -201,6 +247,7 @@ def torch_linspace(*args, device=None, **kwargs):
return original_torch_linspace(*args, device=device, **kwargs) return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None): def torch_Generator(device=None):
if check_device(device): if check_device(device):
return original_torch_Generator(return_xpu(device)) return original_torch_Generator(return_xpu(device))
@@ -208,12 +255,14 @@ def torch_Generator(device=None):
return original_torch_Generator(device) return original_torch_Generator(device)
original_torch_load = torch.load original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs): def torch_load(f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs):
if check_device(map_location): if check_device(map_location):
return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
else: else:
return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs) return original_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs)
# Hijack Functions: # Hijack Functions:
def ipex_hijacks(): def ipex_hijacks():
torch.tensor = torch_tensor torch.tensor = torch_tensor
@@ -232,7 +281,7 @@ def ipex_hijacks():
torch.backends.cuda.sdp_kernel = return_null_context torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda torch.UntypedStorage.is_cuda = is_cuda
torch.autocast = ipex_autocast torch.amp.autocast_mode.autocast.__init__ = autocast_init
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.group_norm = functional_group_norm
@@ -244,5 +293,6 @@ def ipex_hijacks():
torch.bmm = torch_bmm torch.bmm = torch_bmm
torch.cat = torch_cat torch.cat = torch_cat
if not torch.xpu.has_fp64_dtype(): if not device_supports_fp64:
torch.from_numpy = from_numpy torch.from_numpy = from_numpy
torch.as_tensor = as_tensor

View File

@@ -19,7 +19,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import gc import gc
import glob import glob
import math import math
@@ -4636,7 +4636,6 @@ def line_to_prompt_dict(line: str) -> dict:
return prompt_dict return prompt_dict
def sample_images_common( def sample_images_common(
pipe_class, pipe_class,
accelerator: Accelerator, accelerator: Accelerator,
@@ -4654,6 +4653,7 @@ def sample_images_common(
""" """
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
""" """
if steps == 0: if steps == 0:
if not args.sample_at_first: if not args.sample_at_first:
return return
@@ -4673,8 +4673,10 @@ def sample_images_common(
print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}") print(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
return return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
org_vae_device = vae.device # CPUにいるはず org_vae_device = vae.device # CPUにいるはず
vae.to(device) vae.to(distributed_state.device)
# unwrap unet and text_encoder(s) # unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet) unet = accelerator.unwrap_model(unet)
@@ -4684,10 +4686,6 @@ def sample_images_common(
text_encoder = accelerator.unwrap_model(text_encoder) text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts # read prompts
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
# prompts = f.readlines()
if args.sample_prompts.endswith(".txt"): if args.sample_prompts.endswith(".txt"):
with open(args.sample_prompts, "r", encoding="utf-8") as f: with open(args.sample_prompts, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
@@ -4700,12 +4698,11 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f: with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f) prompts = json.load(f)
schedulers: dict = {} # schedulers: dict = {} cannot find where this is used
default_scheduler = get_my_scheduler( default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler, sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization, v_parameterization=args.v_parameterization,
) )
schedulers[args.sample_sampler] = default_scheduler
pipeline = pipe_class( pipeline = pipe_class(
text_encoder=text_encoder, text_encoder=text_encoder,
@@ -4718,114 +4715,145 @@ def sample_images_common(
requires_safety_checker=False, requires_safety_checker=False,
clip_skip=args.clip_skip, clip_skip=args.clip_skip,
) )
pipeline.to(device) pipeline.to(distributed_state.device)
save_dir = args.output_dir + "/sample" save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# preprocess prompts
for i in range(len(prompts)):
prompt_dict = prompts[i]
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
prompts[i] = prompt_dict
assert isinstance(prompt_dict, dict)
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
prompt_dict["enum"] = i
prompt_dict.pop("subset", None)
# save random state to restore later
rng_state = torch.get_rng_state() rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None # TODO mps etc. support
with torch.no_grad(): if distributed_state.num_processes <= 1:
# with accelerator.autocast(): # If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
for i, prompt_dict in enumerate(prompts): with torch.no_grad():
if not accelerator.is_main_process: for prompt_dict in prompts:
continue sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
else:
if isinstance(prompt_dict, str): # Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
prompt_dict = line_to_prompt_dict(prompt_dict) # prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
assert isinstance(prompt_dict, dict) for i in range(distributed_state.num_processes):
negative_prompt = prompt_dict.get("negative_prompt") per_process_prompts.append(prompts[i::distributed_state.num_processes])
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512) with torch.no_grad():
height = prompt_dict.get("height", 512) with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
scale = prompt_dict.get("scale", 7.5) for prompt_dict in prompt_dict_lists[0]:
seed = prompt_dict.get("seed") sample_image_inference(accelerator, args, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=controlnet)
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
image = pipeline.latents_to_image(latents)[0]
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png"
)
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# clear pipeline and cache to reduce vram usage # clear pipeline and cache to reduce vram usage
del pipeline del pipeline
torch.cuda.empty_cache()
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
torch.set_rng_state(rng_state) torch.set_rng_state(rng_state)
if cuda_rng_state is not None: if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state) torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device) vae.to(org_vae_device)
def sample_image_inference(accelerator: Accelerator, args: argparse.Namespace, pipeline, save_dir, prompt_dict, epoch, steps, prompt_replacement, controlnet=None):
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
pipeline.scheduler = scheduler
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"\nprompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=sample_steps,
guidance_scale=scale,
negative_prompt=negative_prompt,
controlnet=controlnet,
controlnet_image=controlnet_image,
)
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
image = pipeline.latents_to_image(latents)[0]
# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = prompt_dict["enum"]
img_filename = (
f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
)
image.save(os.path.join(save_dir, img_filename))
# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")
wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass
# endregion # endregion
# region 前処理用 # region 前処理用

View File

@@ -1,4 +1,3 @@
import math
import argparse import argparse
import os import os
import time import time
@@ -6,8 +5,6 @@ import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from tqdm import tqdm from tqdm import tqdm
from library import sai_model_spec, train_util from library import sai_model_spec, train_util
import library.model_util as model_util
import lora
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99