Merge branch 'sd3' into sd3_safetensors_merge

This commit is contained in:
Kohya S
2025-02-28 23:52:36 +09:00
11 changed files with 653 additions and 800 deletions

View File

@@ -14,6 +14,11 @@ The command to install PyTorch is as follows:
### Recent Updates ### Recent Updates
Feb 26, 2025:
- Improve the validation loss calculation in `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py`. PR [#1903](https://github.com/kohya-ss/sd-scripts/pull/1903)
- The validation loss uses the fixed timestep sampling and the fixed random seed. This is to ensure that the validation loss is not fluctuated by the random values.
Jan 25, 2025: Jan 25, 2025:
- `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO! - `train_network.py`, `sdxl_train_network.py`, `flux_train_network.py`, and `sd3_train_network.py` now support validation loss. PR [#1864](https://github.com/kohya-ss/sd-scripts/pull/1864) Thank you to rockerBOO!

View File

@@ -36,7 +36,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.is_schnell: Optional[bool] = None self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False self.is_swapping_blocks: bool = False
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
super().assert_extra_args(args, train_dataset_group, val_dataset_group) super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args) # sdxl_train_util.verify_sdxl_training_args(args)
@@ -323,7 +328,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler) self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images): def encode_images_to_latents(self, args, vae, images):
return vae.encode(images) return vae.encode(images)
def shift_scale_latents(self, args, latents): def shift_scale_latents(self, args, latents):
@@ -341,7 +346,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
@@ -376,8 +381,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t5_attn_mask = None t5_attn_mask = None
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask): def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
# if not args.split_mode: # grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
# normal forward
with torch.set_grad_enabled(is_train), accelerator.autocast(): with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing) # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = unet( model_pred = unet(
@@ -390,44 +394,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
guidance=guidance_vec, guidance=guidance_vec,
txt_attention_mask=t5_attn_mask, txt_attention_mask=t5_attn_mask,
) )
"""
else:
# split forward to reduce memory usage
assert network.train_blocks == "single", "train_blocks must be single for split mode"
with accelerator.autocast():
# move flux lower to cpu, and then move flux upper to gpu
unet.to("cpu")
clean_memory_on_device(accelerator.device)
self.flux_upper.to(accelerator.device)
# upper model does not require grad
with torch.no_grad():
intermediate_img, intermediate_txt, vec, pe = self.flux_upper(
img=packed_noisy_model_input,
img_ids=img_ids,
txt=t5_out,
txt_ids=txt_ids,
y=l_pooled,
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
# move flux upper back to cpu, and then move flux lower to gpu
self.flux_upper.to("cpu")
clean_memory_on_device(accelerator.device)
unet.to(accelerator.device)
# lower model requires grad
intermediate_img.requires_grad_(True)
intermediate_txt.requires_grad_(True)
vec.requires_grad_(True)
pe.requires_grad_(True)
with torch.set_grad_enabled(is_train and train_unet):
model_pred = unet(img=intermediate_img, txt=intermediate_txt, vec=vec, pe=pe, txt_attention_mask=t5_attn_mask)
"""
return model_pred return model_pred
model_pred = call_dit( model_pred = call_dit(
@@ -546,6 +512,11 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -2,6 +2,13 @@ import functools
import gc import gc
import torch import torch
try:
# intel gpu support for pytorch older than 2.5
# ipex is not needed after pytorch 2.5
import intel_extension_for_pytorch as ipex # noqa
except Exception:
pass
try: try:
HAS_CUDA = torch.cuda.is_available() HAS_CUDA = torch.cuda.is_available()
@@ -14,8 +21,6 @@ except Exception:
HAS_MPS = False HAS_MPS = False
try: try:
import intel_extension_for_pytorch as ipex # noqa
HAS_XPU = torch.xpu.is_available() HAS_XPU = torch.xpu.is_available()
except Exception: except Exception:
HAS_XPU = False HAS_XPU = False
@@ -69,7 +74,7 @@ def init_ipex():
This function should run right after importing torch and before doing anything else. This function should run right after importing torch and before doing anything else.
If IPEX is not available, this function does nothing. If xpu is not available, this function does nothing.
""" """
try: try:
if HAS_XPU: if HAS_XPU:

View File

@@ -2,7 +2,11 @@ import os
import sys import sys
import contextlib import contextlib
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
legacy = True
except Exception:
legacy = False
from .hijacks import ipex_hijacks from .hijacks import ipex_hijacks
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
@@ -12,6 +16,13 @@ def ipex_init(): # pylint: disable=too-many-statements
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked: if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack" return True, "Skipping IPEX hijack"
else: else:
try: # force xpu device on torch compile and triton
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
from triton import backends as triton_backends # pylint: disable=import-error
triton_backends.backends["nvidia"].driver.is_active = lambda *args, **kwargs: False
except Exception:
pass
# Replace cuda with xpu: # Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream torch.cuda.current_stream = torch.xpu.current_stream
@@ -26,84 +37,99 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_current_stream_capturing = lambda: False torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream torch.cuda.stream = torch.xpu.stream
torch.cuda.synchronize = torch.xpu.synchronize
torch.cuda.Event = torch.xpu.Event torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream torch.cuda.Stream = torch.xpu.Stream
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.nn.Module.cuda = torch.nn.Module.xpu torch.nn.Module.cuda = torch.nn.Module.xpu
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.torch = torch.xpu.torch
torch.cuda.Union = torch.xpu.Union
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.random = torch.xpu.random
torch.cuda._device = torch.xpu._device
torch.cuda.__name__ = torch.xpu.__name__
torch.cuda._device_t = torch.xpu._device_t
torch.cuda.__spec__ = torch.xpu.__spec__
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
if legacy:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if float(ipex.__version__[:3]) < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda._lazy_new = torch.xpu._lazy_new torch.cuda._lazy_new = torch.xpu._lazy_new
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
torch.cuda.default_generators = torch.xpu.default_generators
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda._get_device_index = torch.xpu._get_device_index
torch.cuda.__path__ = torch.xpu.__path__
torch.cuda.Device = torch.xpu.Device
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.set_stream = torch.xpu.set_stream
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.os = torch.xpu.os
torch.cuda.torch = torch.xpu.torch
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.Union = torch.xpu.Union
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.ByteTensor = torch.xpu.ByteTensor torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda._lazy_call = torch.xpu._lazy_call
torch.cuda.HalfStorage = torch.xpu.HalfStorage torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.random = torch.xpu.random torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda._device = torch.xpu._device torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.classproperty = torch.xpu.classproperty torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.__name__ = torch.xpu.__name__ torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda._device_t = torch.xpu._device_t torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.warnings = torch.xpu.warnings torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.__spec__ = torch.xpu.__spec__ torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.BoolTensor = torch.xpu.BoolTensor torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__ torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork torch.cuda.BoolStorage = torch.xpu.BoolStorage
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
if not legacy or float(ipex.__version__[:3]) >= 2.3:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu._queued_calls
torch.cuda._tls = torch.xpu._tls
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback
# Memory: # Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read(): if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache torch.cuda.empty_cache = torch.xpu.empty_cache
torch.cuda.memory_stats = torch.xpu.memory_stats
if legacy:
torch.cuda.memory_summary = torch.xpu.memory_summary torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory = torch.xpu.memory
torch.cuda.memory_stats = torch.xpu.memory_stats
torch.cuda.memory_allocated = torch.xpu.memory_allocated torch.cuda.memory_allocated = torch.xpu.memory_allocated
torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated torch.cuda.max_memory_allocated = torch.xpu.max_memory_allocated
torch.cuda.memory_reserved = torch.xpu.memory_reserved torch.cuda.memory_reserved = torch.xpu.memory_reserved
@@ -128,7 +154,11 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.initial_seed = torch.xpu.initial_seed torch.cuda.initial_seed = torch.xpu.initial_seed
# AMP: # AMP:
if legacy:
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp torch.cuda.amp = torch.xpu.amp
if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype
@@ -147,13 +177,21 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
# C # C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream if legacy and float(ipex.__version__[:3]) < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 2024 ipex._C._DeviceProperties.major = 12
ipex._C._DeviceProperties.minor = 0 ipex._C._DeviceProperties.minor = 1
else:
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
torch._C._XpuDeviceProperties.major = 12
torch._C._XpuDeviceProperties.minor = 1
# Fix functions with ipex: # Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory] # torch.xpu.mem_get_info always returns the total memory as free memory
torch.xpu.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch.cuda.mem_get_info = torch.xpu.mem_get_info
torch._utils._get_available_device_type = lambda: "xpu" torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True torch.has_cuda = True
torch.cuda.has_half = True torch.cuda.has_half = True
@@ -161,17 +199,17 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1" torch.version.cuda = "12.1"
torch.cuda.get_device_capability = lambda *args, **kwargs: [12,1] torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.get_device_properties.major = 12 torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1 torch.cuda.get_device_properties.minor = 1
torch.cuda.ipc_collect = lambda *args, **kwargs: None torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0 torch.cuda.utilization = lambda *args, **kwargs: 0
ipex_hijacks() device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
if not torch.xpu.has_fp64_dtype() or os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is not None:
try: try:
from .diffusers import ipex_diffusers from .diffusers import ipex_diffusers
ipex_diffusers() ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
pass pass
torch.cuda.is_xpu_hijacked = True torch.cuda.is_xpu_hijacked = True

View File

@@ -1,177 +1,119 @@
import os import os
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import from functools import cache, wraps
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers # ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers
sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 1))
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 0.5))
# Find something divisible with the input_tokens # Find something divisible with the input_tokens
@cache @cache
def find_slice_size(slice_size, slice_block_size): def find_split_size(original_size, slice_block_size, slice_rate=2):
while (slice_size * slice_block_size) > attention_slice_rate: split_size = original_size
slice_size = slice_size // 2 while True:
if slice_size <= 1: if (split_size * slice_block_size) <= slice_rate and original_size % split_size == 0:
slice_size = 1 return split_size
break split_size = split_size - 1
return slice_size if split_size <= 1:
return 1
return split_size
# Find slice sizes for SDPA # Find slice sizes for SDPA
@cache @cache
def find_sdpa_slice_sizes(query_shape, query_element_size): def find_sdpa_slice_sizes(query_shape, key_shape, query_element_size, slice_rate=2, trigger_rate=3):
if len(query_shape) == 3: batch_size, attn_heads, query_len, _ = query_shape
batch_size_attention, query_tokens, shape_three = query_shape _, _, key_len, _ = key_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size slice_batch_size = attn_heads * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention split_batch_size = batch_size
split_2_slice_size = query_tokens split_head_size = attn_heads
split_3_slice_size = shape_three split_query_size = query_len
do_split = False do_batch_split = False
do_split_2 = False do_head_split = False
do_split_3 = False do_query_split = False
if block_size > sdpa_slice_trigger_rate: if batch_size * slice_batch_size >= trigger_rate:
do_split = True do_batch_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size) split_batch_size = find_split_size(batch_size, slice_batch_size, slice_rate=slice_rate)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size if split_batch_size * slice_batch_size > slice_rate:
slice_head_size = split_batch_size * (query_len * key_len) * query_element_size / 1024 / 1024 / 1024
do_head_split = True
split_head_size = find_split_size(attn_heads, slice_head_size, slice_rate=slice_rate)
# Find slice sizes for BMM if split_head_size * slice_head_size > slice_rate:
@cache slice_query_size = split_batch_size * split_head_size * (key_len) * query_element_size / 1024 / 1024 / 1024
def find_bmm_slice_sizes(input_shape, input_element_size, mat2_shape): do_query_split = True
batch_size_attention, input_tokens, mat2_atten_shape = input_shape[0], input_shape[1], mat2_shape[2] split_query_size = find_split_size(query_len, slice_query_size, slice_rate=slice_rate)
slice_block_size = input_tokens * mat2_atten_shape / 1024 / 1024 * input_element_size
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention return do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size
split_2_slice_size = input_tokens
split_3_slice_size = mat2_atten_shape
do_split = False
do_split_2 = False
do_split_3 = False
if block_size > attention_slice_rate:
do_split = True
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * mat2_atten_shape / 1024 / 1024 * input_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size / 1024 / 1024 * input_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
original_torch_bmm = torch.bmm
def torch_bmm_32_bit(input, mat2, *, out=None):
if input.device.type != "xpu":
return original_torch_bmm(input, mat2, out=out)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_bmm_slice_sizes(input.shape, input.element_size(), mat2.shape)
# Slice BMM
if do_split:
batch_size_attention, input_tokens, mat2_atten_shape = input.shape[0], input.shape[1], mat2.shape[2]
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(input_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(mat2_atten_shape // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
mat2[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
out=out
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_torch_bmm(
input[start_idx:end_idx, start_idx_2:end_idx_2],
mat2[start_idx:end_idx, start_idx_2:end_idx_2],
out=out
)
else:
hidden_states[start_idx:end_idx] = original_torch_bmm(
input[start_idx:end_idx],
mat2[start_idx:end_idx],
out=out
)
torch.xpu.synchronize(input.device)
else:
return original_torch_bmm(input, mat2, out=out)
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs): @wraps(torch.nn.functional.scaled_dot_product_attention)
def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.device.type != "xpu": if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_sdpa_slice_sizes(query.shape, query.element_size()) is_unsqueezed = False
if len(query.shape) == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if len(key.shape) == 3:
key = key.unsqueeze(0)
if len(value.shape) == 3:
value = value.unsqueeze(0)
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)
# Slice SDPA # Slice SDPA
if do_split: if do_batch_split:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2] batch_size, attn_heads, query_len, _ = query.shape
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype) _, _, _, head_dim = value.shape
for i in range(batch_size_attention // split_slice_size): hidden_states = torch.zeros((batch_size, attn_heads, query_len, head_dim), device=query.device, dtype=query.dtype)
start_idx = i * split_slice_size if attn_mask is not None:
end_idx = (i + 1) * split_slice_size attn_mask = attn_mask.expand((query.shape[0], query.shape[1], query.shape[2], key.shape[-2]))
if do_split_2: for ib in range(batch_size // split_batch_size):
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name start_idx = ib * split_batch_size
start_idx_2 = i2 * split_2_slice_size end_idx = (ib + 1) * split_batch_size
end_idx_2 = (i2 + 1) * split_2_slice_size if do_head_split:
if do_split_3: for ih in range(attn_heads // split_head_size): # pylint: disable=invalid-name
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name start_idx_h = ih * split_head_size
start_idx_3 = i3 * split_3_slice_size end_idx_h = (ih + 1) * split_head_size
end_idx_3 = (i3 + 1) * split_3_slice_size if do_query_split:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention( for iq in range(query_len // split_query_size): # pylint: disable=invalid-name
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], start_idx_q = iq * split_query_size
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], end_idx_q = (iq + 1) * split_query_size
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3], hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] = original_scaled_dot_product_attention(
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask, query[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :],
key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, start_idx_q:end_idx_q, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
else: else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention( hidden_states[start_idx:end_idx, start_idx_h:end_idx_h, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2], query[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
key[start_idx:end_idx, start_idx_2:end_idx_2], key[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
value[start_idx:end_idx, start_idx_2:end_idx_2], value[start_idx:end_idx, start_idx_h:end_idx_h, :, :],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask, attn_mask=attn_mask[start_idx:end_idx, start_idx_h:end_idx_h, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
else: else:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention( hidden_states[start_idx:end_idx, :, :, :] = original_scaled_dot_product_attention(
query[start_idx:end_idx], query[start_idx:end_idx, :, :, :],
key[start_idx:end_idx], key[start_idx:end_idx, :, :, :],
value[start_idx:end_idx], value[start_idx:end_idx, :, :, :],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask, attn_mask=attn_mask[start_idx:end_idx, :, :, :] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal, **kwargs dropout_p=dropout_p, is_causal=is_causal, **kwargs
) )
torch.xpu.synchronize(query.device) torch.xpu.synchronize(query.device)
else: else:
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs) hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
if is_unsqueezed:
hidden_states.squeeze(0)
return hidden_states return hidden_states

View File

@@ -1,312 +1,47 @@
import os from functools import wraps
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import import diffusers # pylint: disable=import-error
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
from diffusers.utils import USE_PEFT_BACKEND
from functools import cache
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4))
@cache # Diffusers FreeU
def find_slice_size(slice_size, slice_block_size): original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
while (slice_size * slice_block_size) > attention_slice_rate: @wraps(diffusers.utils.torch_utils.fourier_filter)
slice_size = slice_size // 2 def fourier_filter(x_in, threshold, scale):
if slice_size <= 1: return_dtype = x_in.dtype
slice_size = 1 return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype)
break
return slice_size
@cache
def find_attention_slice_sizes(query_shape, query_element_size, query_device_type, slice_size=None):
if len(query_shape) == 3:
batch_size_attention, query_tokens, shape_three = query_shape
shape_four = 1
else:
batch_size_attention, query_tokens, shape_three, shape_four = query_shape
if slice_size is not None:
batch_size_attention = slice_size
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * query_element_size # fp64 error
block_size = batch_size_attention * slice_block_size class FluxPosEmbed(torch.nn.Module):
def __init__(self, theta: int, axes_dim):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
split_slice_size = batch_size_attention def forward(self, ids: torch.Tensor) -> torch.Tensor:
split_2_slice_size = query_tokens n_axes = ids.shape[-1]
split_3_slice_size = shape_three cos_out = []
sin_out = []
do_split = False pos = ids.float()
do_split_2 = False for i in range(n_axes):
do_split_3 = False cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed(
self.axes_dim[i],
if query_device_type != "xpu": pos[:, i],
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size theta=self.theta,
repeat_interleave_real=True,
if block_size > attention_slice_rate: use_real=True,
do_split = True freqs_dtype=torch.float32,
split_slice_size = find_slice_size(split_slice_size, slice_block_size)
if split_slice_size * slice_block_size > attention_slice_rate:
slice_2_block_size = split_slice_size * shape_three * shape_four / 1024 / 1024 * query_element_size
do_split_2 = True
split_2_slice_size = find_slice_size(split_2_slice_size, slice_2_block_size)
if split_2_slice_size * slice_2_block_size > attention_slice_rate:
slice_3_block_size = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * query_element_size
do_split_3 = True
split_3_slice_size = find_slice_size(split_3_slice_size, slice_3_block_size)
return do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size
class SlicedAttnProcessor: # pylint: disable=too-few-public-methods
r"""
Processor for implementing sliced attention.
Args:
slice_size (`int`, *optional*):
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
`attention_head_dim` must be a multiple of the `slice_size`.
"""
def __init__(self, slice_size):
self.slice_size = slice_size
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None) -> torch.FloatTensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
) )
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) cos_out.append(cos)
sin_out.append(sin)
if attn.group_norm is not None: freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
batch_size_attention, query_tokens, shape_three = query.shape
hidden_states = torch.zeros(
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
_, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type, slice_size=self.slice_size)
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor: def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
r""" diffusers.utils.torch_utils.fourier_filter = fourier_filter
Default processor for performing attention-related computations. if not device_supports_fp64:
""" diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor,
encoder_hidden_states=None, attention_mask=None,
temb=None, scale: float = 1.0) -> torch.Tensor: # pylint: disable=too-many-statements, too-many-locals, too-many-branches
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
####################################################################
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, query_tokens, shape_three = query.shape[0], query.shape[1], query.shape[2]
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
do_split, do_split_2, do_split_3, split_slice_size, split_2_slice_size, split_3_slice_size = find_attention_slice_sizes(query.shape, query.element_size(), query.device.type)
if do_split:
for i in range(batch_size_attention // split_slice_size):
start_idx = i * split_slice_size
end_idx = (i + 1) * split_slice_size
if do_split_2:
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx, start_idx_2:end_idx_2]
key_slice = key[start_idx:end_idx, start_idx_2:end_idx_2]
attn_mask_slice = attention_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx, start_idx_2:end_idx_2])
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = attn_slice
del attn_slice
else:
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
del query_slice
del key_slice
del attn_mask_slice
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
del attn_slice
torch.xpu.synchronize(query.device)
else:
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
####################################################################
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
def ipex_diffusers():
#ARC GPUs can't allocate more than 4GB to a single block:
diffusers.models.attention_processor.SlicedAttnProcessor = SlicedAttnProcessor
diffusers.models.attention_processor.AttnProcessor = AttnProcessor

View File

@@ -5,7 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long # pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype() device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
OptState = ipex.cpu.autocast._grad_scaler.OptState OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator _MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state _refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state

View File

@@ -2,10 +2,19 @@ import os
from functools import wraps from functools import wraps
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import numpy as np import numpy as np
device_supports_fp64 = torch.xpu.has_fp64_dtype() device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
del x
torch.xpu.empty_cache()
can_allocate_plus_4gb = True
except Exception:
can_allocate_plus_4gb = False
else:
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return # pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return
@@ -26,7 +35,7 @@ def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int)) return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def return_xpu(device): def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device("xpu") if isinstance(device, torch.device) else "xpu" return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
# Autocast # Autocast
@@ -42,7 +51,7 @@ def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=Non
original_interpolate = torch.nn.functional.interpolate original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate) @wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None or mode == 'bicubic': if mode in {'bicubic', 'bilinear'}:
return_device = tensor.device return_device = tensor.device
return_dtype = tensor.dtype return_dtype = tensor.dtype
return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode, return original_interpolate(tensor.to("cpu", dtype=torch.float32), size=size, scale_factor=scale_factor, mode=mode,
@@ -73,35 +82,46 @@ def as_tensor(data, dtype=None, device=None):
return original_as_tensor(data, dtype=dtype, device=device) return original_as_tensor(data, dtype=dtype, device=device)
if device_supports_fp64 and os.environ.get('IPEX_FORCE_ATTENTION_SLICE', None) is None: if can_allocate_plus_4gb:
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else: else:
# 32 bit attention workarounds for Alchemist: # 32 bit attention workarounds for Alchemist:
try: try:
from .attention import torch_bmm_32_bit as original_torch_bmm from .attention import dynamic_scaled_dot_product_attention as original_scaled_dot_product_attention
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# Data Type Errors:
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
@wraps(torch.nn.functional.scaled_dot_product_attention) @wraps(torch.nn.functional.scaled_dot_product_attention)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False): def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, **kwargs):
if query.dtype != key.dtype: if query.dtype != key.dtype:
key = key.to(dtype=query.dtype) key = key.to(dtype=query.dtype)
if query.dtype != value.dtype: if query.dtype != value.dtype:
value = value.to(dtype=query.dtype) value = value.to(dtype=query.dtype)
if attn_mask is not None and query.dtype != attn_mask.dtype: if attn_mask is not None and query.dtype != attn_mask.dtype:
attn_mask = attn_mask.to(dtype=query.dtype) attn_mask = attn_mask.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal) return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
# Data Type Errors:
original_torch_bmm = torch.bmm
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
# Diffusers FreeU
original_fft_fftn = torch.fft.fftn
@wraps(torch.fft.fftn)
def fft_fftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_fftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# Diffusers FreeU
original_fft_ifftn = torch.fft.ifftn
@wraps(torch.fft.ifftn)
def fft_ifftn(input, s=None, dim=None, norm=None, *, out=None):
return_dtype = input.dtype
return original_fft_ifftn(input.to(dtype=torch.float32), s=s, dim=dim, norm=norm, out=out).to(dtype=return_dtype)
# A1111 FP16 # A1111 FP16
original_functional_group_norm = torch.nn.functional.group_norm original_functional_group_norm = torch.nn.functional.group_norm
@@ -133,6 +153,15 @@ def functional_linear(input, weight, bias=None):
bias.data = bias.data.to(dtype=weight.data.dtype) bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_linear(input, weight, bias=bias) return original_functional_linear(input, weight, bias=bias)
original_functional_conv1d = torch.nn.functional.conv1d
@wraps(torch.nn.functional.conv1d)
def functional_conv1d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if input.dtype != weight.data.dtype:
input = input.to(dtype=weight.data.dtype)
if bias is not None and bias.data.dtype != weight.data.dtype:
bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv1d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
original_functional_conv2d = torch.nn.functional.conv2d original_functional_conv2d = torch.nn.functional.conv2d
@wraps(torch.nn.functional.conv2d) @wraps(torch.nn.functional.conv2d)
def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
@@ -142,14 +171,15 @@ def functional_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1,
bias.data = bias.data.to(dtype=weight.data.dtype) bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) return original_functional_conv2d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# A1111 Embedding BF16 # LTX Video
original_torch_cat = torch.cat original_functional_conv3d = torch.nn.functional.conv3d
@wraps(torch.cat) @wraps(torch.nn.functional.conv3d)
def torch_cat(tensor, *args, **kwargs): def functional_conv3d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype): if input.dtype != weight.data.dtype:
return original_torch_cat([tensor[0].to(tensor[1].dtype), tensor[1], tensor[2].to(tensor[1].dtype)], *args, **kwargs) input = input.to(dtype=weight.data.dtype)
else: if bias is not None and bias.data.dtype != weight.data.dtype:
return original_torch_cat(tensor, *args, **kwargs) bias.data = bias.data.to(dtype=weight.data.dtype)
return original_functional_conv3d(input, weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
# SwinIR BF16: # SwinIR BF16:
original_functional_pad = torch.nn.functional.pad original_functional_pad = torch.nn.functional.pad
@@ -164,6 +194,7 @@ def functional_pad(input, pad, mode='constant', value=None):
original_torch_tensor = torch.tensor original_torch_tensor = torch.tensor
@wraps(torch.tensor) @wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs): def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_device(device): if check_device(device):
device = return_xpu(device) device = return_xpu(device)
if not device_supports_fp64: if not device_supports_fp64:
@@ -227,7 +258,7 @@ def torch_empty(*args, device=None, **kwargs):
original_torch_randn = torch.randn original_torch_randn = torch.randn
@wraps(torch.randn) @wraps(torch.randn)
def torch_randn(*args, device=None, dtype=None, **kwargs): def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype == bytes: if dtype is bytes:
dtype = None dtype = None
if check_device(device): if check_device(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs) return original_torch_randn(*args, device=return_xpu(device), **kwargs)
@@ -250,6 +281,14 @@ def torch_zeros(*args, device=None, **kwargs):
else: else:
return original_torch_zeros(*args, device=device, **kwargs) return original_torch_zeros(*args, device=device, **kwargs)
original_torch_full = torch.full
@wraps(torch.full)
def torch_full(*args, device=None, **kwargs):
if check_device(device):
return original_torch_full(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_full(*args, device=device, **kwargs)
original_torch_linspace = torch.linspace original_torch_linspace = torch.linspace
@wraps(torch.linspace) @wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs): def torch_linspace(*args, device=None, **kwargs):
@@ -258,14 +297,6 @@ def torch_linspace(*args, device=None, **kwargs):
else: else:
return original_torch_linspace(*args, device=device, **kwargs) return original_torch_linspace(*args, device=device, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
original_torch_load = torch.load original_torch_load = torch.load
@wraps(torch.load) @wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs): def torch_load(f, map_location=None, *args, **kwargs):
@@ -276,9 +307,27 @@ def torch_load(f, map_location=None, *args, **kwargs):
else: else:
return original_torch_load(f, *args, map_location=map_location, **kwargs) return original_torch_load(f, *args, map_location=map_location, **kwargs)
original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)
@wraps(torch.cuda.synchronize)
def torch_cuda_synchronize(device=None):
if check_device(device):
return torch.xpu.synchronize(return_xpu(device))
else:
return torch.xpu.synchronize(device)
# Hijack Functions: # Hijack Functions:
def ipex_hijacks(): def ipex_hijacks(legacy=True):
global device_supports_fp64, can_allocate_plus_4gb
if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
torch.tensor = torch_tensor torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda torch.Tensor.cuda = Tensor_cuda
@@ -289,9 +338,11 @@ def ipex_hijacks():
torch.randn = torch_randn torch.randn = torch_randn
torch.ones = torch_ones torch.ones = torch_ones
torch.zeros = torch_zeros torch.zeros = torch_zeros
torch.full = torch_full
torch.linspace = torch_linspace torch.linspace = torch_linspace
torch.Generator = torch_Generator
torch.load = torch_load torch.load = torch_load
torch.Generator = torch_Generator
torch.cuda.synchronize = torch_cuda_synchronize
torch.backends.cuda.sdp_kernel = return_null_context torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel torch.nn.DataParallel = DummyDataParallel
@@ -302,12 +353,15 @@ def ipex_hijacks():
torch.nn.functional.group_norm = functional_group_norm torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm torch.nn.functional.layer_norm = functional_layer_norm
torch.nn.functional.linear = functional_linear torch.nn.functional.linear = functional_linear
torch.nn.functional.conv1d = functional_conv1d
torch.nn.functional.conv2d = functional_conv2d torch.nn.functional.conv2d = functional_conv2d
torch.nn.functional.interpolate = interpolate torch.nn.functional.conv3d = functional_conv3d
torch.nn.functional.pad = functional_pad torch.nn.functional.pad = functional_pad
torch.bmm = torch_bmm torch.bmm = torch_bmm
torch.cat = torch_cat torch.fft.fftn = fft_fftn
torch.fft.ifftn = fft_ifftn
if not device_supports_fp64: if not device_supports_fp64:
torch.from_numpy = from_numpy torch.from_numpy = from_numpy
torch.as_tensor = as_tensor torch.as_tensor = as_tensor
return device_supports_fp64, can_allocate_plus_4gb

View File

@@ -13,17 +13,7 @@ import re
import shutil import shutil
import time import time
import typing import typing
from typing import ( from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Tuple,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
import glob import glob
import math import math
@@ -146,12 +136,13 @@ IMAGE_TRANSFORMS = transforms.Compose(
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz" TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"
def split_train_val( def split_train_val(
paths: List[str], paths: List[str],
sizes: List[Optional[Tuple[int, int]]], sizes: List[Optional[Tuple[int, int]]],
is_training_dataset: bool, is_training_dataset: bool,
validation_split: float, validation_split: float,
validation_seed: int | None validation_seed: int | None,
) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]: ) -> Tuple[List[str], List[Optional[Tuple[int, int]]]]:
""" """
Split the dataset into train and validation Split the dataset into train and validation
@@ -1999,11 +1990,7 @@ class DreamBoothDataset(BaseDataset):
# required for training images dataset of regularization images # required for training images dataset of regularization images
else: else:
img_paths, sizes = split_train_val( img_paths, sizes = split_train_val(
img_paths, img_paths, sizes, self.is_training_dataset, self.validation_split, self.validation_seed
sizes,
self.is_training_dataset,
self.validation_split,
self.validation_seed
) )
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
@@ -5944,12 +5931,17 @@ def save_sd_model_on_train_end_common(
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor: def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device) -> torch.Tensor:
if min_timestep < max_timestep:
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu")
else:
timesteps = torch.full((b_size,), max_timestep, device="cpu")
timesteps = timesteps.long().to(device) timesteps = timesteps.long().to(device)
return timesteps return timesteps
def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]: def get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.IntTensor]:
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents, device=latents.device) noise = torch.randn_like(latents, device=latents.device)
if args.noise_offset: if args.noise_offset:
@@ -6459,12 +6451,16 @@ def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tr
if "wandb" in [tracker.name for tracker in accelerator.trackers]: if "wandb" in [tracker.name for tracker in accelerator.trackers]:
import wandb import wandb
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True) wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
# Define specific metrics to handle validation and epochs "steps" # Define specific metrics to handle validation and epochs "steps"
wandb_tracker.define_metric("epoch", hidden=True) wandb_tracker.define_metric("epoch", hidden=True)
wandb_tracker.define_metric("val_step", hidden=True) wandb_tracker.define_metric("val_step", hidden=True)
wandb_tracker.define_metric("global_step", hidden=True)
# endregion # endregion

View File

@@ -26,7 +26,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
super().__init__() super().__init__()
self.sample_prompts_te_outputs = None self.sample_prompts_te_outputs = None
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
# super().assert_extra_args(args, train_dataset_group) # super().assert_extra_args(args, train_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args) # sdxl_train_util.verify_sdxl_training_args(args)
@@ -299,7 +304,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift) noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.training_shift)
return noise_scheduler return noise_scheduler
def encode_images_to_latents(self, args, accelerator, vae, images): def encode_images_to_latents(self, args, vae, images):
return vae.encode(images) return vae.encode(images)
def shift_scale_latents(self, args, latents): def shift_scale_latents(self, args, latents):
@@ -317,7 +322,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise that we'll add to the latents # Sample noise that we'll add to the latents
noise = torch.randn_like(latents) noise = torch.randn_like(latents)
@@ -445,14 +450,19 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
text_encoder.to(te_weight_dtype) # fp8 text_encoder.to(te_weight_dtype) # fp8
prepare_fp8(text_encoder, weight_dtype) prepare_fp8(text_encoder, weight_dtype)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True):
# drop cached text encoder outputs # drop cached text encoder outputs: in validation, we drop cached outputs deterministically by fixed seed
text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None: if text_encoder_outputs_list is not None:
text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy() text_encodoing_strategy: strategy_sd3.Sd3TextEncodingStrategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list) text_encoder_outputs_list = text_encodoing_strategy.drop_cached_text_encoder_outputs(*text_encoder_outputs_list)
batch["text_encoder_outputs_list"] = text_encoder_outputs_list batch["text_encoder_outputs_list"] = text_encoder_outputs_list
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()
def prepare_unet_with_accelerator( def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module: ) -> torch.nn.Module:

View File

@@ -9,6 +9,7 @@ import random
import time import time
import json import json
from multiprocessing import Value from multiprocessing import Value
import numpy as np
import toml import toml
from tqdm import tqdm from tqdm import tqdm
@@ -100,9 +101,7 @@ class NetworkTrainer:
if ( if (
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None
): # tracking d*lr value of unet. ): # tracking d*lr value of unet.
logs["lr/d*lr"] = ( logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"]
)
else: else:
idx = 0 idx = 0
if not args.network_train_unet_only: if not args.network_train_unet_only:
@@ -115,16 +114,56 @@ class NetworkTrainer:
logs[f"lr/d*lr/group{i}"] = ( logs[f"lr/d*lr/group{i}"] = (
lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"]
) )
if ( if args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None:
args.optimizer_type.lower().endswith("ProdigyPlusScheduleFree".lower()) and optimizer is not None logs[f"lr/d*lr/group{i}"] = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
):
logs[f"lr/d*lr/group{i}"] = (
optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"]
)
return logs return logs
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]): def step_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, global_step, global_step, epoch)
def epoch_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int):
self.accelerator_logging(accelerator, logs, epoch, global_step, epoch)
def val_logging(self, accelerator: Accelerator, logs: dict, global_step: int, epoch: int, val_step: int):
self.accelerator_logging(accelerator, logs, global_step + val_step, global_step, epoch, val_step)
def accelerator_logging(
self, accelerator: Accelerator, logs: dict, step_value: int, global_step: int, epoch: int, val_step: Optional[int] = None
):
"""
step_value is for tensorboard, other values are for wandb
"""
tensorboard_tracker = None
wandb_tracker = None
other_trackers = []
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
tensorboard_tracker = accelerator.get_tracker("tensorboard")
elif tracker.name == "wandb":
wandb_tracker = accelerator.get_tracker("wandb")
else:
other_trackers.append(accelerator.get_tracker(tracker.name))
if tensorboard_tracker is not None:
tensorboard_tracker.log(logs, step=step_value)
if wandb_tracker is not None:
logs["global_step"] = global_step
logs["epoch"] = epoch
if val_step is not None:
logs["val_step"] = val_step
wandb_tracker.log(logs)
for tracker in other_trackers:
tracker.log(logs, step=step_value)
def assert_extra_args(
self,
args,
train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset],
val_dataset_group: Optional[train_util.DatasetGroup],
):
train_dataset_group.verify_bucket_reso_steps(64) train_dataset_group.verify_bucket_reso_steps(64)
if val_dataset_group is not None: if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64) val_dataset_group.verify_bucket_reso_steps(64)
@@ -219,7 +258,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=True is_train=True,
): ):
# Sample noise, sample a random timestep for each image, and add noise to the latents, # Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified # with noise offset and/or multires noise if specified
@@ -309,7 +348,10 @@ class NetworkTrainer:
) -> torch.nn.Module: ) -> torch.nn.Module:
return accelerator.prepare(unet) return accelerator.prepare(unet)
def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype): def on_step_start(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train: bool = True):
pass
def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
pass pass
# endregion # endregion
@@ -330,7 +372,7 @@ class NetworkTrainer:
tokenize_strategy: strategy_base.TokenizeStrategy, tokenize_strategy: strategy_base.TokenizeStrategy,
is_train=True, is_train=True,
train_text_encoder=True, train_text_encoder=True,
train_unet=True train_unet=True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Process a batch for the network Process a batch for the network
@@ -397,7 +439,7 @@ class NetworkTrainer:
network, network,
weight_dtype, weight_dtype,
train_unet, train_unet,
is_train=is_train is_train=is_train,
) )
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler) huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
@@ -900,7 +942,9 @@ class NetworkTrainer:
accelerator.print("running training / 学習開始") accelerator.print("running training / 学習開始")
accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") accelerator.print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
accelerator.print(f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}") accelerator.print(
f" num validation images * repeats / 学習画像の数×繰り返し回数: {val_dataset_group.num_train_images if val_dataset_group is not None else 0}"
)
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}") accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
@@ -1243,12 +1287,6 @@ class NetworkTrainer:
# log empty object to commit the sample images to wandb # log empty object to commit the sample images to wandb
accelerator.log({}, step=0) accelerator.log({}, step=0)
validation_steps = (
min(args.max_validation_steps, len(val_dataloader))
if args.max_validation_steps is not None
else len(val_dataloader)
)
# training loop # training loop
if initial_step > 0: # only if skip_until_initial_step is specified if initial_step > 0: # only if skip_until_initial_step is specified
for skip_epoch in range(epoch_to_start): # skip epochs for skip_epoch in range(epoch_to_start): # skip epochs
@@ -1271,13 +1309,53 @@ class NetworkTrainer:
range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps"
) )
validation_steps = (
min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
)
NUM_VALIDATION_TIMESTEPS = 4 # 200, 400, 600, 800 TODO make this configurable
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.num_train_timesteps if args.max_timestep is None else args.max_timestep
validation_timesteps = np.linspace(min_timestep, max_timestep, (NUM_VALIDATION_TIMESTEPS + 2), dtype=int)[1:-1]
validation_total_steps = validation_steps * len(validation_timesteps)
original_args_min_timestep = args.min_timestep
original_args_max_timestep = args.max_timestep
def switch_rng_state(seed: int) -> tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]:
cpu_rng_state = torch.get_rng_state()
if accelerator.device.type == "cuda":
gpu_rng_state = torch.cuda.get_rng_state()
elif accelerator.device.type == "xpu":
gpu_rng_state = torch.xpu.get_rng_state()
elif accelerator.device.type == "mps":
gpu_rng_state = torch.cuda.get_rng_state()
else:
gpu_rng_state = None
python_rng_state = random.getstate()
torch.manual_seed(seed)
random.seed(seed)
return (cpu_rng_state, gpu_rng_state, python_rng_state)
def restore_rng_state(rng_states: tuple[torch.ByteTensor, Optional[torch.ByteTensor], tuple]):
cpu_rng_state, gpu_rng_state, python_rng_state = rng_states
torch.set_rng_state(cpu_rng_state)
if gpu_rng_state is not None:
if accelerator.device.type == "cuda":
torch.cuda.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "xpu":
torch.xpu.set_rng_state(gpu_rng_state)
elif accelerator.device.type == "mps":
torch.cuda.set_rng_state(gpu_rng_state)
random.setstate(python_rng_state)
for epoch in range(epoch_to_start, num_train_epochs): for epoch in range(epoch_to_start, num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n") accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}\n")
current_epoch.value = epoch + 1 current_epoch.value = epoch + 1
metadata["ss_epoch"] = str(epoch + 1) metadata["ss_epoch"] = str(epoch + 1)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) # network.train() is called here
# TRAINING # TRAINING
skipped_dataloader = None skipped_dataloader = None
@@ -1294,8 +1372,8 @@ class NetworkTrainer:
with accelerator.accumulate(training_model): with accelerator.accumulate(training_model):
on_step_start_for_network(text_encoder, unet) on_step_start_for_network(text_encoder, unet)
# temporary, for batch processing # preprocess batch for each model
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=True)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
@@ -1312,7 +1390,7 @@ class NetworkTrainer:
tokenize_strategy, tokenize_strategy,
is_train=True, is_train=True,
train_text_encoder=train_text_encoder, train_text_encoder=train_text_encoder,
train_unet=train_unet train_unet=train_unet,
) )
accelerator.backward(loss) accelerator.backward(loss)
@@ -1369,39 +1447,35 @@ class NetworkTrainer:
if args.scale_weight_norms: if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs}) progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking: if is_tracking:
logs = self.generate_step_logs( logs = self.generate_step_logs(
args, args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm
) )
accelerator.log(logs, step=global_step) self.step_logging(accelerator, logs, global_step, epoch + 1)
# VALIDATION PER STEP # VALIDATION PER STEP: global_step is already incremented
should_validate_step = ( # for example, if validate_every_n_steps=100, validate at step 100, 200, 300, ...
args.validate_every_n_steps is not None should_validate_step = args.validate_every_n_steps is not None and global_step % args.validate_every_n_steps == 0
and global_step != 0 # Skip first step
and global_step % args.validate_every_n_steps == 0
)
if accelerator.sync_gradients and validation_steps > 0 and should_validate_step: if accelerator.sync_gradients and validation_steps > 0 and should_validate_step:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
desc="validation steps" desc="validation steps",
) )
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
# temporary, for batch processing for timestep in validation_timesteps:
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
args.min_timestep = args.max_timestep = timestep # dirty hack to change timestep
loss = self.process_batch( loss = self.process_batch(
batch, batch,
@@ -1417,21 +1491,23 @@ class NetworkTrainer:
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=False, is_train=False,
train_text_encoder=False, train_text_encoder=train_text_encoder, # this is needed for validation because Text Encoders must be called if train_text_encoder is True
train_unet=False train_unet=train_unet,
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_step_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_step_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average }) val_progress_bar.set_postfix(
{"val_avg_loss": val_step_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking: # if is_tracking:
logs = { # logs = {f"loss/validation/step_current_{timestep}": current_loss}
"loss/validation/step_current": current_loss, # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
"val_step": (epoch * validation_steps) + val_step,
} self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
accelerator.log(logs, step=global_step) val_timesteps_step += 1
if is_tracking: if is_tracking:
loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average loss_validation_divergence = val_step_loss_recorder.moving_average - loss_recorder.moving_average
@@ -1439,31 +1515,45 @@ class NetworkTrainer:
"loss/validation/step_average": val_step_loss_recorder.moving_average, "loss/validation/step_average": val_step_loss_recorder.moving_average,
"loss/validation/step_divergence": loss_validation_divergence, "loss/validation/step_divergence": loss_validation_divergence,
} }
accelerator.log(logs, step=global_step) self.step_logging(accelerator, logs, global_step, epoch=epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
# EPOCH VALIDATION # EPOCH VALIDATION
should_validate_epoch = ( should_validate_epoch = (
(epoch + 1) % args.validate_every_n_epochs == 0 (epoch + 1) % args.validate_every_n_epochs == 0 if args.validate_every_n_epochs is not None else True
if args.validate_every_n_epochs is not None
else True
) )
if should_validate_epoch and len(val_dataloader) > 0: if should_validate_epoch and len(val_dataloader) > 0:
optimizer_eval_fn()
accelerator.unwrap_model(network).eval()
rng_states = switch_rng_state(args.validation_seed if args.validation_seed is not None else args.seed)
val_progress_bar = tqdm( val_progress_bar = tqdm(
range(validation_steps), smoothing=0, range(validation_total_steps),
smoothing=0,
disable=not accelerator.is_local_main_process, disable=not accelerator.is_local_main_process,
desc="epoch validation steps" desc="epoch validation steps",
) )
val_timesteps_step = 0
for val_step, batch in enumerate(val_dataloader): for val_step, batch in enumerate(val_dataloader):
if val_step >= validation_steps: if val_step >= validation_steps:
break break
for timestep in validation_timesteps:
args.min_timestep = args.max_timestep = timestep
# temporary, for batch processing # temporary, for batch processing
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype) self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype, is_train=False)
loss = self.process_batch( loss = self.process_batch(
batch, batch,
@@ -1479,22 +1569,23 @@ class NetworkTrainer:
text_encoding_strategy, text_encoding_strategy,
tokenize_strategy, tokenize_strategy,
is_train=False, is_train=False,
train_text_encoder=False, train_text_encoder=train_text_encoder,
train_unet=False train_unet=train_unet,
) )
current_loss = loss.detach().item() current_loss = loss.detach().item()
val_epoch_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) val_epoch_loss_recorder.add(epoch=epoch, step=val_timesteps_step, loss=current_loss)
val_progress_bar.update(1) val_progress_bar.update(1)
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average }) val_progress_bar.set_postfix(
{"val_epoch_avg_loss": val_epoch_loss_recorder.moving_average, "timestep": timestep}
)
if is_tracking: # if is_tracking:
logs = { # logs = {f"loss/validation/epoch_current_{timestep}": current_loss}
"loss/validation/epoch_current": current_loss, # self.val_logging(accelerator, logs, global_step, epoch + 1, val_step)
"epoch": epoch + 1,
"val_step": (epoch * validation_steps) + val_step self.on_validation_step_end(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
} val_timesteps_step += 1
accelerator.log(logs, step=global_step)
if is_tracking: if is_tracking:
avr_loss: float = val_epoch_loss_recorder.moving_average avr_loss: float = val_epoch_loss_recorder.moving_average
@@ -1502,14 +1593,20 @@ class NetworkTrainer:
logs = { logs = {
"loss/validation/epoch_average": avr_loss, "loss/validation/epoch_average": avr_loss,
"loss/validation/epoch_divergence": loss_validation_divergence, "loss/validation/epoch_divergence": loss_validation_divergence,
"epoch": epoch + 1
} }
accelerator.log(logs, step=global_step) self.epoch_logging(accelerator, logs, global_step, epoch + 1)
restore_rng_state(rng_states)
args.min_timestep = original_args_min_timestep
args.max_timestep = original_args_max_timestep
optimizer_train_fn()
accelerator.unwrap_model(network).train()
progress_bar.unpause()
# END OF EPOCH # END OF EPOCH
if is_tracking: if is_tracking:
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1} logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=global_step) self.epoch_logging(accelerator, logs, global_step, epoch + 1)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -1696,31 +1793,31 @@ def setup_parser() -> argparse.ArgumentParser:
"--validation_seed", "--validation_seed",
type=int, type=int,
default=None, default=None,
help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する" help="Validation seed for shuffling validation dataset, training `--seed` used otherwise / 検証データセットをシャッフルするための検証シード、それ以外の場合はトレーニング `--seed` を使用する",
) )
parser.add_argument( parser.add_argument(
"--validation_split", "--validation_split",
type=float, type=float,
default=0.0, default=0.0,
help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合" help="Split for validation images out of the training dataset / 学習画像から検証画像に分割する割合",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_steps", "--validate_every_n_steps",
type=int, type=int,
default=None, default=None,
help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます" help="Run validation on validation dataset every N steps. By default, validation will only occur every epoch if a validation dataset is available / 検証データセットの検証をNステップごとに実行します。デフォルトでは、検証データセットが利用可能な場合にのみ、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--validate_every_n_epochs", "--validate_every_n_epochs",
type=int, type=int,
default=None, default=None,
help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます" help="Run validation dataset every N epochs. By default, validation will run every epoch if a validation dataset is available / 検証データセットをNエポックごとに実行します。デフォルトでは、検証データセットが利用可能な場合、検証はエポックごとに実行されます",
) )
parser.add_argument( parser.add_argument(
"--max_validation_steps", "--max_validation_steps",
type=int, type=int,
default=None, default=None,
help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します" help="Max number of validation dataset items processed. By default, validation will run the entire validation dataset / 処理される検証データセット項目の最大数。デフォルトでは、検証は検証データセット全体を実行します",
) )
return parser return parser