mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 01:12:41 +00:00
make to work with PyTorch 1.12
This commit is contained in:
@@ -1104,9 +1104,9 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
# crop_ltrb[2] is right, so target_size[0] - crop_ltrb[2] is left in flipped image
|
||||
crop_left_top = (target_size[0] - crop_ltrb[2], crop_ltrb[1])
|
||||
|
||||
original_sizes_hw.append((original_size[1], original_size[0]))
|
||||
crop_top_lefts.append((crop_left_top[1], crop_left_top[0]))
|
||||
target_sizes_hw.append((target_size[1], target_size[0]))
|
||||
original_sizes_hw.append((int(original_size[1]), int(original_size[0])))
|
||||
crop_top_lefts.append((int(crop_left_top[1]), int(crop_left_top[0])))
|
||||
target_sizes_hw.append((int(target_size[1]), int(target_size[0])))
|
||||
flippeds.append(flipped)
|
||||
|
||||
# captionとtext encoder outputを処理する
|
||||
|
||||
@@ -146,7 +146,8 @@ if __name__ == "__main__":
|
||||
text_model2.eval()
|
||||
|
||||
unet.set_use_memory_efficient_attention(True, False)
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# Tokenizers
|
||||
tokenizer1 = CLIPTokenizer.from_pretrained(text_encoder_1_name)
|
||||
|
||||
@@ -174,7 +174,8 @@ def train(args):
|
||||
# Windows版のxformersはfloatで学習できなかったりするのでxformersを使わない設定も可能にしておく必要がある
|
||||
accelerator.print("Disable Diffusers' xformers")
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
|
||||
@@ -104,7 +104,8 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
||||
else:
|
||||
_, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator)
|
||||
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
vae.to(accelerator.device, dtype=vae_dtype)
|
||||
vae.requires_grad_(False)
|
||||
vae.eval()
|
||||
|
||||
@@ -217,7 +217,8 @@ class NetworkTrainer:
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 差分追加学習のためにモデルを読み込む
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
|
||||
@@ -343,7 +343,8 @@ class TextualInversionTrainer:
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える
|
||||
vae.set_use_memory_efficient_attention_xformers(args.xformers)
|
||||
|
||||
# 学習を準備する
|
||||
if cache_latents:
|
||||
|
||||
Reference in New Issue
Block a user