mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
24
README.md
24
README.md
@@ -281,6 +281,30 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
|
||||
|
||||
## Change History
|
||||
|
||||
### Dec 3, 2023 / 2023/12/3
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- Min SNR Gamma with V-predicition (SD 2.1) is fixed. Thanks to feffy380! PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- See [#673](https://github.com/kohya-ss/sd-scripts/issues/673) for details.
|
||||
- `--min_diff` and `--clamp_quantile` options are added to `networks/extract_lora_from_models.py`. Thanks to wkpark! PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- The default values are same as the previous version.
|
||||
- Deep Shrink hires fix is supported in `sdxl_gen_img.py` and `gen_img_diffusers.py`.
|
||||
- `--ds_timesteps_1` and `--ds_timesteps_2` options denote the timesteps of the Deep Shrink for the first and second stages.
|
||||
- `--ds_depth_1` and `--ds_depth_2` options denote the depth (block index) of the Deep Shrink for the first and second stages.
|
||||
- `--ds_ratio` option denotes the ratio of the Deep Shrink. `0.5` means the half of the original latent size for the Deep Shrink.
|
||||
- `--dst1`, `--dst2`, `--dsd1`, `--dsd2` and `--dsr` prompt options are also available.
|
||||
|
||||
- `finetune\tag_images_by_wd14_tagger.py` で `--caption_separator` オプションでカンマ以外の区切り文字を指定できるようになりました。KohakuBlueleaf 氏に感謝します。 PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)
|
||||
- V-predicition (SD 2.1) での Min SNR Gamma が修正されました。feffy380 氏に感謝します。 PR[#934](https://github.com/kohya-ss/sd-scripts/pull/934)
|
||||
- 詳細は [#673](https://github.com/kohya-ss/sd-scripts/issues/673) を参照してください。
|
||||
- `networks/extract_lora_from_models.py` に `--min_diff` と `--clamp_quantile` オプションが追加されました。wkpark 氏に感謝します。 PR [#936](https://github.com/kohya-ss/sd-scripts/pull/936)
|
||||
- デフォルト値は前のバージョンと同じです。
|
||||
- `sdxl_gen_img.py` と `gen_img_diffusers.py` で Deep Shrink hires fix をサポートしました。
|
||||
- `--ds_timesteps_1` と `--ds_timesteps_2` オプションは Deep Shrink の第一段階と第二段階の timesteps を指定します。
|
||||
- `--ds_depth_1` と `--ds_depth_2` オプションは Deep Shrink の第一段階と第二段階の深さ(ブロックの index)を指定します。
|
||||
- `--ds_ratio` オプションは Deep Shrink の比率を指定します。`0.5` を指定すると Deep Shrink 適用時の latent は元のサイズの半分になります。
|
||||
- `--dst1`、`--dst2`、`--dsd1`、`--dsd2`、`--dsr` プロンプトオプションも使用できます。
|
||||
|
||||
### Nov 5, 2023 / 2023/11/5
|
||||
|
||||
- `sdxl_train.py` now supports different learning rates for each Text Encoder.
|
||||
|
||||
@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
|
||||
|
||||
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
|
||||
|
||||
- `--sample_at_first`
|
||||
|
||||
学習開始前にサンプル出力します。学習前との比較ができます。
|
||||
|
||||
- `--sample_prompts`
|
||||
|
||||
サンプル出力用プロンプトのファイルを指定します。
|
||||
|
||||
@@ -253,9 +253,6 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||
if args.full_fp16:
|
||||
train_util.patch_accelerator_for_fp16_training(accelerator)
|
||||
@@ -298,6 +295,9 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
@@ -4,13 +4,12 @@ import contextlib
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
from .hijacks import ipex_hijacks
|
||||
from .attention import attention_init
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
def ipex_init(): # pylint: disable=too-many-statements
|
||||
try:
|
||||
#Replace cuda with xpu:
|
||||
# Replace cuda with xpu:
|
||||
torch.cuda.current_device = torch.xpu.current_device
|
||||
torch.cuda.current_stream = torch.xpu.current_stream
|
||||
torch.cuda.device = torch.xpu.device
|
||||
@@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.FloatTensor = torch.xpu.FloatTensor
|
||||
torch.Tensor.cuda = torch.Tensor.xpu
|
||||
torch.Tensor.is_cuda = torch.Tensor.is_xpu
|
||||
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
|
||||
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
|
||||
torch.cuda._initialized = torch.xpu.lazy_init._initialized
|
||||
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
|
||||
@@ -90,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.CharStorage = torch.xpu.CharStorage
|
||||
torch.cuda.__file__ = torch.xpu.__file__
|
||||
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
|
||||
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
|
||||
|
||||
#Memory:
|
||||
# Memory:
|
||||
torch.cuda.memory = torch.xpu.memory
|
||||
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
|
||||
torch.xpu.empty_cache = lambda: None
|
||||
@@ -112,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
|
||||
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
|
||||
|
||||
#RNG:
|
||||
# RNG:
|
||||
torch.cuda.get_rng_state = torch.xpu.get_rng_state
|
||||
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
|
||||
torch.cuda.set_rng_state = torch.xpu.set_rng_state
|
||||
@@ -123,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.seed_all = torch.xpu.seed_all
|
||||
torch.cuda.initial_seed = torch.xpu.initial_seed
|
||||
|
||||
#AMP:
|
||||
# AMP:
|
||||
torch.cuda.amp = torch.xpu.amp
|
||||
if not hasattr(torch.cuda.amp, "common"):
|
||||
torch.cuda.amp.common = contextlib.nullcontext()
|
||||
@@ -138,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
|
||||
|
||||
#C
|
||||
# C
|
||||
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
|
||||
ipex._C._DeviceProperties.major = 2023
|
||||
ipex._C._DeviceProperties.minor = 2
|
||||
|
||||
#Fix functions with ipex:
|
||||
# Fix functions with ipex:
|
||||
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
|
||||
torch._utils._get_available_device_type = lambda: "xpu"
|
||||
torch.has_cuda = True
|
||||
@@ -164,12 +164,17 @@ def ipex_init(): # pylint: disable=too-many-statements
|
||||
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
|
||||
|
||||
ipex_hijacks()
|
||||
attention_init()
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
try:
|
||||
from .attention import attention_init
|
||||
attention_init()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
try:
|
||||
from .diffusers import ipex_diffusers
|
||||
ipex_diffusers()
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return True, None
|
||||
|
||||
@@ -74,6 +74,11 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
|
||||
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
|
||||
no_shape_one = False
|
||||
|
||||
if query.dtype != key.dtype:
|
||||
key = key.to(dtype=query.dtype)
|
||||
if query.dtype != value.dtype:
|
||||
value = value.to(dtype=query.dtype)
|
||||
|
||||
block_multiply = query.element_size()
|
||||
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
|
||||
block_size = batch_size_attention * slice_block_size
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
|
||||
import diffusers #0.21.1 # pylint: disable=import-error
|
||||
import diffusers #0.24.0 # pylint: disable=import-error
|
||||
from diffusers.models.attention_processor import Attention
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
|
||||
|
||||
# pylint: disable=protected-access, missing-function-docstring, line-too-long
|
||||
|
||||
device_supports_fp64 = torch.xpu.has_fp64_dtype()
|
||||
OptState = ipex.cpu.autocast._grad_scaler.OptState
|
||||
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
|
||||
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
|
||||
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
|
||||
|
||||
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
|
||||
assert self._scale is not None
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
if device_supports_fp64:
|
||||
inv_scale = self._scale.double().reciprocal().float()
|
||||
else:
|
||||
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
|
||||
found_inf = torch.full(
|
||||
(1,), 0.0, dtype=torch.float32, device=self._scale.device
|
||||
)
|
||||
|
||||
@@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs):
|
||||
else:
|
||||
return original_autocast(*args, **kwargs)
|
||||
|
||||
# Embedding BF16
|
||||
original_torch_cat = torch.cat
|
||||
def torch_cat(tensor, *args, **kwargs):
|
||||
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
|
||||
@@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs):
|
||||
else:
|
||||
return original_torch_cat(tensor, *args, **kwargs)
|
||||
|
||||
# Latent antialias:
|
||||
original_interpolate = torch.nn.functional.interpolate
|
||||
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
|
||||
if antialias or align_corners is not None:
|
||||
@@ -115,19 +117,29 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
|
||||
else:
|
||||
return original_linalg_solve(A, B, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def is_cuda(self):
|
||||
return self.device.type == 'xpu'
|
||||
|
||||
def ipex_hijacks():
|
||||
CondFunc('torch.tensor',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.to',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.Tensor.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.__init__',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.UntypedStorage.cuda',
|
||||
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
|
||||
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
|
||||
CondFunc('torch.empty',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.load',
|
||||
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
|
||||
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
|
||||
CondFunc('torch.randn',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
@@ -137,17 +149,19 @@ def ipex_hijacks():
|
||||
CondFunc('torch.zeros',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.tensor',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.linspace',
|
||||
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
|
||||
lambda orig_func, *args, device=None, **kwargs: check_device(device))
|
||||
CondFunc('torch.load',
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
|
||||
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
|
||||
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
|
||||
|
||||
CondFunc('torch.Generator',
|
||||
lambda orig_func, device=None: torch.xpu.Generator(device),
|
||||
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
|
||||
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
|
||||
|
||||
# TiledVAE and ControlNet:
|
||||
CondFunc('torch.batch_norm',
|
||||
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
|
||||
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
|
||||
@@ -159,38 +173,46 @@ def ipex_hijacks():
|
||||
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
|
||||
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
|
||||
|
||||
#Functions with dtype errors:
|
||||
# Functions with dtype errors:
|
||||
CondFunc('torch.nn.modules.GroupNorm.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# Training:
|
||||
CondFunc('torch.nn.modules.linear.Linear.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
CondFunc('torch.nn.modules.conv.Conv2d.forward',
|
||||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
|
||||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
|
||||
# BF16:
|
||||
CondFunc('torch.nn.functional.layer_norm',
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
|
||||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
|
||||
weight is not None and input.dtype != weight.data.dtype)
|
||||
# SwinIR BF16:
|
||||
CondFunc('torch.nn.functional.pad',
|
||||
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
|
||||
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
|
||||
|
||||
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
|
||||
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
|
||||
if not torch.xpu.has_fp64_dtype():
|
||||
CondFunc('torch.from_numpy',
|
||||
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
|
||||
lambda orig_func, ndarray: ndarray.dtype == float)
|
||||
|
||||
#Broken functions when torch.cuda.is_available is True:
|
||||
# Broken functions when torch.cuda.is_available is True:
|
||||
# Pin Memory:
|
||||
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
|
||||
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
|
||||
lambda orig_func, *args, **kwargs: True)
|
||||
|
||||
#Functions that make compile mad with CondFunc:
|
||||
# Functions that make compile mad with CondFunc:
|
||||
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
|
||||
torch.nn.DataParallel = DummyDataParallel
|
||||
torch.autocast = ipex_autocast
|
||||
torch.cat = torch_cat
|
||||
torch.linalg.solve = linalg_solve
|
||||
torch.UntypedStorage.is_cuda = is_cuda
|
||||
torch.nn.functional.interpolate = interpolate
|
||||
torch.backends.cuda.sdp_kernel = return_null_context
|
||||
|
||||
@@ -586,6 +586,9 @@ class CrossAttention(nn.Module):
|
||||
self.use_memory_efficient_attention_mem_eff = False
|
||||
self.use_sdpa = False
|
||||
|
||||
# Attention processor
|
||||
self.processor = None
|
||||
|
||||
def set_use_memory_efficient_attention(self, xformers, mem_eff):
|
||||
self.use_memory_efficient_attention_xformers = xformers
|
||||
self.use_memory_efficient_attention_mem_eff = mem_eff
|
||||
@@ -607,7 +610,28 @@ class CrossAttention(nn.Module):
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
def set_processor(self):
|
||||
return self.processor
|
||||
|
||||
def get_processor(self):
|
||||
return self.processor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None, **kwargs):
|
||||
if self.processor is not None:
|
||||
(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
) = translate_attention_names_from_diffusers(
|
||||
hidden_states=hidden_states, context=context, mask=mask, **kwargs
|
||||
)
|
||||
return self.processor(
|
||||
attn=self,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=context,
|
||||
attention_mask=mask,
|
||||
**kwargs
|
||||
)
|
||||
if self.use_memory_efficient_attention_xformers:
|
||||
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
|
||||
if self.use_memory_efficient_attention_mem_eff:
|
||||
@@ -720,6 +744,21 @@ class CrossAttention(nn.Module):
|
||||
out = self.to_out[0](out)
|
||||
return out
|
||||
|
||||
def translate_attention_names_from_diffusers(
|
||||
hidden_states: torch.FloatTensor,
|
||||
context: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
# HF naming
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None
|
||||
):
|
||||
# translate from hugging face diffusers
|
||||
context = context if context is not None else encoder_hidden_states
|
||||
|
||||
# translate from hugging face diffusers
|
||||
mask = mask if mask is not None else attention_mask
|
||||
|
||||
return hidden_states, context, mask
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.out_channels = OUT_CHANNELS
|
||||
|
||||
self.sample_size = sample_size
|
||||
self.prepare_config()
|
||||
self.prepare_config(sample_size=sample_size)
|
||||
|
||||
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
|
||||
|
||||
@@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
|
||||
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
|
||||
|
||||
# region diffusers compatibility
|
||||
def prepare_config(self):
|
||||
self.config = SimpleNamespace()
|
||||
def prepare_config(self, *args, **kwargs):
|
||||
self.config = SimpleNamespace(**kwargs)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
@@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
|
||||
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
|
||||
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
|
||||
|
||||
# temporary workaround for text_projection.weight.weight for Playground-v2
|
||||
if "text_projection.weight.weight" in new_sd:
|
||||
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
|
||||
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
|
||||
del new_sd["text_projection.weight.weight"]
|
||||
|
||||
return new_sd, logit_scale
|
||||
|
||||
|
||||
@@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
|
||||
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
|
||||
elif k.startswith("conditioner.embedders.1.model."):
|
||||
te2_sd[k] = state_dict.pop(k)
|
||||
|
||||
|
||||
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
|
||||
if "text_model.embeddings.position_ids" not in te1_sd:
|
||||
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)
|
||||
|
||||
@@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
|
||||
|
||||
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
|
||||
|
||||
|
||||
|
||||
@@ -2665,7 +2665,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
|
||||
"--optimizer_type",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
|
||||
)
|
||||
|
||||
# backward compatibility
|
||||
@@ -2979,6 +2979,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
||||
parser.add_argument(
|
||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_every_n_epochs",
|
||||
type=int,
|
||||
@@ -3384,7 +3387,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
|
||||
|
||||
|
||||
def get_optimizer(args, trainable_params):
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
|
||||
|
||||
optimizer_type = args.optimizer_type
|
||||
if args.use_8bit_adam:
|
||||
@@ -3488,6 +3491,20 @@ def get_optimizer(args, trainable_params):
|
||||
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW".lower():
|
||||
print(f"use PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
|
||||
try:
|
||||
optimizer_class = bnb.optim.PagedAdamW
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
"No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
|
||||
)
|
||||
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
|
||||
|
||||
elif optimizer_type == "PagedAdamW32bit".lower():
|
||||
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
|
||||
try:
|
||||
@@ -3897,17 +3914,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
# TODO remove this function in the future
|
||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||
|
||||
|
||||
def transform_models_if_DDP(models):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return [model.module if type(model) == DDP else model for model in models if model is not None]
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -3931,8 +3937,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
torch.cuda.empty_cache()
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
@@ -4447,11 +4451,118 @@ SCHEDULER_LINEAR_END = 0.0120
|
||||
SCHEDULER_TIMESTEPS = 1000
|
||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||
|
||||
def get_my_scheduler(
|
||||
*,
|
||||
sample_sampler: str,
|
||||
v_parameterization: bool,
|
||||
):
|
||||
sched_init_args = {}
|
||||
if sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif sample_sampler == "lms" or sample_sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif sample_sampler == "euler" or sample_sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = sample_sampler
|
||||
elif sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if v_parameterization:
|
||||
sched_init_args["prediction_type"] = "v_prediction"
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if (
|
||||
hasattr(scheduler.config, "clip_sample")
|
||||
and scheduler.config.clip_sample is False
|
||||
):
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
|
||||
return scheduler
|
||||
|
||||
|
||||
def sample_images(*args, **kwargs):
|
||||
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
|
||||
|
||||
|
||||
def line_to_prompt_dict(line: str) -> dict:
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = line.split(" --")
|
||||
prompt_dict = {}
|
||||
prompt_dict['prompt'] = prompt_args[0]
|
||||
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict['width'] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict['height'] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict['seed'] = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
prompt_dict['scale'] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict['negative_prompt'] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict['sample_sampler'] = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
prompt_dict['controlnet_image'] = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
|
||||
return prompt_dict
|
||||
|
||||
def sample_images_common(
|
||||
pipe_class,
|
||||
accelerator,
|
||||
@@ -4469,15 +4580,19 @@ def sample_images_common(
|
||||
"""
|
||||
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
|
||||
"""
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
|
||||
return
|
||||
if args.sample_every_n_epochs is not None:
|
||||
# sample_every_n_steps は無視する
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
|
||||
if not os.path.isfile(args.sample_prompts):
|
||||
@@ -4504,56 +4619,19 @@ def sample_images_common(
|
||||
with open(args.sample_prompts, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# schedulerを用意する
|
||||
sched_init_args = {}
|
||||
if args.sample_sampler == "ddim":
|
||||
scheduler_cls = DDIMScheduler
|
||||
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
|
||||
scheduler_cls = DDPMScheduler
|
||||
elif args.sample_sampler == "pndm":
|
||||
scheduler_cls = PNDMScheduler
|
||||
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
|
||||
scheduler_cls = LMSDiscreteScheduler
|
||||
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
|
||||
scheduler_cls = EulerAncestralDiscreteScheduler
|
||||
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
|
||||
scheduler_cls = DPMSolverMultistepScheduler
|
||||
sched_init_args["algorithm_type"] = args.sample_sampler
|
||||
elif args.sample_sampler == "dpmsingle":
|
||||
scheduler_cls = DPMSolverSinglestepScheduler
|
||||
elif args.sample_sampler == "heun":
|
||||
scheduler_cls = HeunDiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
|
||||
scheduler_cls = KDPM2DiscreteScheduler
|
||||
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
|
||||
scheduler_cls = KDPM2AncestralDiscreteScheduler
|
||||
else:
|
||||
scheduler_cls = DDIMScheduler
|
||||
|
||||
if args.v_parameterization:
|
||||
sched_init_args["prediction_type"] = "v_prediction"
|
||||
|
||||
scheduler = scheduler_cls(
|
||||
num_train_timesteps=SCHEDULER_TIMESTEPS,
|
||||
beta_start=SCHEDULER_LINEAR_START,
|
||||
beta_end=SCHEDULER_LINEAR_END,
|
||||
beta_schedule=SCHEDLER_SCHEDULE,
|
||||
**sched_init_args,
|
||||
schedulers: dict = {}
|
||||
default_scheduler = get_my_scheduler(
|
||||
sample_sampler=args.sample_sampler,
|
||||
v_parameterization=args.v_parameterization,
|
||||
)
|
||||
|
||||
# clip_sample=Trueにする
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||
# print("set clip_sample to True")
|
||||
scheduler.config.clip_sample = True
|
||||
schedulers[args.sample_sampler] = default_scheduler
|
||||
|
||||
pipeline = pipe_class(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
scheduler=default_scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
@@ -4569,78 +4647,34 @@ def sample_images_common(
|
||||
|
||||
with torch.no_grad():
|
||||
# with accelerator.autocast():
|
||||
for i, prompt in enumerate(prompts):
|
||||
for i, prompt_dict in enumerate(prompts):
|
||||
if not accelerator.is_main_process:
|
||||
continue
|
||||
|
||||
if isinstance(prompt, dict):
|
||||
negative_prompt = prompt.get("negative_prompt")
|
||||
sample_steps = prompt.get("sample_steps", 30)
|
||||
width = prompt.get("width", 512)
|
||||
height = prompt.get("height", 512)
|
||||
scale = prompt.get("scale", 7.5)
|
||||
seed = prompt.get("seed")
|
||||
controlnet_image = prompt.get("controlnet_image")
|
||||
prompt = prompt.get("prompt")
|
||||
else:
|
||||
# prompt = prompt.strip()
|
||||
# if len(prompt) == 0 or prompt[0] == "#":
|
||||
# continue
|
||||
if isinstance(prompt_dict, str):
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
|
||||
# subset of gen_img_diffusers
|
||||
prompt_args = prompt.split(" --")
|
||||
prompt = prompt_args[0]
|
||||
negative_prompt = None
|
||||
sample_steps = 30
|
||||
width = height = 512
|
||||
scale = 7.5
|
||||
seed = None
|
||||
controlnet_image = None
|
||||
for parg in prompt_args:
|
||||
try:
|
||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
width = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
height = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||
if m:
|
||||
seed = int(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||
if m: # steps
|
||||
sample_steps = max(1, min(1000, int(m.group(1))))
|
||||
continue
|
||||
|
||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # scale
|
||||
scale = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
negative_prompt = m.group(1)
|
||||
continue
|
||||
|
||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
controlnet_image = m.group(1)
|
||||
continue
|
||||
|
||||
except ValueError as ex:
|
||||
print(f"Exception in parsing / 解析エラー: {parg}")
|
||||
print(ex)
|
||||
assert isinstance(prompt_dict, dict)
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 30)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 7.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scheduler = schedulers.get(sampler_name)
|
||||
if scheduler is None:
|
||||
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
|
||||
schedulers[sampler_name] = scheduler
|
||||
pipeline.scheduler = scheduler
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
@@ -4658,6 +4692,9 @@ def sample_images_common(
|
||||
print(f"width: {width}")
|
||||
print(f"sample_steps: {sample_steps}")
|
||||
print(f"scale: {scale}")
|
||||
print(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
print(f"seed: {seed}")
|
||||
with accelerator.autocast():
|
||||
latents = pipeline(
|
||||
prompt=prompt,
|
||||
|
||||
@@ -515,7 +515,8 @@ class PipelineLike:
|
||||
uncond_embeddings = tes_uncond_embs[0]
|
||||
for i in range(1, len(tes_text_embs)):
|
||||
text_embeddings = torch.cat([text_embeddings, tes_text_embs[i]], dim=2) # n,77,2048
|
||||
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
|
||||
if do_classifier_free_guidance:
|
||||
uncond_embeddings = torch.cat([uncond_embeddings, tes_uncond_embs[i]], dim=2) # n,77,2048
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_scale is None:
|
||||
@@ -578,9 +579,11 @@ class PipelineLike:
|
||||
text_pool = clip_vision_embeddings # replace: same as ComfyUI (?)
|
||||
|
||||
c_vector = torch.cat([text_pool, c_vector], dim=1)
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
|
||||
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||
if do_classifier_free_guidance:
|
||||
uc_vector = torch.cat([uncond_pool, uc_vector], dim=1)
|
||||
vector_embeddings = torch.cat([uc_vector, c_vector])
|
||||
else:
|
||||
vector_embeddings = c_vector
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, self.device)
|
||||
|
||||
@@ -397,13 +397,10 @@ def train(args):
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||
if train_text_encoder1:
|
||||
text_encoder1 = accelerator.prepare(text_encoder1)
|
||||
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
|
||||
if train_text_encoder2:
|
||||
text_encoder2 = accelerator.prepare(text_encoder2)
|
||||
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
|
||||
|
||||
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
@@ -461,6 +458,11 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
sdxl_train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
@@ -283,9 +283,6 @@ def train(args):
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
unet = train_util.transform_models_if_DDP([unet])[0]
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
|
||||
@@ -254,9 +254,6 @@ def train(args):
|
||||
)
|
||||
network: control_net_lllite.ControlNetLLLite
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
|
||||
else:
|
||||
|
||||
@@ -11,10 +11,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -335,7 +338,9 @@ def train(args):
|
||||
init_kwargs = {}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
accelerator.init_trackers(
|
||||
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
|
||||
)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
del train_dataset_group
|
||||
@@ -371,6 +376,11 @@ def train(args):
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(
|
||||
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
|
||||
14
train_db.py
14
train_db.py
@@ -112,6 +112,7 @@ def train(args):
|
||||
|
||||
# mixed precisionに対応した型を用意しておき適宜castする
|
||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||
|
||||
# モデルを読み込む
|
||||
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
@@ -136,7 +137,7 @@ def train(args):
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
@@ -225,9 +226,6 @@ def train(args):
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
if not train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
|
||||
|
||||
@@ -274,6 +272,9 @@ def train(args):
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
|
||||
|
||||
# For --sample_at_first
|
||||
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -482,6 +483,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_half_vae",
|
||||
action="store_true",
|
||||
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
@@ -127,6 +128,11 @@ class NetworkTrainer:
|
||||
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
|
||||
return noise_pred
|
||||
|
||||
def all_reduce_network(self, accelerator, network):
|
||||
for param in network.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad = accelerator.reduce(param.grad, reduction="mean")
|
||||
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -390,46 +396,20 @@ class NetworkTrainer:
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
# TODO めちゃくちゃ冗長なのでコードを整理する
|
||||
if train_unet and train_text_encoder:
|
||||
if train_unet:
|
||||
unet = accelerator.prepare(unet)
|
||||
else:
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
if train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoders = [t_enc1, t_enc2]
|
||||
del t_enc1, t_enc2
|
||||
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
|
||||
else:
|
||||
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = accelerator.prepare(text_encoder)
|
||||
text_encoders = [text_encoder]
|
||||
elif train_unet:
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
for t_enc in text_encoders:
|
||||
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||
elif train_text_encoder:
|
||||
if len(text_encoders) > 1:
|
||||
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoder = text_encoders = [t_enc1, t_enc2]
|
||||
del t_enc1, t_enc2
|
||||
else:
|
||||
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder, network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
text_encoders = [text_encoder]
|
||||
|
||||
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
|
||||
else:
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
network, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare (train_network here only)
|
||||
text_encoders = train_util.transform_models_if_DDP(text_encoders)
|
||||
unet, network = train_util.transform_models_if_DDP([unet, network])
|
||||
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
# according to TI example in Diffusers, train is required
|
||||
@@ -451,7 +431,7 @@ class NetworkTrainer:
|
||||
|
||||
del t_enc
|
||||
|
||||
network.prepare_grad_etc(text_encoder, unet)
|
||||
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
|
||||
|
||||
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||
vae.requires_grad_(False)
|
||||
@@ -714,8 +694,8 @@ class NetworkTrainer:
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
if hasattr(network, "on_step_start"):
|
||||
on_step_start = network.on_step_start
|
||||
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
|
||||
on_step_start = accelerator.unwrap_model(network).on_step_start
|
||||
else:
|
||||
on_step_start = lambda *args, **kwargs: None
|
||||
|
||||
@@ -743,6 +723,9 @@ class NetworkTrainer:
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
@@ -750,7 +733,7 @@ class NetworkTrainer:
|
||||
|
||||
metadata["ss_epoch"] = str(epoch + 1)
|
||||
|
||||
network.on_epoch_start(text_encoder, unet)
|
||||
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
current_step.value = global_step
|
||||
@@ -823,8 +806,9 @@ class NetworkTrainer:
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
accelerator.backward(loss)
|
||||
self.all_reduce_network(accelerator, network) # sync DDP grad manually
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
@@ -832,7 +816,7 @@ class NetworkTrainer:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if args.scale_weight_norms:
|
||||
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
|
||||
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
|
||||
args.scale_weight_norms, accelerator.device
|
||||
)
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
|
||||
@@ -7,10 +7,13 @@ import toml
|
||||
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
if torch.xpu.is_available():
|
||||
from library.ipex import ipex_init
|
||||
|
||||
ipex_init()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -415,15 +418,11 @@ class TextualInversionTrainer:
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
# transform DDP after prepare
|
||||
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
|
||||
|
||||
elif len(text_encoders) == 2:
|
||||
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
# transform DDP after prepare
|
||||
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
|
||||
|
||||
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
|
||||
|
||||
@@ -529,6 +528,20 @@ class TextualInversionTrainer:
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# For --sample_at_first
|
||||
self.sample_images(
|
||||
accelerator,
|
||||
args,
|
||||
0,
|
||||
global_step,
|
||||
accelerator.device,
|
||||
vae,
|
||||
tokenizer_or_list,
|
||||
text_encoder_or_list,
|
||||
unet,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
|
||||
@@ -333,9 +333,6 @@ def train(args):
|
||||
text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# transform DDP after prepare
|
||||
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
|
||||
|
||||
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
|
||||
# print(len(index_no_updates), torch.sum(index_no_updates))
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()
|
||||
|
||||
Reference in New Issue
Block a user