From acf16c063a620c0d775e5c471e0384f5e54a6896 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 20 Jul 2023 21:41:16 +0900 Subject: [PATCH] make to work with PyTorch 1.12 --- library/train_util.py | 6 +++--- sdxl_minimal_inference.py | 3 ++- sdxl_train.py | 3 ++- tools/cache_latents.py | 3 ++- train_network.py | 3 ++- train_textual_inversion.py | 3 ++- 6 files changed, 13 insertions(+), 8 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f5d5288b..5f6e7d48 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1104,9 +1104,9 @@ class BaseDataset(torch.utils.data.Dataset): # crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1]) - original_sizes_hw.append((original_size[1], original_size[0])) - crop_top_lefts.append((crop_left_top[1], crop_left_top[0])) - target_sizes_hw.append((target_size[1], target_size[0])) + original_sizes_hw.append((int(original_size[1]), int(original_size[0]))) + crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0]))) + target_sizes_hw.append((int(target_size[1]), int(target_size[0]))) flippeds.append(flipped) # captionとtext encoder outputを処理する diff --git a/sdxl_minimal_inference.py b/sdxl_minimal_inference.py index d441877d..1a950902 100644 --- a/sdxl_minimal_inference.py +++ b/sdxl_minimal_inference.py @@ -146,7 +146,8 @@ if __name__ == "__main__": text_model2.eval() unet.set_use_memory_efficient_attention(True, False) - vae.set_use_memory_efficient_attention_xformers(True) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(True) # Tokenizers tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name) diff --git a/sdxl_train.py b/sdxl_train.py index 4dbed79a..7e3a8416 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -174,7 +174,8 @@ def train(args): # Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある accelerator.print("Disable Diffusers' xformers") train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - vae.set_use_memory_efficient_attention_xformers(args.xformers) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) # 学習を準備する if cache_latents: diff --git a/tools/cache_latents.py b/tools/cache_latents.py index d403d559..b6991ac1 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -104,7 +104,8 @@ def cache_to_disk(args: argparse.Namespace) -> None: else: _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - vae.set_use_memory_efficient_attention_xformers(args.xformers) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) vae.eval() diff --git a/train_network.py b/train_network.py index a55339c4..310f7506 100644 --- a/train_network.py +++ b/train_network.py @@ -217,7 +217,8 @@ class NetworkTrainer: # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - vae.set_use_memory_efficient_attention_xformers(args.xformers) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) # 差分追加学習のためにモデルを読み込む sys.path.append(os.path.dirname(__file__)) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index e227a13b..265b244b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -343,7 +343,8 @@ class TextualInversionTrainer: # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa) - vae.set_use_memory_efficient_attention_xformers(args.xformers) + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + vae.set_use_memory_efficient_attention_xformers(args.xformers) # 学習を準備する if cache_latents: