Merge pull request #1012 from kohya-ss/dev

merge dev to main
This commit is contained in:
Kohya S
2023-12-21 22:18:39 +09:00
committed by GitHub
25 changed files with 573 additions and 344 deletions

View File

@@ -249,6 +249,34 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
## Change History
### Dec 21, 2023 / 2023/12/21
- The issues in multi-GPU training are fixed. Thanks to Isotr0py! PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) and [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
- `--ddp_gradient_as_bucket_view` and `--ddp_bucket_view`options are added to `sdxl_train.py`. Please specify these options for multi-GPU training.
- IPEX support is updated. Thanks to Disty0!
- Fixed the bug that the size of the bucket becomes less than `min_bucket_reso`. Thanks to Cauldrath! PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
- `--sample_at_first` option is added to each training script. This option is useful to generate images at the first step, before training. Thanks to shirayu! PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
- `--ss` option is added to the sampling prompt in training. You can specify the scheduler for the sampling like `--ss euler_a`. Thanks to shirayu! PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
- `keep_tokens_separator` is added to the dataset config. This option is useful to keep (prevent from shuffling) the tokens in the captions. See [#975](https://github.com/kohya-ss/sd-scripts/pull/975) for details. Thanks to Linaqruf!
- You can specify the separator with an option like `--keep_tokens_separator "|||"` or with `keep_tokens_separator: "|||"` in `.toml`. The tokens before `|||` are not shuffled.
- Attention processor hook is added. See [#961](https://github.com/kohya-ss/sd-scripts/pull/961) for details. Thanks to rockerBOO!
- The optimizer `PagedAdamW` is added. Thanks to xzuyn! PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
- NaN replacement in SDXL VAE is sped up. Thanks to liubo0902! PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
- Fixed the path error in `finetune/make_captions.py`. Thanks to CjangCjengh! PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
- マルチGPUでの学習の不具合を修正しました。Isotr0py 氏に感謝します。 PR [#989](https://github.com/kohya-ss/sd-scripts/pull/989) および [#1000](https://github.com/kohya-ss/sd-scripts/pull/1000)
- `sdxl_train.py``--ddp_gradient_as_bucket_view``--ddp_bucket_view` オプションが追加されました。マルチGPUでの学習時にはこれらのオプションを指定してください。
- IPEX サポートが更新されました。Disty0 氏に感謝します。
- Aspect Ratio Bucketing で bucket のサイズが `min_bucket_reso` 未満になる不具合を修正しました。Cauldrath 氏に感謝します。 PR [#1008](https://github.com/kohya-ss/sd-scripts/pull/1008)
- 各学習スクリプトに `--sample_at_first` オプションが追加されました。学習前に画像を生成することで、学習結果が比較しやすくなります。shirayu 氏に感謝します。 PR [#907](https://github.com/kohya-ss/sd-scripts/pull/907)
- 学習時のプロンプトに `--ss` オプションが追加されました。`--ss euler_a` のようにスケジューラを指定できます。shirayu 氏に感謝します。 PR [#906](https://github.com/kohya-ss/sd-scripts/pull/906)
- データセット設定に `keep_tokens_separator` が追加されました。キャプション内のトークンをどの位置までシャッフルしないかを指定できます。詳細は [#975](https://github.com/kohya-ss/sd-scripts/pull/975) を参照してください。Linaqruf 氏に感謝します。
- オプションで `--keep_tokens_separator "|||"` のように指定するか、`.toml``keep_tokens_separator: "|||"` のように指定します。`|||` の前のトークンはシャッフルされません。
- Attention processor hook が追加されました。詳細は [#961](https://github.com/kohya-ss/sd-scripts/pull/961) を参照してください。rockerBOO 氏に感謝します。
- オプティマイザ `PagedAdamW` が追加されました。xzuyn 氏に感謝します。 PR [#955](https://github.com/kohya-ss/sd-scripts/pull/955)
- 学習時、SDXL VAE で NaN が発生した時の置き換えが高速化されました。liubo0902 氏に感謝します。 PR [#1009](https://github.com/kohya-ss/sd-scripts/pull/1009)
- `finetune/make_captions.py` で相対パス指定時のエラーが修正されました。CjangCjengh 氏に感謝します。 PR [#986](https://github.com/kohya-ss/sd-scripts/pull/986)
### Dec 3, 2023 / 2023/12/3
- `finetune\tag_images_by_wd14_tagger.py` now supports the separator other than `,` with `--caption_separator` option. Thanks to KohakuBlueleaf! PR [#913](https://github.com/kohya-ss/sd-scripts/pull/913)

View File

@@ -374,6 +374,10 @@ classがひとつで対象が複数の場合、正則化画像フォルダはひ
サンプル出力するステップ数またはエポック数を指定します。この数ごとにサンプル出力します。両方指定するとエポック数が優先されます。
- `--sample_at_first`
学習開始前にサンプル出力します。学習前との比較ができます。
- `--sample_prompts`
サンプル出力用プロンプトのファイルを指定します。

View File

@@ -253,9 +253,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
@@ -298,6 +295,9 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")

View File

@@ -13,7 +13,7 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
sys.path.append(os.path.dirname(__file__))
from blip.blip import blip_decoder
from blip.blip import blip_decoder, is_url
import library.train_util as train_util
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -76,6 +76,8 @@ def main(args):
cwd = os.getcwd()
print("Current Working Directory is: ", cwd)
os.chdir("finetune")
if not is_url(args.caption_weights) and not os.path.isfile(args.caption_weights):
args.caption_weights = os.path.join("..", args.caption_weights)
print(f"load images from {args.train_data_dir}")
train_data_dir_path = Path(args.train_data_dir)

View File

@@ -53,6 +53,7 @@ class BaseSubsetParams:
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
keep_tokens_separator: str = None,
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
@@ -160,6 +161,7 @@ class ConfigSanitizer:
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
@@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}

View File

@@ -4,13 +4,12 @@ import contextlib
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
from .hijacks import ipex_hijacks
from .attention import attention_init
# pylint: disable=protected-access, missing-function-docstring, line-too-long
def ipex_init(): # pylint: disable=too-many-statements
try:
#Replace cuda with xpu:
# Replace cuda with xpu:
torch.cuda.current_device = torch.xpu.current_device
torch.cuda.current_stream = torch.xpu.current_stream
torch.cuda.device = torch.xpu.device
@@ -30,6 +29,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.Tensor.cuda = torch.Tensor.xpu
torch.Tensor.is_cuda = torch.Tensor.is_xpu
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
@@ -90,9 +90,9 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.__file__ = torch.xpu.__file__
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
#torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing
#Memory:
# Memory:
torch.cuda.memory = torch.xpu.memory
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
@@ -112,7 +112,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.memory_stats_as_nested_dict = torch.xpu.memory_stats_as_nested_dict
torch.cuda.reset_accumulated_memory_stats = torch.xpu.reset_accumulated_memory_stats
#RNG:
# RNG:
torch.cuda.get_rng_state = torch.xpu.get_rng_state
torch.cuda.get_rng_state_all = torch.xpu.get_rng_state_all
torch.cuda.set_rng_state = torch.xpu.set_rng_state
@@ -123,7 +123,7 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed
#AMP:
# AMP:
torch.cuda.amp = torch.xpu.amp
if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
@@ -138,12 +138,12 @@ def ipex_init(): # pylint: disable=too-many-statements
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
#C
# C
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentStream
ipex._C._DeviceProperties.major = 2023
ipex._C._DeviceProperties.minor = 2
#Fix functions with ipex:
# Fix functions with ipex:
torch.cuda.mem_get_info = lambda device=None: [(torch.xpu.get_device_properties(device).total_memory - torch.xpu.memory_reserved(device)), torch.xpu.get_device_properties(device).total_memory]
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
@@ -156,20 +156,14 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.get_device_properties.minor = 7
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0
if hasattr(torch.xpu, 'getDeviceIdListForCard'):
torch.cuda.getDeviceIdListForCard = torch.xpu.getDeviceIdListForCard
torch.cuda.get_device_id_list_per_card = torch.xpu.getDeviceIdListForCard
else:
torch.cuda.getDeviceIdListForCard = torch.xpu.get_device_id_list_per_card
torch.cuda.get_device_id_list_per_card = torch.xpu.get_device_id_list_per_card
ipex_hijacks()
attention_init()
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
if not torch.xpu.has_fp64_dtype():
try:
from .diffusers import ipex_diffusers
ipex_diffusers()
except Exception: # pylint: disable=broad-exception-caught
pass
except Exception as e:
return False, e
return True, None

View File

@@ -4,11 +4,8 @@ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unuse
# pylint: disable=protected-access, missing-function-docstring, line-too-long
original_torch_bmm = torch.bmm
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
def torch_bmm_32_bit(input, mat2, *, out=None):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
block_multiply = input.element_size()
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
@@ -17,28 +14,27 @@ def torch_bmm(input, mat2, *, out=None):
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the input_tokens
# Find something divisible with the input_tokens
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size_2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
# Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size_2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
else:
do_split = False
split_2_slice_size = input_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the input_tokens
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(input.shape[0], input.shape[1], mat2.shape[2], device=input.device, dtype=input.dtype)
for i in range(batch_size_attention // split_slice_size):
@@ -64,45 +60,54 @@ def torch_bmm(input, mat2, *, out=None):
return hidden_states
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
if len(query.shape) == 3:
batch_size_attention, query_tokens, shape_four = query.shape
shape_one = 1
no_shape_one = True
batch_size_attention, query_tokens, shape_three = query.shape
shape_four = 1
else:
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
no_shape_one = False
batch_size_attention, query_tokens, shape_three, shape_four = query.shape
block_multiply = query.element_size()
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
slice_block_size = query_tokens * shape_three * shape_four / 1024 / 1024 * block_multiply
block_size = batch_size_attention * slice_block_size
split_slice_size = batch_size_attention
if block_size > 4:
do_split = True
#Find something divisible with the shape_one
# Find something divisible with the batch_size_attention
while (split_slice_size * slice_block_size) > 4:
split_slice_size = split_slice_size // 2
if split_slice_size <= 1:
split_slice_size = 1
break
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size_2 = split_slice_size * shape_three * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
# Find something divisible with the query_tokens
while (split_2_slice_size * slice_block_size_2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
split_3_slice_size = shape_three
if split_2_slice_size * slice_block_size_2 > 4:
slice_block_size_3 = split_slice_size * split_2_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_3 = True
# Find something divisible with the shape_three
while (split_3_slice_size * slice_block_size_3) > 4:
split_3_slice_size = split_3_slice_size // 2
if split_3_slice_size <= 1:
split_3_slice_size = 1
break
else:
do_split_3 = False
else:
do_split_2 = False
else:
do_split = False
split_2_slice_size = query_tokens
if split_slice_size * slice_block_size > 4:
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
do_split_2 = True
#Find something divisible with the batch_size_attention
while (split_2_slice_size * slice_block_size2) > 4:
split_2_slice_size = split_2_slice_size // 2
if split_2_slice_size <= 1:
split_2_slice_size = 1
break
else:
do_split_2 = False
if do_split:
hidden_states = torch.zeros(query.shape, device=query.device, dtype=query.dtype)
for i in range(batch_size_attention // split_slice_size):
@@ -112,7 +117,18 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
for i2 in range(query_tokens // split_2_slice_size): # pylint: disable=invalid-name
start_idx_2 = i2 * split_2_slice_size
end_idx_2 = (i2 + 1) * split_2_slice_size
if no_shape_one:
if do_split_3:
for i3 in range(shape_three // split_3_slice_size): # pylint: disable=invalid-name
start_idx_3 = i3 * split_3_slice_size
end_idx_3 = (i3 + 1) * split_3_slice_size
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
key[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
value[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3],
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2, start_idx_3:end_idx_3] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[start_idx:end_idx, start_idx_2:end_idx_2],
key[start_idx:end_idx, start_idx_2:end_idx_2],
@@ -120,38 +136,16 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
attn_mask=attn_mask[start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx, start_idx_2:end_idx_2] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx, start_idx_2:end_idx_2],
key[:, start_idx:end_idx, start_idx_2:end_idx_2],
value[:, start_idx:end_idx, start_idx_2:end_idx_2],
attn_mask=attn_mask[:, start_idx:end_idx, start_idx_2:end_idx_2] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
if no_shape_one:
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
hidden_states[:, start_idx:end_idx] = original_scaled_dot_product_attention(
query[:, start_idx:end_idx],
key[:, start_idx:end_idx],
value[:, start_idx:end_idx],
attn_mask=attn_mask[:, start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
hidden_states[start_idx:end_idx] = original_scaled_dot_product_attention(
query[start_idx:end_idx],
key[start_idx:end_idx],
value[start_idx:end_idx],
attn_mask=attn_mask[start_idx:end_idx] if attn_mask is not None else attn_mask,
dropout_p=dropout_p, is_causal=is_causal
)
else:
return original_scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
)
return hidden_states
def attention_init():
#ARC GPUs can't allocate more than 4GB to a single block:
torch.bmm = torch_bmm
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,6 +1,6 @@
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import diffusers #0.21.1 # pylint: disable=import-error
import diffusers #0.24.0 # pylint: disable=import-error
from diffusers.models.attention_processor import Attention
# pylint: disable=protected-access, missing-function-docstring, line-too-long

View File

@@ -5,6 +5,7 @@ import intel_extension_for_pytorch._C as core # pylint: disable=import-error, un
# pylint: disable=protected-access, missing-function-docstring, line-too-long
device_supports_fp64 = torch.xpu.has_fp64_dtype()
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state
@@ -96,7 +97,10 @@ def unscale_(self, optimizer):
# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)

View File

@@ -89,6 +89,7 @@ def ipex_autocast(*args, **kwargs):
else:
return original_autocast(*args, **kwargs)
# Embedding BF16
original_torch_cat = torch.cat
def torch_cat(tensor, *args, **kwargs):
if len(tensor) == 3 and (tensor[0].dtype != tensor[1].dtype or tensor[2].dtype != tensor[1].dtype):
@@ -96,6 +97,7 @@ def torch_cat(tensor, *args, **kwargs):
else:
return original_torch_cat(tensor, *args, **kwargs)
# Latent antialias:
original_interpolate = torch.nn.functional.interpolate
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
if antialias or align_corners is not None:
@@ -115,19 +117,54 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
else:
return original_linalg_solve(A, B, *args, **kwargs)
if torch.xpu.has_fp64_dtype():
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 64 bit attention workarounds for Alchemist:
try:
from .attention import torch_bmm_32_bit as original_torch_bmm
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
except Exception: # pylint: disable=broad-exception-caught
original_torch_bmm = torch.bmm
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
# dtype errors:
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
return original_torch_bmm(input, mat2, out=out)
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
if query.dtype != key.dtype:
key = key.to(dtype=query.dtype)
if query.dtype != value.dtype:
value = value.to(dtype=query.dtype)
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
@property
def is_cuda(self):
return self.device.type == 'xpu'
def ipex_hijacks():
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.Tensor.to',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.Tensor.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.__init__',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.UntypedStorage.cuda',
lambda orig_func, self, device=None, *args, **kwargs: orig_func(self, return_xpu(device), *args, **kwargs),
lambda orig_func, self, device=None, *args, **kwargs: check_device(device))
CondFunc('torch.empty',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, *args, map_location=None, **kwargs: orig_func(*args, return_xpu(map_location), **kwargs),
lambda orig_func, *args, map_location=None, **kwargs: map_location is None or check_device(map_location))
CondFunc('torch.randn',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
@@ -137,17 +174,23 @@ def ipex_hijacks():
CondFunc('torch.zeros',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.tensor',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.linspace',
lambda orig_func, *args, device=None, **kwargs: orig_func(*args, device=return_xpu(device), **kwargs),
lambda orig_func, *args, device=None, **kwargs: check_device(device))
CondFunc('torch.load',
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs:
orig_func(orig_func, f, map_location=return_xpu(map_location), pickle_module=pickle_module, weights_only=weights_only, mmap=mmap, **kwargs),
lambda orig_func, f, map_location=None, pickle_module=None, *, weights_only=False, mmap=None, **kwargs: check_device(map_location))
if hasattr(torch.xpu, "Generator"):
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(return_xpu(device)),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
else:
CondFunc('torch.Generator',
lambda orig_func, device=None: orig_func(return_xpu(device)),
lambda orig_func, device=None: check_device(device))
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device != torch.device("cpu") and device != "cpu")
# TiledVAE and ControlNet:
CondFunc('torch.batch_norm',
lambda orig_func, input, weight, bias, *args, **kwargs: orig_func(input,
weight if weight is not None else torch.ones(input.size()[1], device=input.device),
@@ -159,38 +202,51 @@ def ipex_hijacks():
bias if bias is not None else torch.zeros(input.size()[1], device=input.device), *args, **kwargs),
lambda orig_func, input, *args, **kwargs: input.device != torch.device("cpu"))
#Functions with dtype errors:
# Functions with dtype errors:
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# Training:
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
# BF16:
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
# SwinIR BF16:
CondFunc('torch.nn.functional.pad',
lambda orig_func, input, pad, mode='constant', value=None: orig_func(input.to(torch.float32), pad, mode=mode, value=value).to(dtype=torch.bfloat16),
lambda orig_func, input, pad, mode='constant', value=None: mode == 'reflect' and input.dtype == torch.bfloat16)
#Diffusers Float64 (ARC GPUs doesn't support double or Float64):
# Diffusers Float64 (Alchemist GPUs doesn't support 64 bit):
if not torch.xpu.has_fp64_dtype():
CondFunc('torch.from_numpy',
lambda orig_func, ndarray: orig_func(ndarray.astype('float32')),
lambda orig_func, ndarray: ndarray.dtype == float)
#Broken functions when torch.cuda.is_available is True:
# Broken functions when torch.cuda.is_available is True:
# Pin Memory:
CondFunc('torch.utils.data.dataloader._BaseDataLoaderIter.__init__',
lambda orig_func, *args, **kwargs: ipex_no_cuda(orig_func, *args, **kwargs),
lambda orig_func, *args, **kwargs: True)
#Functions that make compile mad with CondFunc:
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
# Functions that make compile mad with CondFunc:
torch.nn.DataParallel = DummyDataParallel
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
torch.autocast = ipex_autocast
torch.cat = torch_cat
torch.linalg.solve = linalg_solve
torch.nn.functional.interpolate = interpolate
torch.backends.cuda.sdp_kernel = return_null_context
torch.UntypedStorage.is_cuda = is_cuda
torch.nn.functional.interpolate = interpolate
torch.linalg.solve = linalg_solve
torch.bmm = torch_bmm
torch.cat = torch_cat
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1307,19 +1307,19 @@ def load_vae(vae_id, dtype):
def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
max_area = max_width * max_height
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
width = int(math.sqrt(max_area) // divisible) * divisible
resos.add((width, width))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
width = min_size
while width <= max_size:
height = min(max_size, int((max_area // width) // divisible) * divisible)
if height >= min_size:
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
@@ -1329,7 +1329,7 @@ def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64)
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
width += divisible
resos = list(resos)
resos.sort()

View File

@@ -586,6 +586,9 @@ class CrossAttention(nn.Module):
self.use_memory_efficient_attention_mem_eff = False
self.use_sdpa = False
# Attention processor
self.processor = None
def set_use_memory_efficient_attention(self, xformers, mem_eff):
self.use_memory_efficient_attention_xformers = xformers
self.use_memory_efficient_attention_mem_eff = mem_eff
@@ -607,7 +610,28 @@ class CrossAttention(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def forward(self, hidden_states, context=None, mask=None):
def set_processor(self):
return self.processor
def get_processor(self):
return self.processor
def forward(self, hidden_states, context=None, mask=None, **kwargs):
if self.processor is not None:
(
hidden_states,
encoder_hidden_states,
attention_mask,
) = translate_attention_names_from_diffusers(
hidden_states=hidden_states, context=context, mask=mask, **kwargs
)
return self.processor(
attn=self,
hidden_states=hidden_states,
encoder_hidden_states=context,
attention_mask=mask,
**kwargs
)
if self.use_memory_efficient_attention_xformers:
return self.forward_memory_efficient_xformers(hidden_states, context, mask)
if self.use_memory_efficient_attention_mem_eff:
@@ -720,6 +744,21 @@ class CrossAttention(nn.Module):
out = self.to_out[0](out)
return out
def translate_attention_names_from_diffusers(
hidden_states: torch.FloatTensor,
context: Optional[torch.FloatTensor] = None,
mask: Optional[torch.FloatTensor] = None,
# HF naming
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None
):
# translate from hugging face diffusers
context = context if context is not None else encoder_hidden_states
# translate from hugging face diffusers
mask = mask if mask is not None else attention_mask
return hidden_states, context, mask
# feedforward
class GEGLU(nn.Module):
@@ -1350,7 +1389,7 @@ class UNet2DConditionModel(nn.Module):
self.out_channels = OUT_CHANNELS
self.sample_size = sample_size
self.prepare_config()
self.prepare_config(sample_size=sample_size)
# state_dictの書式が変わるのでmoduleの持ち方は変えられない
@@ -1437,8 +1476,8 @@ class UNet2DConditionModel(nn.Module):
self.conv_out = nn.Conv2d(BLOCK_OUT_CHANNELS[0], OUT_CHANNELS, kernel_size=3, padding=1)
# region diffusers compatibility
def prepare_config(self):
self.config = SimpleNamespace()
def prepare_config(self, *args, **kwargs):
self.config = SimpleNamespace(**kwargs)
@property
def dtype(self) -> torch.dtype:

View File

@@ -133,6 +133,12 @@ def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length):
# logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す
logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None)
# temporary workaround for text_projection.weight.weight for Playground-v2
if "text_projection.weight.weight" in new_sd:
print(f"convert_sdxl_text_encoder_2_checkpoint: convert text_projection.weight.weight to text_projection.weight")
new_sd["text_projection.weight"] = new_sd["text_projection.weight.weight"]
del new_sd["text_projection.weight.weight"]
return new_sd, logit_scale
@@ -258,7 +264,7 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty
te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k)
elif k.startswith("conditioner.embedders.1.model."):
te2_sd[k] = state_dict.pop(k)
# 一部のposition_idsがないモデルへの対応 / add position_ids for some models
if "text_model.embeddings.position_ids" not in te1_sd:
te1_sd["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0)

View File

@@ -51,8 +51,6 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder1, text_encoder2, unet = train_util.transform_models_if_DDP([text_encoder1, text_encoder2, unet])
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

View File

@@ -19,7 +19,7 @@ from typing import (
Tuple,
Union,
)
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs
import gc
import glob
import math
@@ -351,6 +351,7 @@ class BaseSubset:
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
@@ -368,6 +369,7 @@ class BaseSubset:
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
@@ -395,6 +397,7 @@ class DreamBoothSubset(BaseSubset):
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -415,6 +418,7 @@ class DreamBoothSubset(BaseSubset):
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -449,6 +453,7 @@ class FineTuningSubset(BaseSubset):
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -469,6 +474,7 @@ class FineTuningSubset(BaseSubset):
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -500,6 +506,7 @@ class ControlNetSubset(BaseSubset):
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -520,6 +527,7 @@ class ControlNetSubset(BaseSubset):
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
@@ -654,15 +662,33 @@ class BaseDataset(torch.utils.data.Dataset):
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
fixed_tokens = []
flex_tokens = []
if (
hasattr(subset, "keep_tokens_separator")
and subset.keep_tokens_separator
and subset.keep_tokens_separator in caption
):
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
math.floor(
(self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step))
)
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
flex_tokens = flex_tokens[:tokens_len]
def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
@@ -673,12 +699,6 @@ class BaseDataset(torch.utils.data.Dataset):
l.append(token)
return l
fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]
if subset.shuffle_caption:
random.shuffle(flex_tokens)
@@ -1722,7 +1742,9 @@ class ControlNetDataset(BaseDataset):
subset.caption_extension,
subset.num_repeats,
subset.shuffle_caption,
subset.caption_separator,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
@@ -2665,7 +2687,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser):
"--optimizer_type",
type=str,
default="",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor",
)
# backward compatibility
@@ -2877,6 +2899,16 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="DDP timeout (min, None for default of accelerate) / DDPのタイムアウト分、Noneでaccelerateのデフォルト",
)
parser.add_argument(
"--ddp_gradient_as_bucket_view",
action="store_true",
help="enable gradient_as_bucket_view for DDP / DDPでgradient_as_bucket_viewを有効にする",
)
parser.add_argument(
"--ddp_static_graph",
action="store_true",
help="enable static_graph for DDP / DDPでstatic_graphを有効にする",
)
parser.add_argument(
"--clip_skip",
type=int,
@@ -2979,6 +3011,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
)
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
parser.add_argument(
"--sample_every_n_epochs",
type=int,
@@ -3112,12 +3145,8 @@ def add_dataset_arguments(
):
# dataset common
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
parser.add_argument(
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
)
parser.add_argument(
"--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字"
)
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
parser.add_argument(
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
)
@@ -3133,6 +3162,13 @@ def add_dataset_arguments(
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残すトークンはカンマ区切りの各部分を意味する",
)
parser.add_argument(
"--keep_tokens_separator",
type=str,
default="",
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens."
+ " / captionを固定部分と可変部分に分けるためのカスタム区切り文字。この区切り文字より前のトークンはシャッフルされない。指定しない場合、'--keep_tokens'が固定部分のトークン数として使用される。",
)
parser.add_argument(
"--caption_prefix",
type=str,
@@ -3384,7 +3420,7 @@ def resume_from_local_or_hf_if_specified(accelerator, args):
def get_optimizer(args, trainable_params):
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
# "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor"
optimizer_type = args.optimizer_type
if args.use_8bit_adam:
@@ -3488,6 +3524,20 @@ def get_optimizer(args, trainable_params):
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW".lower():
print(f"use PagedAdamW optimizer | {optimizer_kwargs}")
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("No bitsandbytes / bitsandbytesがインストールされていないようです")
try:
optimizer_class = bnb.optim.PagedAdamW
except AttributeError:
raise AttributeError(
"No PagedAdamW. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedAdamWが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください"
)
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)
elif optimizer_type == "PagedAdamW32bit".lower():
print(f"use 32-bit PagedAdamW optimizer | {optimizer_kwargs}")
try:
@@ -3821,8 +3871,12 @@ def prepare_accelerator(args: argparse.Namespace):
wandb.login(key=args.wandb_api_key)
kwargs_handlers = (
None if args.ddp_timeout is None else [InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))]
InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout)) if args.ddp_timeout else None,
DistributedDataParallelKwargs(gradient_as_bucket_view=args.ddp_gradient_as_bucket_view, static_graph=args.ddp_static_graph)
if args.ddp_gradient_as_bucket_view or args.ddp_static_graph
else None,
)
kwargs_handlers = list(filter(lambda x: x is not None, kwargs_handlers))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
@@ -3897,17 +3951,6 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
return text_encoder, vae, unet, load_stable_diffusion_format
# TODO remove this function in the future
def transform_if_model_is_DDP(text_encoder, unet, network=None):
# Transform text_encoder, unet and network from DistributedDataParallel
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
def transform_models_if_DDP(models):
# Transform text_encoder, unet and network from DistributedDataParallel
return [model.module if type(model) == DDP else model for model in models if model is not None]
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
# load models for each process
for pi in range(accelerator.state.num_processes):
@@ -3931,8 +3974,6 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
torch.cuda.empty_cache()
accelerator.wait_for_everyone()
text_encoder, unet = transform_if_model_is_DDP(text_encoder, unet)
return text_encoder, vae, unet, load_stable_diffusion_format
@@ -4044,6 +4085,7 @@ def get_hidden_states_sdxl(
text_encoder1: CLIPTextModel,
text_encoder2: CLIPTextModelWithProjection,
weight_dtype: Optional[str] = None,
accelerator: Optional[Accelerator] = None,
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
@@ -4059,7 +4101,8 @@ def get_hidden_states_sdxl(
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
# pool2 = enc_out["text_embeds"]
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if max_token_length is None else max_token_length // 75
@@ -4448,13 +4491,119 @@ SCHEDULER_TIMESTEPS = 1000
SCHEDLER_SCHEDULE = "scaled_linear"
def get_my_scheduler(
*,
sample_sampler: str,
v_parameterization: bool,
):
sched_init_args = {}
if sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif sample_sampler == "lms" or sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif sample_sampler == "euler" or sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif sample_sampler == "euler_a" or sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif sample_sampler == "dpmsolver" or sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = sample_sampler
elif sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif sample_sampler == "dpm_2" or sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif sample_sampler == "dpm_2_a" or sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
return scheduler
def sample_images(*args, **kwargs):
return sample_images_common(StableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
def line_to_prompt_dict(line: str) -> dict:
# subset of gen_img_diffusers
prompt_args = line.split(" --")
prompt_dict = {}
prompt_dict["prompt"] = prompt_args[0]
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["width"] = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["height"] = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
prompt_dict["seed"] = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
prompt_dict["scale"] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)
continue
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["sample_sampler"] = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["controlnet_image"] = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
return prompt_dict
def sample_images_common(
pipe_class,
accelerator,
accelerator: Accelerator,
args: argparse.Namespace,
epoch,
steps,
@@ -4469,15 +4618,19 @@ def sample_images_common(
"""
StableDiffusionLongPromptWeightingPipelineの改造版を使うようにしたので、clip skipおよびプロンプトの重みづけに対応した
"""
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
if steps == 0:
if not args.sample_at_first:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
print(f"\ngenerating sample images at step / サンプル画像生成 ステップ: {steps}")
if not os.path.isfile(args.sample_prompts):
@@ -4487,6 +4640,13 @@ def sample_images_common(
org_vae_device = vae.device # CPUにいるはず
vae.to(device)
# unwrap unet and text_encoder(s)
unet = accelerator.unwrap_model(unet)
if isinstance(text_encoder, (list, tuple)):
text_encoder = [accelerator.unwrap_model(te) for te in text_encoder]
else:
text_encoder = accelerator.unwrap_model(text_encoder)
# read prompts
# with open(args.sample_prompts, "rt", encoding="utf-8") as f:
@@ -4504,56 +4664,19 @@ def sample_images_common(
with open(args.sample_prompts, "r", encoding="utf-8") as f:
prompts = json.load(f)
# schedulerを用意する
sched_init_args = {}
if args.sample_sampler == "ddim":
scheduler_cls = DDIMScheduler
elif args.sample_sampler == "ddpm": # ddpmはおかしくなるのでoptionから外してある
scheduler_cls = DDPMScheduler
elif args.sample_sampler == "pndm":
scheduler_cls = PNDMScheduler
elif args.sample_sampler == "lms" or args.sample_sampler == "k_lms":
scheduler_cls = LMSDiscreteScheduler
elif args.sample_sampler == "euler" or args.sample_sampler == "k_euler":
scheduler_cls = EulerDiscreteScheduler
elif args.sample_sampler == "euler_a" or args.sample_sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
elif args.sample_sampler == "dpmsolver" or args.sample_sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
sched_init_args["algorithm_type"] = args.sample_sampler
elif args.sample_sampler == "dpmsingle":
scheduler_cls = DPMSolverSinglestepScheduler
elif args.sample_sampler == "heun":
scheduler_cls = HeunDiscreteScheduler
elif args.sample_sampler == "dpm_2" or args.sample_sampler == "k_dpm_2":
scheduler_cls = KDPM2DiscreteScheduler
elif args.sample_sampler == "dpm_2_a" or args.sample_sampler == "k_dpm_2_a":
scheduler_cls = KDPM2AncestralDiscreteScheduler
else:
scheduler_cls = DDIMScheduler
if args.v_parameterization:
sched_init_args["prediction_type"] = "v_prediction"
scheduler = scheduler_cls(
num_train_timesteps=SCHEDULER_TIMESTEPS,
beta_start=SCHEDULER_LINEAR_START,
beta_end=SCHEDULER_LINEAR_END,
beta_schedule=SCHEDLER_SCHEDULE,
**sched_init_args,
schedulers: dict = {}
default_scheduler = get_my_scheduler(
sample_sampler=args.sample_sampler,
v_parameterization=args.v_parameterization,
)
# clip_sample=Trueにする
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
# print("set clip_sample to True")
scheduler.config.clip_sample = True
schedulers[args.sample_sampler] = default_scheduler
pipeline = pipe_class(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
scheduler=default_scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
@@ -4569,78 +4692,37 @@ def sample_images_common(
with torch.no_grad():
# with accelerator.autocast():
for i, prompt in enumerate(prompts):
for i, prompt_dict in enumerate(prompts):
if not accelerator.is_main_process:
continue
if isinstance(prompt, dict):
negative_prompt = prompt.get("negative_prompt")
sample_steps = prompt.get("sample_steps", 30)
width = prompt.get("width", 512)
height = prompt.get("height", 512)
scale = prompt.get("scale", 7.5)
seed = prompt.get("seed")
controlnet_image = prompt.get("controlnet_image")
prompt = prompt.get("prompt")
else:
# prompt = prompt.strip()
# if len(prompt) == 0 or prompt[0] == "#":
# continue
if isinstance(prompt_dict, str):
prompt_dict = line_to_prompt_dict(prompt_dict)
# subset of gen_img_diffusers
prompt_args = prompt.split(" --")
prompt = prompt_args[0]
negative_prompt = None
sample_steps = 30
width = height = 512
scale = 7.5
seed = None
controlnet_image = None
for parg in prompt_args:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
continue
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
continue
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
if m:
seed = int(m.group(1))
continue
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
if m: # steps
sample_steps = max(1, min(1000, int(m.group(1))))
continue
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
if m: # scale
scale = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
continue
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
if m: # negative prompt
controlnet_image = m.group(1)
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
assert isinstance(prompt_dict, dict)
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 30)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 7.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
scheduler = schedulers.get(sampler_name)
if scheduler is None:
scheduler = get_my_scheduler(
sample_sampler=sampler_name,
v_parameterization=args.v_parameterization,
)
schedulers[sampler_name] = scheduler
pipeline.scheduler = scheduler
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
@@ -4658,6 +4740,9 @@ def sample_images_common(
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
print(f"sample_sampler: {sampler_name}")
if seed is not None:
print(f"seed: {seed}")
with accelerator.autocast():
latents = pipeline(
prompt=prompt,

View File

@@ -397,13 +397,13 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
if train_unet:
unet = accelerator.prepare(unet)
(unet,) = train_util.transform_models_if_DDP([unet])
if train_text_encoder1:
# freeze last layer and final_layer_norm in te1 since we use the output of the penultimate layer
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)
text_encoder1 = accelerator.prepare(text_encoder1)
(text_encoder1,) = train_util.transform_models_if_DDP([text_encoder1])
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
(text_encoder2,) = train_util.transform_models_if_DDP([text_encoder2])
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
@@ -461,6 +461,11 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
sdxl_train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, [tokenizer1, tokenizer2], [text_encoder1, text_encoder2], unet
)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -482,7 +487,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:
@@ -503,6 +508,7 @@ def train(args):
# else:
input_ids1 = input_ids1.to(accelerator.device)
input_ids2 = input_ids2.to(accelerator.device)
# unwrap_model is fine for models not wrapped by accelerator
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
args.max_token_length,
input_ids1,
@@ -512,6 +518,7 @@ def train(args):
text_encoder1,
text_encoder2,
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)

View File

@@ -283,9 +283,6 @@ def train(args):
# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare (train_network here only)
unet = train_util.transform_models_if_DDP([unet])[0]
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -397,7 +394,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:

View File

@@ -254,9 +254,6 @@ def train(args):
)
network: control_net_lllite.ControlNetLLLite
# transform DDP after prepare (train_network here only)
unet, network = train_util.transform_models_if_DDP([unet, network])
if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
else:
@@ -366,7 +363,7 @@ def train(args):
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * sdxl_model_util.VAE_SCALE_FACTOR
if "text_encoder_outputs1_list" not in batch or batch["text_encoder_outputs1_list"] is None:

View File

@@ -1,9 +1,12 @@
import argparse
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -123,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
else:
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)

View File

@@ -64,6 +64,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
text_encoders[0],
text_encoders[1],
None if not args.full_fp16 else weight_dtype,
accelerator=accelerator,
)
return encoder_hidden_states1, encoder_hidden_states2, pool2

View File

@@ -11,10 +11,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -335,7 +338,9 @@ def train(args):
init_kwargs = {}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
accelerator.init_trackers(
"controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs
)
loss_recorder = train_util.LossRecorder()
del train_dataset_group
@@ -371,6 +376,11 @@ def train(args):
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
train_util.sample_images(
accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, controlnet=controlnet
)
# training loop
for epoch in range(num_train_epochs):
if is_main_process:

View File

@@ -112,6 +112,7 @@ def train(args):
# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
# モデルを読み込む
text_encoder, vae, unet, load_stable_diffusion_format = train_util.load_target_model(args, weight_dtype, accelerator)
@@ -136,7 +137,7 @@ def train(args):
# 学習を準備する
if cache_latents:
vae.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=vae_dtype)
vae.requires_grad_(False)
vae.eval()
with torch.no_grad():
@@ -225,9 +226,6 @@ def train(args):
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
if not train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error
@@ -274,6 +272,9 @@ def train(args):
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs)
# For --sample_at_first
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -482,6 +483,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="steps to stop text encoder training, -1 for no training / Text Encoderの学習を止めるステップ数、-1で最初から学習しない",
)
parser.add_argument(
"--no_half_vae",
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
return parser

View File

@@ -12,6 +12,7 @@ import toml
from tqdm import tqdm
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
try:
import intel_extension_for_pytorch as ipex
@@ -127,6 +128,11 @@ class NetworkTrainer:
noise_pred = unet(noisy_latents, timesteps, text_conds).sample
return noise_pred
def all_reduce_network(self, accelerator, network):
for param in network.parameters():
if param.grad is not None:
param.grad = accelerator.reduce(param.grad, reduction="mean")
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
@@ -390,46 +396,20 @@ class NetworkTrainer:
# acceleratorがなんかよろしくやってくれるらしい
# TODO めちゃくちゃ冗長なのでコードを整理する
if train_unet and train_text_encoder:
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
unet, t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
else:
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = accelerator.prepare(text_encoder)
text_encoders = [text_encoder]
elif train_unet:
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, network, optimizer, train_dataloader, lr_scheduler
)
else:
for t_enc in text_encoders:
t_enc.to(accelerator.device, dtype=weight_dtype)
elif train_text_encoder:
if len(text_encoders) > 1:
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], network, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoders = [t_enc1, t_enc2]
del t_enc1, t_enc2
else:
text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, network, optimizer, train_dataloader, lr_scheduler
)
text_encoders = [text_encoder]
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
else:
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
network, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare (train_network here only)
text_encoders = train_util.transform_models_if_DDP(text_encoders)
unet, network = train_util.transform_models_if_DDP([unet, network])
network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(network, optimizer, train_dataloader, lr_scheduler)
if args.gradient_checkpointing:
# according to TI example in Diffusers, train is required
@@ -451,7 +431,7 @@ class NetworkTrainer:
del t_enc
network.prepare_grad_etc(text_encoder, unet)
accelerator.unwrap_model(network).prepare_grad_etc(text_encoder, unet)
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
vae.requires_grad_(False)
@@ -714,8 +694,8 @@ class NetworkTrainer:
del train_dataset_group
# callback for step start
if hasattr(network, "on_step_start"):
on_step_start = network.on_step_start
if hasattr(accelerator.unwrap_model(network), "on_step_start"):
on_step_start = accelerator.unwrap_model(network).on_step_start
else:
on_step_start = lambda *args, **kwargs: None
@@ -743,6 +723,9 @@ class NetworkTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
@@ -750,7 +733,7 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch + 1)
network.on_epoch_start(text_encoder, unet)
accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet)
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
@@ -767,7 +750,7 @@ class NetworkTrainer:
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = torch.nan_to_num(latents, 0, out=latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
@@ -823,8 +806,9 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
self.all_reduce_network(accelerator, network) # sync DDP grad manually
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = network.get_trainable_params()
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
@@ -832,7 +816,7 @@ class NetworkTrainer:
optimizer.zero_grad(set_to_none=True)
if args.scale_weight_norms:
keys_scaled, mean_norm, maximum_norm = network.apply_max_norm_regularization(
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}

View File

@@ -7,10 +7,13 @@ import toml
from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
@@ -415,15 +418,11 @@ class TextualInversionTrainer:
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_or_list, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder_or_list, unet = train_util.transform_if_model_is_DDP(text_encoder_or_list, unet)
elif len(text_encoders) == 2:
text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder1, text_encoder2, unet = train_util.transform_if_model_is_DDP(text_encoder1, text_encoder2, unet)
text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2]
@@ -529,6 +528,20 @@ class TextualInversionTrainer:
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file)
# For --sample_at_first
self.sample_images(
accelerator,
args,
0,
global_step,
accelerator.device,
vae,
tokenizer_or_list,
text_encoder_or_list,
unet,
prompt_replacement,
)
# training loop
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")

View File

@@ -333,9 +333,6 @@ def train(args):
text_encoder, optimizer, train_dataloader, lr_scheduler
)
# transform DDP after prepare
text_encoder, unet = train_util.transform_if_model_is_DDP(text_encoder, unet)
index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0]
# print(len(index_no_updates), torch.sum(index_no_updates))
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.detach().clone()