mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into sd3_safetensors_merge
This commit is contained in:
@@ -14,6 +14,11 @@ The command to install PyTorch is as follows:
|
|||||||
|
|
||||||
### Recent Updates
|
### Recent Updates
|
||||||
|
|
||||||
|
Feb 26, 2025:
|
||||||
|
|
||||||
|
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
|
||||||
|
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
|
||||||
|
|
||||||
Jan 25, 2025:
|
Jan 25, 2025:
|
||||||
|
|
||||||
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!
|
||||||
|
|||||||
@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.is_schnell: Optional[bool] = None
|
self.is_schnell: Optional[bool] = None
|
||||||
self.is_swapping_blocks: bool = False
|
self.is_swapping_blocks: bool = False
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
|
||||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
@@ -323,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||||
return noise_scheduler
|
return noise_scheduler
|
||||||
|
|
||||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
def encode_images_to_latents(self, args, vae, images):
|
||||||
return vae.encode(images)
|
return vae.encode(images)
|
||||||
|
|
||||||
def shift_scale_latents(self, args, latents):
|
def shift_scale_latents(self, args, latents):
|
||||||
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -376,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
|
||||||
# if not args.split_mode:
|
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
|
||||||
# normal forward
|
|
||||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = unet(
|
model_pred = unet(
|
||||||
@@ -390,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
else:
|
|
||||||
# split forward to reduce memory usage
|
|
||||||
assert network.train_blocks == "single", "train_blocks must be single for split mode"
|
|
||||||
with accelerator.autocast():
|
|
||||||
# move flux lower to cpu, and then move flux upper to gpu
|
|
||||||
unet.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
self.flux_upper.to(accelerator.device)
|
|
||||||
|
|
||||||
# upper model does not require grad
|
|
||||||
with torch.no_grad():
|
|
||||||
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
|
|
||||||
img=packed_noisy_model_input,
|
|
||||||
img_ids=img_ids,
|
|
||||||
txt=t5_out,
|
|
||||||
txt_ids=txt_ids,
|
|
||||||
y=l_pooled,
|
|
||||||
timesteps=timesteps / 1000,
|
|
||||||
guidance=guidance_vec,
|
|
||||||
txt_attention_mask=t5_attn_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# move flux upper back to cpu, and then move flux lower to gpu
|
|
||||||
self.flux_upper.to("cpu")
|
|
||||||
clean_memory_on_device(accelerator.device)
|
|
||||||
unet.to(accelerator.device)
|
|
||||||
|
|
||||||
# lower model requires grad
|
|
||||||
intermediate_img.requires_grad_(True)
|
|
||||||
intermediate_txt.requires_grad_(True)
|
|
||||||
vec.requires_grad_(True)
|
|
||||||
pe.requires_grad_(True)
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_train and train_unet):
|
|
||||||
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
|
|
||||||
"""
|
|
||||||
|
|
||||||
return model_pred
|
return model_pred
|
||||||
|
|
||||||
model_pred = call_dit(
|
model_pred = call_dit(
|
||||||
@@ -546,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoder.to(te_weight_dtype) # fp8
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
prepare_fp8(text_encoder, weight_dtype)
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
def prepare_unet_with_accelerator(
|
def prepare_unet_with_accelerator(
|
||||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
|||||||
@@ -2,6 +2,13 @@ import functools
|
|||||||
import gc
|
import gc
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
try:
|
||||||
|
# intel gpu support for pytorch older than 2.5
|
||||||
|
# ipex is not needed after pytorch 2.5
|
||||||
|
import intel_extension_for_pytorch as ipex # noqa
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
HAS_CUDA = torch.cuda.is_available()
|
HAS_CUDA = torch.cuda.is_available()
|
||||||
@@ -14,8 +21,6 @@ except Exception:
|
|||||||
HAS_MPS = False
|
HAS_MPS = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # noqa
|
|
||||||
|
|
||||||
HAS_XPU = torch.xpu.is_available()
|
HAS_XPU = torch.xpu.is_available()
|
||||||
except Exception:
|
except Exception:
|
||||||
HAS_XPU = False
|
HAS_XPU = False
|
||||||
@@ -69,7 +74,7 @@ def init_ipex():
|
|||||||
|
|
||||||
This function should run right after importing torch and before doing anything else.
|
This function should run right after importing torch and before doing anything else.
|
||||||
|
|
||||||
If IPEX is not available, this function does nothing.
|
If xpu is not available, this function does nothing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if HAS_XPU:
|
if HAS_XPU:
|
||||||
|
|||||||
@@ -2,7 +2,11 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import contextlib
|
import contextlib
|
||||||
import torch
|
import torch
|
||||||
|
try:
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||||
|
legacy = True
|
||||||
|
except Exception:
|
||||||
|
legacy = False
|
||||||
from .hijacks import ipex_hijacks
|
from .hijacks import ipex_hijacks
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
|
||||||
return True, "Skipping IPEX hijack"
|
return True, "Skipping IPEX hijack"
|
||||||
else:
|
else:
|
||||||
|
try: # force xpu device on torch compile and triton
|
||||||
|
torch._inductor.utils.GPU_TYPES = ["xpu"]
|
||||||
|
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
|
||||||
|
from triton import backends as triton_backends # pylint: disable=import-error
|
||||||
|
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
# Replace cuda with xpu:
|
# Replace cuda with xpu:
|
||||||
torch.cuda.current_device = torch.xpu.current_device
|
torch.cuda.current_device = torch.xpu.current_device
|
||||||
torch.cuda.current_stream = torch.xpu.current_stream
|
torch.cuda.current_stream = torch.xpu.current_stream
|
||||||
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.is_current_stream_capturing = lambda: False
|
torch.cuda.is_current_stream_capturing = lambda: False
|
||||||
torch.cuda.set_device = torch.xpu.set_device
|
torch.cuda.set_device = torch.xpu.set_device
|
||||||
torch.cuda.stream = torch.xpu.stream
|
torch.cuda.stream = torch.xpu.stream
|
||||||
torch.cuda.synchronize = torch.xpu.synchronize
|
|
||||||
torch.cuda.Event = torch.xpu.Event
|
torch.cuda.Event = torch.xpu.Event
|
||||||
torch.cuda.Stream = torch.xpu.Stream
|
torch.cuda.Stream = torch.xpu.Stream
|
||||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
|
||||||
torch.Tensor.cuda = torch.Tensor.xpu
|
torch.Tensor.cuda = torch.Tensor.xpu
|
||||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||||
torch.nn.Module.cuda = torch.nn.Module.xpu
|
torch.nn.Module.cuda = torch.nn.Module.xpu
|
||||||
|
torch.cuda.Optional = torch.xpu.Optional
|
||||||
|
torch.cuda.__cached__ = torch.xpu.__cached__
|
||||||
|
torch.cuda.__loader__ = torch.xpu.__loader__
|
||||||
|
torch.cuda.Tuple = torch.xpu.Tuple
|
||||||
|
torch.cuda.streams = torch.xpu.streams
|
||||||
|
torch.cuda.Any = torch.xpu.Any
|
||||||
|
torch.cuda.__doc__ = torch.xpu.__doc__
|
||||||
|
torch.cuda.default_generators = torch.xpu.default_generators
|
||||||
|
torch.cuda._get_device_index = torch.xpu._get_device_index
|
||||||
|
torch.cuda.__path__ = torch.xpu.__path__
|
||||||
|
torch.cuda.set_stream = torch.xpu.set_stream
|
||||||
|
torch.cuda.torch = torch.xpu.torch
|
||||||
|
torch.cuda.Union = torch.xpu.Union
|
||||||
|
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
||||||
|
torch.cuda.__package__ = torch.xpu.__package__
|
||||||
|
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
||||||
|
torch.cuda.List = torch.xpu.List
|
||||||
|
torch.cuda._lazy_init = torch.xpu._lazy_init
|
||||||
|
torch.cuda.StreamContext = torch.xpu.StreamContext
|
||||||
|
torch.cuda._lazy_call = torch.xpu._lazy_call
|
||||||
|
torch.cuda.random = torch.xpu.random
|
||||||
|
torch.cuda._device = torch.xpu._device
|
||||||
|
torch.cuda.__name__ = torch.xpu.__name__
|
||||||
|
torch.cuda._device_t = torch.xpu._device_t
|
||||||
|
torch.cuda.__spec__ = torch.xpu.__spec__
|
||||||
|
torch.cuda.__file__ = torch.xpu.__file__
|
||||||
|
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||||
|
|
||||||
|
if legacy:
|
||||||
|
torch.cuda.os = torch.xpu.os
|
||||||
|
torch.cuda.Device = torch.xpu.Device
|
||||||
|
torch.cuda.warnings = torch.xpu.warnings
|
||||||
|
torch.cuda.classproperty = torch.xpu.classproperty
|
||||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||||
|
if float(ipex.__version__[:3]) < 2.3:
|
||||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||||
|
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||||
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
|
||||||
torch.cuda._tls = torch.xpu.lazy_init._tls
|
torch.cuda._tls = torch.xpu.lazy_init._tls
|
||||||
torch.cuda.threading = torch.xpu.lazy_init.threading
|
torch.cuda.threading = torch.xpu.lazy_init.threading
|
||||||
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
torch.cuda.traceback = torch.xpu.lazy_init.traceback
|
||||||
torch.cuda.Optional = torch.xpu.Optional
|
|
||||||
torch.cuda.__cached__ = torch.xpu.__cached__
|
|
||||||
torch.cuda.__loader__ = torch.xpu.__loader__
|
|
||||||
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
|
||||||
torch.cuda.Tuple = torch.xpu.Tuple
|
|
||||||
torch.cuda.streams = torch.xpu.streams
|
|
||||||
torch.cuda._lazy_new = torch.xpu._lazy_new
|
torch.cuda._lazy_new = torch.xpu._lazy_new
|
||||||
|
|
||||||
|
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||||
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
torch.cuda.FloatStorage = torch.xpu.FloatStorage
|
||||||
torch.cuda.Any = torch.xpu.Any
|
|
||||||
torch.cuda.__doc__ = torch.xpu.__doc__
|
|
||||||
torch.cuda.default_generators = torch.xpu.default_generators
|
|
||||||
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
|
||||||
torch.cuda._get_device_index = torch.xpu._get_device_index
|
|
||||||
torch.cuda.__path__ = torch.xpu.__path__
|
|
||||||
torch.cuda.Device = torch.xpu.Device
|
|
||||||
torch.cuda.IntTensor = torch.xpu.IntTensor
|
|
||||||
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
|
||||||
torch.cuda.set_stream = torch.xpu.set_stream
|
|
||||||
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
|
||||||
torch.cuda.os = torch.xpu.os
|
|
||||||
torch.cuda.torch = torch.xpu.torch
|
|
||||||
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
|
||||||
torch.cuda.Union = torch.xpu.Union
|
|
||||||
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
|
||||||
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
|
||||||
torch.cuda.LongTensor = torch.xpu.LongTensor
|
|
||||||
torch.cuda.IntStorage = torch.xpu.IntStorage
|
|
||||||
torch.cuda.LongStorage = torch.xpu.LongStorage
|
|
||||||
torch.cuda.__annotations__ = torch.xpu.__annotations__
|
|
||||||
torch.cuda.__package__ = torch.xpu.__package__
|
|
||||||
torch.cuda.__builtins__ = torch.xpu.__builtins__
|
|
||||||
torch.cuda.CharTensor = torch.xpu.CharTensor
|
|
||||||
torch.cuda.List = torch.xpu.List
|
|
||||||
torch.cuda._lazy_init = torch.xpu._lazy_init
|
|
||||||
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
|
||||||
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
|
||||||
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
torch.cuda.HalfTensor = torch.xpu.HalfTensor
|
||||||
torch.cuda.StreamContext = torch.xpu.StreamContext
|
|
||||||
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
|
||||||
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
|
||||||
torch.cuda._lazy_call = torch.xpu._lazy_call
|
|
||||||
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
torch.cuda.HalfStorage = torch.xpu.HalfStorage
|
||||||
torch.cuda.random = torch.xpu.random
|
torch.cuda.ByteTensor = torch.xpu.ByteTensor
|
||||||
torch.cuda._device = torch.xpu._device
|
torch.cuda.ByteStorage = torch.xpu.ByteStorage
|
||||||
torch.cuda.classproperty = torch.xpu.classproperty
|
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
|
||||||
torch.cuda.__name__ = torch.xpu.__name__
|
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
|
||||||
torch.cuda._device_t = torch.xpu._device_t
|
torch.cuda.ShortTensor = torch.xpu.ShortTensor
|
||||||
torch.cuda.warnings = torch.xpu.warnings
|
torch.cuda.ShortStorage = torch.xpu.ShortStorage
|
||||||
torch.cuda.__spec__ = torch.xpu.__spec__
|
torch.cuda.LongTensor = torch.xpu.LongTensor
|
||||||
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
torch.cuda.LongStorage = torch.xpu.LongStorage
|
||||||
|
torch.cuda.IntTensor = torch.xpu.IntTensor
|
||||||
|
torch.cuda.IntStorage = torch.xpu.IntStorage
|
||||||
|
torch.cuda.CharTensor = torch.xpu.CharTensor
|
||||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||||
torch.cuda.__file__ = torch.xpu.__file__
|
torch.cuda.BoolTensor = torch.xpu.BoolTensor
|
||||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
torch.cuda.BoolStorage = torch.xpu.BoolStorage
|
||||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
|
||||||
|
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
|
||||||
|
|
||||||
|
if not legacy or float(ipex.__version__[:3]) >= 2.3:
|
||||||
|
torch.cuda._initialization_lock = torch.xpu._initialization_lock
|
||||||
|
torch.cuda._initialized = torch.xpu._initialized
|
||||||
|
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
|
||||||
|
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
|
||||||
|
torch.cuda._queued_calls = torch.xpu._queued_calls
|
||||||
|
torch.cuda._tls = torch.xpu._tls
|
||||||
|
torch.cuda.threading = torch.xpu.threading
|
||||||
|
torch.cuda.traceback = torch.xpu.traceback
|
||||||
|
|
||||||
# Memory:
|
# Memory:
|
||||||
torch.cuda.memory = torch.xpu.memory
|
|
||||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||||
torch.xpu.empty_cache = lambda: None
|
torch.xpu.empty_cache = lambda: None
|
||||||
torch.cuda.empty_cache = torch.xpu.empty_cache
|
torch.cuda.empty_cache = torch.xpu.empty_cache
|
||||||
torch.cuda.memory_stats = torch.xpu.memory_stats
|
|
||||||
|
if legacy:
|
||||||
torch.cuda.memory_summary = torch.xpu.memory_summary
|
torch.cuda.memory_summary = torch.xpu.memory_summary
|
||||||
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
|
||||||
|
torch.cuda.memory = torch.xpu.memory
|
||||||
|
torch.cuda.memory_stats = torch.xpu.memory_stats
|
||||||
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
torch.cuda.memory_allocated = torch.xpu.memory_allocated
|
||||||
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
|
||||||
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
torch.cuda.memory_reserved = torch.xpu.memory_reserved
|
||||||
@@ -128,7 +154,11 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||||
|
|
||||||
# AMP:
|
# AMP:
|
||||||
|
if legacy:
|
||||||
|
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
|
||||||
|
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
|
||||||
torch.cuda.amp = torch.xpu.amp
|
torch.cuda.amp = torch.xpu.amp
|
||||||
|
if float(ipex.__version__[:3]) < 2.3:
|
||||||
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
|
||||||
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
|
||||||
|
|
||||||
@@ -147,13 +177,21 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||||
|
|
||||||
# C
|
# C
|
||||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
if legacy and float(ipex.__version__[:3]) < 2.3:
|
||||||
|
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
|
||||||
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
|
||||||
ipex._C._DeviceProperties.major = 2024
|
ipex._C._DeviceProperties.major = 12
|
||||||
ipex._C._DeviceProperties.minor = 0
|
ipex._C._DeviceProperties.minor = 1
|
||||||
|
else:
|
||||||
|
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
|
||||||
|
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
|
||||||
|
torch._C._XpuDeviceProperties.major = 12
|
||||||
|
torch._C._XpuDeviceProperties.minor = 1
|
||||||
|
|
||||||
# Fix functions with ipex:
|
# Fix functions with ipex:
|
||||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
# torch.xpu.mem_get_info always returns the total memory as free memory
|
||||||
|
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||||
|
torch.cuda.mem_get_info = torch.xpu.mem_get_info
|
||||||
torch._utils._get_available_device_type = lambda: "xpu"
|
torch._utils._get_available_device_type = lambda: "xpu"
|
||||||
torch.has_cuda = True
|
torch.has_cuda = True
|
||||||
torch.cuda.has_half = True
|
torch.cuda.has_half = True
|
||||||
@@ -161,17 +199,17 @@ def ipex_init(): # pylint: disable=too-many-statements
|
|||||||
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
|
||||||
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
torch.backends.cuda.is_built = lambda *args, **kwargs: True
|
||||||
torch.version.cuda = "12.1"
|
torch.version.cuda = "12.1"
|
||||||
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1]
|
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
|
||||||
|
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
|
||||||
torch.cuda.get_device_properties.major = 12
|
torch.cuda.get_device_properties.major = 12
|
||||||
torch.cuda.get_device_properties.minor = 1
|
torch.cuda.get_device_properties.minor = 1
|
||||||
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
torch.cuda.ipc_collect = lambda *args, **kwargs: None
|
||||||
torch.cuda.utilization = lambda *args, **kwargs: 0
|
torch.cuda.utilization = lambda *args, **kwargs: 0
|
||||||
|
|
||||||
ipex_hijacks()
|
device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
|
||||||
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
|
|
||||||
try:
|
try:
|
||||||
from .diffusers import ipex_diffusers
|
from .diffusers import ipex_diffusers
|
||||||
ipex_diffusers()
|
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
pass
|
pass
|
||||||
torch.cuda.is_xpu_hijacked = True
|
torch.cuda.is_xpu_hijacked = True
|
||||||
|
|||||||
@@ -1,177 +1,119 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
from functools import cache, wraps
|
||||||
from functools import cache
|
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
|
||||||
|
|
||||||
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4))
|
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
|
||||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
|
||||||
|
|
||||||
# Find something divisible with the input_tokens
|
# Find something divisible with the input_tokens
|
||||||
@cache
|
@cache
|
||||||
def find_slice_size(slice_size, slice_block_size):
|
def find_split_size(original_size, slice_block_size, slice_rate=2):
|
||||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
split_size = original_size
|
||||||
slice_size = slice_size // 2
|
while True:
|
||||||
if slice_size <= 1:
|
if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
|
||||||
slice_size = 1
|
return split_size
|
||||||
break
|
split_size = split_size - 1
|
||||||
return slice_size
|
if split_size <= 1:
|
||||||
|
return 1
|
||||||
|
return split_size
|
||||||
|
|
||||||
|
|
||||||
# Find slice sizes for SDPA
|
# Find slice sizes for SDPA
|
||||||
@cache
|
@cache
|
||||||
def find_sdpa_slice_sizes(query_shape, query_element_size):
|
def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
|
||||||
if len(query_shape) == 3:
|
batch_size, attn_heads, query_len, _ = query_shape
|
||||||
batch_size_attention, query_tokens, shape_three = query_shape
|
_, _, key_len, _ = key_shape
|
||||||
shape_four = 1
|
|
||||||
else:
|
|
||||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
|
||||||
|
|
||||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||||
block_size = batch_size_attention * slice_block_size
|
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
split_batch_size = batch_size
|
||||||
split_2_slice_size = query_tokens
|
split_head_size = attn_heads
|
||||||
split_3_slice_size = shape_three
|
split_query_size = query_len
|
||||||
|
|
||||||
do_split = False
|
do_batch_split = False
|
||||||
do_split_2 = False
|
do_head_split = False
|
||||||
do_split_3 = False
|
do_query_split = False
|
||||||
|
|
||||||
if block_size > sdpa_slice_trigger_rate:
|
if batch_size * slice_batch_size >= trigger_rate:
|
||||||
do_split = True
|
do_batch_split = True
|
||||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
|
||||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
|
||||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
|
||||||
do_split_2 = True
|
|
||||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
|
||||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
|
||||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
|
||||||
do_split_3 = True
|
|
||||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
|
||||||
|
|
||||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
if split_batch_size * slice_batch_size > slice_rate:
|
||||||
|
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
|
||||||
|
do_head_split = True
|
||||||
|
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
|
||||||
|
|
||||||
# Find slice sizes for BMM
|
if split_head_size * slice_head_size > slice_rate:
|
||||||
@cache
|
slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
|
||||||
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape):
|
do_query_split = True
|
||||||
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2]
|
split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
|
||||||
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
|
|
||||||
block_size = batch_size_attention * slice_block_size
|
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
|
||||||
split_2_slice_size = input_tokens
|
|
||||||
split_3_slice_size = mat2_atten_shape
|
|
||||||
|
|
||||||
do_split = False
|
|
||||||
do_split_2 = False
|
|
||||||
do_split_3 = False
|
|
||||||
|
|
||||||
if block_size > attention_slice_rate:
|
|
||||||
do_split = True
|
|
||||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
|
||||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
|
||||||
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
|
|
||||||
do_split_2 = True
|
|
||||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
|
||||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
|
||||||
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
|
|
||||||
do_split_3 = True
|
|
||||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
|
||||||
|
|
||||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
|
||||||
|
|
||||||
|
|
||||||
original_torch_bmm = torch.bmm
|
|
||||||
def torch_bmm_32_bit(input, mat2, *, out=None):
|
|
||||||
if input.device.type != "xpu":
|
|
||||||
return original_torch_bmm(input, mat2, out=out)
|
|
||||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
|
|
||||||
|
|
||||||
# Slice BMM
|
|
||||||
if do_split:
|
|
||||||
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
|
|
||||||
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
|
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
|
||||||
start_idx = i * split_slice_size
|
|
||||||
end_idx = (i + 1) * split_slice_size
|
|
||||||
if do_split_2:
|
|
||||||
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
||||||
if do_split_3:
|
|
||||||
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_3 = i3 * split_3_slice_size
|
|
||||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
|
|
||||||
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
|
||||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
|
||||||
out=out
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
|
|
||||||
input[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
|
|
||||||
out=out
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
hidden_states[start_idx:end_idx] = original_torch_bmm(
|
|
||||||
input[start_idx:end_idx],
|
|
||||||
mat2[start_idx:end_idx],
|
|
||||||
out=out
|
|
||||||
)
|
|
||||||
torch.xpu.synchronize(input.device)
|
|
||||||
else:
|
|
||||||
return original_torch_bmm(input, mat2, out=out)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||||
|
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||||
if query.device.type != "xpu":
|
if query.device.type != "xpu":
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size())
|
is_unsqueezed = False
|
||||||
|
if len(query.shape) == 3:
|
||||||
|
query = query.unsqueeze(0)
|
||||||
|
is_unsqueezed = True
|
||||||
|
if len(key.shape) == 3:
|
||||||
|
key = key.unsqueeze(0)
|
||||||
|
if len(value.shape) == 3:
|
||||||
|
value = value.unsqueeze(0)
|
||||||
|
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
|
||||||
|
|
||||||
# Slice SDPA
|
# Slice SDPA
|
||||||
if do_split:
|
if do_batch_split:
|
||||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
batch_size, attn_heads, query_len, _ = query.shape
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
_, _, _, head_dim = value.shape
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
|
||||||
start_idx = i * split_slice_size
|
if attn_mask is not None:
|
||||||
end_idx = (i + 1) * split_slice_size
|
attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
|
||||||
if do_split_2:
|
for ib in range(batch_size // split_batch_size):
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
start_idx = ib * split_batch_size
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
end_idx = (ib + 1) * split_batch_size
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
if do_head_split:
|
||||||
if do_split_3:
|
for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
|
||||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
start_idx_h = ih * split_head_size
|
||||||
start_idx_3 = i3 * split_3_slice_size
|
end_idx_h = (ih + 1) * split_head_size
|
||||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
if do_query_split:
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
|
for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
|
||||||
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
start_idx_q = iq * split_query_size
|
||||||
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
end_idx_q = (iq + 1) * split_query_size
|
||||||
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
|
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
|
||||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
|
query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
|
||||||
|
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||||
|
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||||
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
|
hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
|
||||||
query[start_idx:end_idx, start_idx_2:end_idx_2],
|
query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||||
key[start_idx:end_idx, start_idx_2:end_idx_2],
|
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||||
value[start_idx:end_idx, start_idx_2:end_idx_2],
|
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
|
||||||
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
|
hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
|
||||||
query[start_idx:end_idx],
|
query[start_idx:end_idx, :, :, :],
|
||||||
key[start_idx:end_idx],
|
key[start_idx:end_idx, :, :, :],
|
||||||
value[start_idx:end_idx],
|
value[start_idx:end_idx, :, :, :],
|
||||||
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
|
attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
|
||||||
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
dropout_p=dropout_p, is_causal=is_causal, **kwargs
|
||||||
)
|
)
|
||||||
torch.xpu.synchronize(query.device)
|
torch.xpu.synchronize(query.device)
|
||||||
else:
|
else:
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||||
|
if is_unsqueezed:
|
||||||
|
hidden_states.squeeze(0)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -1,312 +1,47 @@
|
|||||||
import os
|
from functools import wraps
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
import diffusers # pylint: disable=import-error
|
||||||
import diffusers #0.24.0 # pylint: disable=import-error
|
|
||||||
from diffusers.models.attention_processor import Attention
|
|
||||||
from diffusers.utils import USE_PEFT_BACKEND
|
|
||||||
from functools import cache
|
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
|
|
||||||
|
|
||||||
@cache
|
# Diffusers FreeU
|
||||||
def find_slice_size(slice_size, slice_block_size):
|
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
|
||||||
while (slice_size * slice_block_size) > attention_slice_rate:
|
@wraps(diffusers.utils.torch_utils.fourier_filter)
|
||||||
slice_size = slice_size // 2
|
def fourier_filter(x_in, threshold, scale):
|
||||||
if slice_size <= 1:
|
return_dtype = x_in.dtype
|
||||||
slice_size = 1
|
return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
|
||||||
break
|
|
||||||
return slice_size
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
|
|
||||||
if len(query_shape) == 3:
|
|
||||||
batch_size_attention, query_tokens, shape_three = query_shape
|
|
||||||
shape_four = 1
|
|
||||||
else:
|
|
||||||
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
|
|
||||||
if slice_size is not None:
|
|
||||||
batch_size_attention = slice_size
|
|
||||||
|
|
||||||
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size
|
# fp64 error
|
||||||
block_size = batch_size_attention * slice_block_size
|
class FluxPosEmbed(torch.nn.Module):
|
||||||
|
def __init__(self, theta: int, axes_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.theta = theta
|
||||||
|
self.axes_dim = axes_dim
|
||||||
|
|
||||||
split_slice_size = batch_size_attention
|
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||||
split_2_slice_size = query_tokens
|
n_axes = ids.shape[-1]
|
||||||
split_3_slice_size = shape_three
|
cos_out = []
|
||||||
|
sin_out = []
|
||||||
do_split = False
|
pos = ids.float()
|
||||||
do_split_2 = False
|
for i in range(n_axes):
|
||||||
do_split_3 = False
|
cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
|
||||||
|
self.axes_dim[i],
|
||||||
if query_device_type != "xpu":
|
pos[:, i],
|
||||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
theta=self.theta,
|
||||||
|
repeat_interleave_real=True,
|
||||||
if block_size > attention_slice_rate:
|
use_real=True,
|
||||||
do_split = True
|
freqs_dtype=torch.float32,
|
||||||
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
|
|
||||||
if split_slice_size * slice_block_size > attention_slice_rate:
|
|
||||||
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
|
|
||||||
do_split_2 = True
|
|
||||||
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
|
|
||||||
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
|
|
||||||
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
|
|
||||||
do_split_3 = True
|
|
||||||
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
|
|
||||||
|
|
||||||
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
|
|
||||||
|
|
||||||
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
|
|
||||||
r"""
|
|
||||||
Processor for implementing sliced attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
slice_size (`int`, *optional*):
|
|
||||||
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
|
||||||
`attention_head_dim` must be a multiple of the `slice_size`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, slice_size):
|
|
||||||
self.slice_size = slice_size
|
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
|
||||||
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
)
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
cos_out.append(cos)
|
||||||
|
sin_out.append(sin)
|
||||||
if attn.group_norm is not None:
|
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||||
|
return freqs_cos, freqs_sin
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
dim = query.shape[-1]
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
batch_size_attention, query_tokens, shape_three = query.shape
|
|
||||||
hidden_states = torch.zeros(
|
|
||||||
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
####################################################################
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
||||||
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
|
|
||||||
|
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
|
||||||
start_idx = i * split_slice_size
|
|
||||||
end_idx = (i + 1) * split_slice_size
|
|
||||||
if do_split_2:
|
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
||||||
if do_split_3:
|
|
||||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_3 = i3 * split_3_slice_size
|
|
||||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
else:
|
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
torch.xpu.synchronize(query.device)
|
|
||||||
else:
|
|
||||||
query_slice = query[start_idx:end_idx]
|
|
||||||
key_slice = key[start_idx:end_idx]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
####################################################################
|
|
||||||
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor:
|
def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
|
||||||
r"""
|
diffusers.utils.torch_utils.fourier_filter = fourier_filter
|
||||||
Default processor for performing attention-related computations.
|
if not device_supports_fp64:
|
||||||
"""
|
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
|
||||||
|
|
||||||
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
|
|
||||||
encoder_hidden_states=None, attention_mask=None,
|
|
||||||
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
|
|
||||||
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
args = () if USE_PEFT_BACKEND else (scale,)
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states, *args)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states, *args)
|
|
||||||
value = attn.to_v(encoder_hidden_states, *args)
|
|
||||||
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
####################################################################
|
|
||||||
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
|
|
||||||
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
|
|
||||||
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
|
|
||||||
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
|
|
||||||
|
|
||||||
if do_split:
|
|
||||||
for i in range(batch_size_attention // split_slice_size):
|
|
||||||
start_idx = i * split_slice_size
|
|
||||||
end_idx = (i + 1) * split_slice_size
|
|
||||||
if do_split_2:
|
|
||||||
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_2 = i2 * split_2_slice_size
|
|
||||||
end_idx_2 = (i2 + 1) * split_2_slice_size
|
|
||||||
if do_split_3:
|
|
||||||
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
|
|
||||||
start_idx_3 = i3 * split_3_slice_size
|
|
||||||
end_idx_3 = (i3 + 1) * split_3_slice_size
|
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
else:
|
|
||||||
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
else:
|
|
||||||
query_slice = query[start_idx:end_idx]
|
|
||||||
key_slice = key[start_idx:end_idx]
|
|
||||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
||||||
|
|
||||||
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
||||||
del query_slice
|
|
||||||
del key_slice
|
|
||||||
del attn_mask_slice
|
|
||||||
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = attn_slice
|
|
||||||
del attn_slice
|
|
||||||
torch.xpu.synchronize(query.device)
|
|
||||||
else:
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
####################################################################
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def ipex_diffusers():
|
|
||||||
#ARC GPUs can't allocate more than 4GB to a single block:
|
|
||||||
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
|
|
||||||
diffusers.models.attention_processor.AttnProcessor = AttnProcessor
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
|||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||||
|
|
||||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||||
|
|||||||
@@ -2,10 +2,19 @@ import os
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
import torch
|
import torch
|
||||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
|
||||||
|
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
|
||||||
|
try:
|
||||||
|
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
|
||||||
|
del x
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
can_allocate_plus_4gb = True
|
||||||
|
except Exception:
|
||||||
|
can_allocate_plus_4gb = False
|
||||||
|
else:
|
||||||
|
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
|
||||||
|
|
||||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
|
||||||
|
|
||||||
@@ -26,7 +35,7 @@ def check_device(device):
|
|||||||
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
|
||||||
|
|
||||||
def return_xpu(device):
|
def return_xpu(device):
|
||||||
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu"
|
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
|
||||||
|
|
||||||
|
|
||||||
# Autocast
|
# Autocast
|
||||||
@@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
|
|||||||
original_interpolate = torch.nn.functional.interpolate
|
original_interpolate = torch.nn.functional.interpolate
|
||||||
@wraps(torch.nn.functional.interpolate)
|
@wraps(torch.nn.functional.interpolate)
|
||||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||||
if antialias or align_corners is not None or mode == 'bicubic':
|
if mode in {'bicubic', 'bilinear'}:
|
||||||
return_device = tensor.device
|
return_device = tensor.device
|
||||||
return_dtype = tensor.dtype
|
return_dtype = tensor.dtype
|
||||||
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
|
||||||
@@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None):
|
|||||||
return original_as_tensor(data, dtype=dtype, device=device)
|
return original_as_tensor(data, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None:
|
if can_allocate_plus_4gb:
|
||||||
original_torch_bmm = torch.bmm
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
else:
|
else:
|
||||||
# 32 bit attention workarounds for Alchemist:
|
# 32 bit attention workarounds for Alchemist:
|
||||||
try:
|
try:
|
||||||
from .attention import torch_bmm_32_bit as original_torch_bmm
|
from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
|
||||||
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
|
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
original_torch_bmm = torch.bmm
|
|
||||||
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
|
||||||
|
|
||||||
|
|
||||||
# Data Type Errors:
|
|
||||||
@wraps(torch.bmm)
|
|
||||||
def torch_bmm(input, mat2, *, out=None):
|
|
||||||
if input.dtype != mat2.dtype:
|
|
||||||
mat2 = mat2.to(input.dtype)
|
|
||||||
return original_torch_bmm(input, mat2, out=out)
|
|
||||||
|
|
||||||
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
@wraps(torch.nn.functional.scaled_dot_product_attention)
|
||||||
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
|
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
|
||||||
if query.dtype != key.dtype:
|
if query.dtype != key.dtype:
|
||||||
key = key.to(dtype=query.dtype)
|
key = key.to(dtype=query.dtype)
|
||||||
if query.dtype != value.dtype:
|
if query.dtype != value.dtype:
|
||||||
value = value.to(dtype=query.dtype)
|
value = value.to(dtype=query.dtype)
|
||||||
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
if attn_mask is not None and query.dtype != attn_mask.dtype:
|
||||||
attn_mask = attn_mask.to(dtype=query.dtype)
|
attn_mask = attn_mask.to(dtype=query.dtype)
|
||||||
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
|
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
|
||||||
|
|
||||||
|
# Data Type Errors:
|
||||||
|
original_torch_bmm = torch.bmm
|
||||||
|
@wraps(torch.bmm)
|
||||||
|
def torch_bmm(input, mat2, *, out=None):
|
||||||
|
if input.dtype != mat2.dtype:
|
||||||
|
mat2 = mat2.to(input.dtype)
|
||||||
|
return original_torch_bmm(input, mat2, out=out)
|
||||||
|
|
||||||
|
# Diffusers FreeU
|
||||||
|
original_fft_fftn = torch.fft.fftn
|
||||||
|
@wraps(torch.fft.fftn)
|
||||||
|
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||||
|
return_dtype = input.dtype
|
||||||
|
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||||
|
|
||||||
|
# Diffusers FreeU
|
||||||
|
original_fft_ifftn = torch.fft.ifftn
|
||||||
|
@wraps(torch.fft.ifftn)
|
||||||
|
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
|
||||||
|
return_dtype = input.dtype
|
||||||
|
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
|
||||||
|
|
||||||
# A1111 FP16
|
# A1111 FP16
|
||||||
original_functional_group_norm = torch.nn.functional.group_norm
|
original_functional_group_norm = torch.nn.functional.group_norm
|
||||||
@@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None):
|
|||||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
return original_functional_linear(input, weight, bias=bias)
|
return original_functional_linear(input, weight, bias=bias)
|
||||||
|
|
||||||
|
original_functional_conv1d = torch.nn.functional.conv1d
|
||||||
|
@wraps(torch.nn.functional.conv1d)
|
||||||
|
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
|
if input.dtype != weight.data.dtype:
|
||||||
|
input = input.to(dtype=weight.data.dtype)
|
||||||
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
|
|
||||||
original_functional_conv2d = torch.nn.functional.conv2d
|
original_functional_conv2d = torch.nn.functional.conv2d
|
||||||
@wraps(torch.nn.functional.conv2d)
|
@wraps(torch.nn.functional.conv2d)
|
||||||
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
@@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
|
|||||||
bias.data = bias.data.to(dtype=weight.data.dtype)
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
|
|
||||||
# A1111 Embedding BF16
|
# LTX Video
|
||||||
original_torch_cat = torch.cat
|
original_functional_conv3d = torch.nn.functional.conv3d
|
||||||
@wraps(torch.cat)
|
@wraps(torch.nn.functional.conv3d)
|
||||||
def torch_cat(tensor, *args, **kwargs):
|
def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
|
||||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
if input.dtype != weight.data.dtype:
|
||||||
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs)
|
input = input.to(dtype=weight.data.dtype)
|
||||||
else:
|
if bias is not None and bias.data.dtype != weight.data.dtype:
|
||||||
return original_torch_cat(tensor, *args, **kwargs)
|
bias.data = bias.data.to(dtype=weight.data.dtype)
|
||||||
|
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
|
||||||
|
|
||||||
# SwinIR BF16:
|
# SwinIR BF16:
|
||||||
original_functional_pad = torch.nn.functional.pad
|
original_functional_pad = torch.nn.functional.pad
|
||||||
@@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None):
|
|||||||
original_torch_tensor = torch.tensor
|
original_torch_tensor = torch.tensor
|
||||||
@wraps(torch.tensor)
|
@wraps(torch.tensor)
|
||||||
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
|
||||||
|
global device_supports_fp64
|
||||||
if check_device(device):
|
if check_device(device):
|
||||||
device = return_xpu(device)
|
device = return_xpu(device)
|
||||||
if not device_supports_fp64:
|
if not device_supports_fp64:
|
||||||
@@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs):
|
|||||||
original_torch_randn = torch.randn
|
original_torch_randn = torch.randn
|
||||||
@wraps(torch.randn)
|
@wraps(torch.randn)
|
||||||
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
def torch_randn(*args, device=None, dtype=None, **kwargs):
|
||||||
if dtype == bytes:
|
if dtype is bytes:
|
||||||
dtype = None
|
dtype = None
|
||||||
if check_device(device):
|
if check_device(device):
|
||||||
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
|
||||||
@@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_torch_zeros(*args, device=device, **kwargs)
|
return original_torch_zeros(*args, device=device, **kwargs)
|
||||||
|
|
||||||
|
original_torch_full = torch.full
|
||||||
|
@wraps(torch.full)
|
||||||
|
def torch_full(*args, device=None, **kwargs):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_full(*args, device=return_xpu(device), **kwargs)
|
||||||
|
else:
|
||||||
|
return original_torch_full(*args, device=device, **kwargs)
|
||||||
|
|
||||||
original_torch_linspace = torch.linspace
|
original_torch_linspace = torch.linspace
|
||||||
@wraps(torch.linspace)
|
@wraps(torch.linspace)
|
||||||
def torch_linspace(*args, device=None, **kwargs):
|
def torch_linspace(*args, device=None, **kwargs):
|
||||||
@@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_torch_linspace(*args, device=device, **kwargs)
|
return original_torch_linspace(*args, device=device, **kwargs)
|
||||||
|
|
||||||
original_torch_Generator = torch.Generator
|
|
||||||
@wraps(torch.Generator)
|
|
||||||
def torch_Generator(device=None):
|
|
||||||
if check_device(device):
|
|
||||||
return original_torch_Generator(return_xpu(device))
|
|
||||||
else:
|
|
||||||
return original_torch_Generator(device)
|
|
||||||
|
|
||||||
original_torch_load = torch.load
|
original_torch_load = torch.load
|
||||||
@wraps(torch.load)
|
@wraps(torch.load)
|
||||||
def torch_load(f, map_location=None, *args, **kwargs):
|
def torch_load(f, map_location=None, *args, **kwargs):
|
||||||
@@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs):
|
|||||||
else:
|
else:
|
||||||
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
return original_torch_load(f, *args, map_location=map_location, **kwargs)
|
||||||
|
|
||||||
|
original_torch_Generator = torch.Generator
|
||||||
|
@wraps(torch.Generator)
|
||||||
|
def torch_Generator(device=None):
|
||||||
|
if check_device(device):
|
||||||
|
return original_torch_Generator(return_xpu(device))
|
||||||
|
else:
|
||||||
|
return original_torch_Generator(device)
|
||||||
|
|
||||||
|
@wraps(torch.cuda.synchronize)
|
||||||
|
def torch_cuda_synchronize(device=None):
|
||||||
|
if check_device(device):
|
||||||
|
return torch.xpu.synchronize(return_xpu(device))
|
||||||
|
else:
|
||||||
|
return torch.xpu.synchronize(device)
|
||||||
|
|
||||||
|
|
||||||
# Hijack Functions:
|
# Hijack Functions:
|
||||||
def ipex_hijacks():
|
def ipex_hijacks(legacy=True):
|
||||||
|
global device_supports_fp64, can_allocate_plus_4gb
|
||||||
|
if legacy and float(torch.__version__[:3]) < 2.5:
|
||||||
|
torch.nn.functional.interpolate = interpolate
|
||||||
torch.tensor = torch_tensor
|
torch.tensor = torch_tensor
|
||||||
torch.Tensor.to = Tensor_to
|
torch.Tensor.to = Tensor_to
|
||||||
torch.Tensor.cuda = Tensor_cuda
|
torch.Tensor.cuda = Tensor_cuda
|
||||||
@@ -289,9 +338,11 @@ def ipex_hijacks():
|
|||||||
torch.randn = torch_randn
|
torch.randn = torch_randn
|
||||||
torch.ones = torch_ones
|
torch.ones = torch_ones
|
||||||
torch.zeros = torch_zeros
|
torch.zeros = torch_zeros
|
||||||
|
torch.full = torch_full
|
||||||
torch.linspace = torch_linspace
|
torch.linspace = torch_linspace
|
||||||
torch.Generator = torch_Generator
|
|
||||||
torch.load = torch_load
|
torch.load = torch_load
|
||||||
|
torch.Generator = torch_Generator
|
||||||
|
torch.cuda.synchronize = torch_cuda_synchronize
|
||||||
|
|
||||||
torch.backends.cuda.sdp_kernel = return_null_context
|
torch.backends.cuda.sdp_kernel = return_null_context
|
||||||
torch.nn.DataParallel = DummyDataParallel
|
torch.nn.DataParallel = DummyDataParallel
|
||||||
@@ -302,12 +353,15 @@ def ipex_hijacks():
|
|||||||
torch.nn.functional.group_norm = functional_group_norm
|
torch.nn.functional.group_norm = functional_group_norm
|
||||||
torch.nn.functional.layer_norm = functional_layer_norm
|
torch.nn.functional.layer_norm = functional_layer_norm
|
||||||
torch.nn.functional.linear = functional_linear
|
torch.nn.functional.linear = functional_linear
|
||||||
|
torch.nn.functional.conv1d = functional_conv1d
|
||||||
torch.nn.functional.conv2d = functional_conv2d
|
torch.nn.functional.conv2d = functional_conv2d
|
||||||
torch.nn.functional.interpolate = interpolate
|
torch.nn.functional.conv3d = functional_conv3d
|
||||||
torch.nn.functional.pad = functional_pad
|
torch.nn.functional.pad = functional_pad
|
||||||
|
|
||||||
torch.bmm = torch_bmm
|
torch.bmm = torch_bmm
|
||||||
torch.cat = torch_cat
|
torch.fft.fftn = fft_fftn
|
||||||
|
torch.fft.ifftn = fft_ifftn
|
||||||
if not device_supports_fp64:
|
if not device_supports_fp64:
|
||||||
torch.from_numpy = from_numpy
|
torch.from_numpy = from_numpy
|
||||||
torch.as_tensor = as_tensor
|
torch.as_tensor = as_tensor
|
||||||
|
return device_supports_fp64, can_allocate_plus_4gb
|
||||||
|
|||||||
@@ -13,17 +13,7 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from typing import (
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
NamedTuple,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union
|
|
||||||
)
|
|
||||||
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
|||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
|
||||||
|
|
||||||
|
|
||||||
def split_train_val(
|
def split_train_val(
|
||||||
paths: List[str],
|
paths: List[str],
|
||||||
sizes: List[Optional[Tuple[int, int]]],
|
sizes: List[Optional[Tuple[int, int]]],
|
||||||
is_training_dataset: bool,
|
is_training_dataset: bool,
|
||||||
validation_split: float,
|
validation_split: float,
|
||||||
validation_seed: int | None
|
validation_seed: int | None,
|
||||||
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
|
||||||
"""
|
"""
|
||||||
Split the dataset into train and validation
|
Split the dataset into train and validation
|
||||||
@@ -1999,11 +1990,7 @@ class DreamBoothDataset(BaseDataset):
|
|||||||
# required for training images dataset of regularization images
|
# required for training images dataset of regularization images
|
||||||
else:
|
else:
|
||||||
img_paths, sizes = split_train_val(
|
img_paths, sizes = split_train_val(
|
||||||
img_paths,
|
img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
|
||||||
sizes,
|
|
||||||
self.is_training_dataset,
|
|
||||||
self.validation_split,
|
|
||||||
self.validation_seed
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||||
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
|
|||||||
|
|
||||||
|
|
||||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
|
||||||
|
if min_timestep < max_timestep:
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
|
||||||
|
else:
|
||||||
|
timesteps = torch.full((b_size,), max_timestep, device="cpu")
|
||||||
timesteps = timesteps.long().to(device)
|
timesteps = timesteps.long().to(device)
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
def get_noise_noisy_latents_and_timesteps(
|
||||||
|
args, noise_scheduler, latents: torch.FloatTensor
|
||||||
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
noise = torch.randn_like(latents, device=latents.device)
|
||||||
if args.noise_offset:
|
if args.noise_offset:
|
||||||
@@ -6459,12 +6451,16 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
|
|||||||
|
|
||||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||||
|
|
||||||
# Define specific metrics to handle validation and epochs "steps"
|
# Define specific metrics to handle validation and epochs "steps"
|
||||||
wandb_tracker.define_metric("epoch", hidden=True)
|
wandb_tracker.define_metric("epoch", hidden=True)
|
||||||
wandb_tracker.define_metric("val_step", hidden=True)
|
wandb_tracker.define_metric("val_step", hidden=True)
|
||||||
|
|
||||||
|
wandb_tracker.define_metric("global_step", hidden=True)
|
||||||
|
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.sample_prompts_te_outputs = None
|
self.sample_prompts_te_outputs = None
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
# super().assert_extra_args(args, train_dataset_group)
|
# super().assert_extra_args(args, train_dataset_group)
|
||||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||||
|
|
||||||
@@ -299,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
|
||||||
return noise_scheduler
|
return noise_scheduler
|
||||||
|
|
||||||
def encode_images_to_latents(self, args, accelerator, vae, images):
|
def encode_images_to_latents(self, args, vae, images):
|
||||||
return vae.encode(images)
|
return vae.encode(images)
|
||||||
|
|
||||||
def shift_scale_latents(self, args, latents):
|
def shift_scale_latents(self, args, latents):
|
||||||
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
@@ -445,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoder.to(te_weight_dtype) # fp8
|
text_encoder.to(te_weight_dtype) # fp8
|
||||||
prepare_fp8(text_encoder, weight_dtype)
|
prepare_fp8(text_encoder, weight_dtype)
|
||||||
|
|
||||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
|
||||||
# drop cached text encoder outputs
|
# drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
|
||||||
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
|
||||||
if text_encoder_outputs_list is not None:
|
if text_encoder_outputs_list is not None:
|
||||||
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||||
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
|
||||||
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
batch["text_encoder_outputs_list"] = text_encoder_outputs_list
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
|
if self.is_swapping_blocks:
|
||||||
|
# prepare for next forward: because backward pass is not called, we need to prepare it here
|
||||||
|
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
|
||||||
|
|
||||||
def prepare_unet_with_accelerator(
|
def prepare_unet_with_accelerator(
|
||||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
|
|||||||
263
train_network.py
263
train_network.py
@@ -9,6 +9,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from multiprocessing import Value
|
from multiprocessing import Value
|
||||||
|
import numpy as np
|
||||||
import toml
|
import toml
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -100,9 +101,7 @@ class NetworkTrainer:
|
|||||||
if (
|
if (
|
||||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
||||||
): # tracking d*lr value of unet.
|
): # tracking d*lr value of unet.
|
||||||
logs["lr/d*lr"] = (
|
logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
||||||
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
idx = 0
|
idx = 0
|
||||||
if not args.network_train_unet_only:
|
if not args.network_train_unet_only:
|
||||||
@@ -115,16 +114,56 @@ class NetworkTrainer:
|
|||||||
logs[f"lr/d*lr/group{i}"] = (
|
logs[f"lr/d*lr/group{i}"] = (
|
||||||
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
|
||||||
)
|
)
|
||||||
if (
|
if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
|
||||||
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
|
logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
||||||
):
|
|
||||||
logs[f"lr/d*lr/group{i}"] = (
|
|
||||||
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return logs
|
return logs
|
||||||
|
|
||||||
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
|
def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
|
||||||
|
|
||||||
|
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
|
||||||
|
|
||||||
|
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
|
||||||
|
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
|
||||||
|
|
||||||
|
def accelerator_logging(
|
||||||
|
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
step_value is for tensorboard, other values are for wandb
|
||||||
|
"""
|
||||||
|
tensorboard_tracker = None
|
||||||
|
wandb_tracker = None
|
||||||
|
other_trackers = []
|
||||||
|
for tracker in accelerator.trackers:
|
||||||
|
if tracker.name == "tensorboard":
|
||||||
|
tensorboard_tracker = accelerator.get_tracker("tensorboard")
|
||||||
|
elif tracker.name == "wandb":
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb")
|
||||||
|
else:
|
||||||
|
other_trackers.append(accelerator.get_tracker(tracker.name))
|
||||||
|
|
||||||
|
if tensorboard_tracker is not None:
|
||||||
|
tensorboard_tracker.log(logs, step=step_value)
|
||||||
|
|
||||||
|
if wandb_tracker is not None:
|
||||||
|
logs["global_step"] = global_step
|
||||||
|
logs["epoch"] = epoch
|
||||||
|
if val_step is not None:
|
||||||
|
logs["val_step"] = val_step
|
||||||
|
wandb_tracker.log(logs)
|
||||||
|
|
||||||
|
for tracker in other_trackers:
|
||||||
|
tracker.log(logs, step=step_value)
|
||||||
|
|
||||||
|
def assert_extra_args(
|
||||||
|
self,
|
||||||
|
args,
|
||||||
|
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
|
||||||
|
val_dataset_group: Optional[train_util.DatasetGroup],
|
||||||
|
):
|
||||||
train_dataset_group.verify_bucket_reso_steps(64)
|
train_dataset_group.verify_bucket_reso_steps(64)
|
||||||
if val_dataset_group is not None:
|
if val_dataset_group is not None:
|
||||||
val_dataset_group.verify_bucket_reso_steps(64)
|
val_dataset_group.verify_bucket_reso_steps(64)
|
||||||
@@ -219,7 +258,7 @@ class NetworkTrainer:
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=True
|
is_train=True,
|
||||||
):
|
):
|
||||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||||
# with noise offset and/or multires noise if specified
|
# with noise offset and/or multires noise if specified
|
||||||
@@ -309,7 +348,10 @@ class NetworkTrainer:
|
|||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
return accelerator.prepare(unet)
|
return accelerator.prepare(unet)
|
||||||
|
|
||||||
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
@@ -330,7 +372,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy: strategy_base.TokenizeStrategy,
|
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
train_text_encoder=True,
|
train_text_encoder=True,
|
||||||
train_unet=True
|
train_unet=True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Process a batch for the network
|
Process a batch for the network
|
||||||
@@ -397,7 +439,7 @@ class NetworkTrainer:
|
|||||||
network,
|
network,
|
||||||
weight_dtype,
|
weight_dtype,
|
||||||
train_unet,
|
train_unet,
|
||||||
is_train=is_train
|
is_train=is_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
|
||||||
@@ -900,7 +942,9 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
accelerator.print("running training / 学習開始")
|
accelerator.print("running training / 学習開始")
|
||||||
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}")
|
accelerator.print(
|
||||||
|
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
|
||||||
|
)
|
||||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
@@ -1243,12 +1287,6 @@ class NetworkTrainer:
|
|||||||
# log empty object to commit the sample images to wandb
|
# log empty object to commit the sample images to wandb
|
||||||
accelerator.log({}, step=0)
|
accelerator.log({}, step=0)
|
||||||
|
|
||||||
validation_steps = (
|
|
||||||
min(args.max_validation_steps, len(val_dataloader))
|
|
||||||
if args.max_validation_steps is not None
|
|
||||||
else len(val_dataloader)
|
|
||||||
)
|
|
||||||
|
|
||||||
# training loop
|
# training loop
|
||||||
if initial_step > 0: # only if skip_until_initial_step is specified
|
if initial_step > 0: # only if skip_until_initial_step is specified
|
||||||
for skip_epoch in range(epoch_to_start): # skip epochs
|
for skip_epoch in range(epoch_to_start): # skip epochs
|
||||||
@@ -1271,13 +1309,53 @@ class NetworkTrainer:
|
|||||||
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
validation_steps = (
|
||||||
|
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||||
|
)
|
||||||
|
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
|
||||||
|
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||||
|
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
|
||||||
|
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
|
||||||
|
validation_total_steps = validation_steps * len(validation_timesteps)
|
||||||
|
original_args_min_timestep = args.min_timestep
|
||||||
|
original_args_max_timestep = args.max_timestep
|
||||||
|
|
||||||
|
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
|
||||||
|
cpu_rng_state = torch.get_rng_state()
|
||||||
|
if accelerator.device.type == "cuda":
|
||||||
|
gpu_rng_state = torch.cuda.get_rng_state()
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
gpu_rng_state = torch.xpu.get_rng_state()
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
gpu_rng_state = torch.cuda.get_rng_state()
|
||||||
|
else:
|
||||||
|
gpu_rng_state = None
|
||||||
|
python_rng_state = random.getstate()
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
return (cpu_rng_state, gpu_rng_state, python_rng_state)
|
||||||
|
|
||||||
|
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
|
||||||
|
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
|
||||||
|
torch.set_rng_state(cpu_rng_state)
|
||||||
|
if gpu_rng_state is not None:
|
||||||
|
if accelerator.device.type == "cuda":
|
||||||
|
torch.cuda.set_rng_state(gpu_rng_state)
|
||||||
|
elif accelerator.device.type == "xpu":
|
||||||
|
torch.xpu.set_rng_state(gpu_rng_state)
|
||||||
|
elif accelerator.device.type == "mps":
|
||||||
|
torch.cuda.set_rng_state(gpu_rng_state)
|
||||||
|
random.setstate(python_rng_state)
|
||||||
|
|
||||||
for epoch in range(epoch_to_start, num_train_epochs):
|
for epoch in range(epoch_to_start, num_train_epochs):
|
||||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
|
||||||
current_epoch.value = epoch + 1
|
current_epoch.value = epoch + 1
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
|
||||||
|
|
||||||
# TRAINING
|
# TRAINING
|
||||||
skipped_dataloader = None
|
skipped_dataloader = None
|
||||||
@@ -1294,8 +1372,8 @@ class NetworkTrainer:
|
|||||||
with accelerator.accumulate(training_model):
|
with accelerator.accumulate(training_model):
|
||||||
on_step_start_for_network(text_encoder, unet)
|
on_step_start_for_network(text_encoder, unet)
|
||||||
|
|
||||||
# temporary, for batch processing
|
# preprocess batch for each model
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
@@ -1312,7 +1390,7 @@ class NetworkTrainer:
|
|||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
train_text_encoder=train_text_encoder,
|
train_text_encoder=train_text_encoder,
|
||||||
train_unet=train_unet
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
@@ -1369,39 +1447,35 @@ class NetworkTrainer:
|
|||||||
if args.scale_weight_norms:
|
if args.scale_weight_norms:
|
||||||
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
progress_bar.set_postfix(**{**max_mean_logs, **logs})
|
||||||
|
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = self.generate_step_logs(
|
logs = self.generate_step_logs(
|
||||||
args,
|
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
|
||||||
current_loss,
|
|
||||||
avr_loss,
|
|
||||||
lr_scheduler,
|
|
||||||
lr_descriptions,
|
|
||||||
optimizer,
|
|
||||||
keys_scaled,
|
|
||||||
mean_norm,
|
|
||||||
maximum_norm
|
|
||||||
)
|
)
|
||||||
accelerator.log(logs, step=global_step)
|
self.step_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
# VALIDATION PER STEP
|
# VALIDATION PER STEP: global_step is already incremented
|
||||||
should_validate_step = (
|
# for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
|
||||||
args.validate_every_n_steps is not None
|
should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
|
||||||
and global_step != 0 # Skip first step
|
|
||||||
and global_step % args.validate_every_n_steps == 0
|
|
||||||
)
|
|
||||||
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
|
||||||
|
optimizer_eval_fn()
|
||||||
|
accelerator.unwrap_model(network).eval()
|
||||||
|
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0,
|
range(validation_total_steps),
|
||||||
|
smoothing=0,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
desc="validation steps"
|
desc="validation steps",
|
||||||
)
|
)
|
||||||
|
val_timesteps_step = 0
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
# temporary, for batch processing
|
for timestep in validation_timesteps:
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||||
|
|
||||||
|
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
@@ -1417,21 +1491,23 @@ class NetworkTrainer:
|
|||||||
text_encoding_strategy,
|
text_encoding_strategy,
|
||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
|
||||||
train_unet=False
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
val_progress_bar.set_postfix(
|
||||||
|
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
|
||||||
|
)
|
||||||
|
|
||||||
if is_tracking:
|
# if is_tracking:
|
||||||
logs = {
|
# logs = {f"loss/validation/step_current_{timestep}": current_loss}
|
||||||
"loss/validation/step_current": current_loss,
|
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||||
"val_step": (epoch * validation_steps) + val_step,
|
|
||||||
}
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
accelerator.log(logs, step=global_step)
|
val_timesteps_step += 1
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
|
||||||
@@ -1439,31 +1515,45 @@ class NetworkTrainer:
|
|||||||
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
"loss/validation/step_average": val_step_loss_recorder.moving_average,
|
||||||
"loss/validation/step_divergence": loss_validation_divergence,
|
"loss/validation/step_divergence": loss_validation_divergence,
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
|
||||||
|
|
||||||
|
restore_rng_state(rng_states)
|
||||||
|
args.min_timestep = original_args_min_timestep
|
||||||
|
args.max_timestep = original_args_max_timestep
|
||||||
|
optimizer_train_fn()
|
||||||
|
accelerator.unwrap_model(network).train()
|
||||||
|
progress_bar.unpause()
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
# EPOCH VALIDATION
|
# EPOCH VALIDATION
|
||||||
should_validate_epoch = (
|
should_validate_epoch = (
|
||||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
(epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
|
||||||
if args.validate_every_n_epochs is not None
|
|
||||||
else True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if should_validate_epoch and len(val_dataloader) > 0:
|
if should_validate_epoch and len(val_dataloader) > 0:
|
||||||
|
optimizer_eval_fn()
|
||||||
|
accelerator.unwrap_model(network).eval()
|
||||||
|
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
range(validation_steps), smoothing=0,
|
range(validation_total_steps),
|
||||||
|
smoothing=0,
|
||||||
disable=not accelerator.is_local_main_process,
|
disable=not accelerator.is_local_main_process,
|
||||||
desc="epoch validation steps"
|
desc="epoch validation steps",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val_timesteps_step = 0
|
||||||
for val_step, batch in enumerate(val_dataloader):
|
for val_step, batch in enumerate(val_dataloader):
|
||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
for timestep in validation_timesteps:
|
||||||
|
args.min_timestep = args.max_timestep = timestep
|
||||||
|
|
||||||
# temporary, for batch processing
|
# temporary, for batch processing
|
||||||
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
@@ -1479,22 +1569,23 @@ class NetworkTrainer:
|
|||||||
text_encoding_strategy,
|
text_encoding_strategy,
|
||||||
tokenize_strategy,
|
tokenize_strategy,
|
||||||
is_train=False,
|
is_train=False,
|
||||||
train_text_encoder=False,
|
train_text_encoder=train_text_encoder,
|
||||||
train_unet=False
|
train_unet=train_unet,
|
||||||
)
|
)
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
|
||||||
val_progress_bar.update(1)
|
val_progress_bar.update(1)
|
||||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
val_progress_bar.set_postfix(
|
||||||
|
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
|
||||||
|
)
|
||||||
|
|
||||||
if is_tracking:
|
# if is_tracking:
|
||||||
logs = {
|
# logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
|
||||||
"loss/validation/epoch_current": current_loss,
|
# self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
|
||||||
"epoch": epoch + 1,
|
|
||||||
"val_step": (epoch * validation_steps) + val_step
|
self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
}
|
val_timesteps_step += 1
|
||||||
accelerator.log(logs, step=global_step)
|
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
@@ -1502,14 +1593,20 @@ class NetworkTrainer:
|
|||||||
logs = {
|
logs = {
|
||||||
"loss/validation/epoch_average": avr_loss,
|
"loss/validation/epoch_average": avr_loss,
|
||||||
"loss/validation/epoch_divergence": loss_validation_divergence,
|
"loss/validation/epoch_divergence": loss_validation_divergence,
|
||||||
"epoch": epoch + 1
|
|
||||||
}
|
}
|
||||||
accelerator.log(logs, step=global_step)
|
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
|
restore_rng_state(rng_states)
|
||||||
|
args.min_timestep = original_args_min_timestep
|
||||||
|
args.max_timestep = original_args_max_timestep
|
||||||
|
optimizer_train_fn()
|
||||||
|
accelerator.unwrap_model(network).train()
|
||||||
|
progress_bar.unpause()
|
||||||
|
|
||||||
# END OF EPOCH
|
# END OF EPOCH
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||||
accelerator.log(logs, step=global_step)
|
self.epoch_logging(accelerator, logs, global_step, epoch + 1)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
@@ -1696,31 +1793,31 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
"--validation_seed",
|
"--validation_seed",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する"
|
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validation_split",
|
"--validation_split",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合"
|
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validate_every_n_steps",
|
"--validate_every_n_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます"
|
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--validate_every_n_epochs",
|
"--validate_every_n_epochs",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます"
|
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max_validation_steps",
|
"--max_validation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
default=None,
|
default=None,
|
||||||
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します"
|
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
|
||||||
)
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user