diff --git a/train_network.py b/train_network.py index 3e8f4e7d..710055e0 100644 --- a/train_network.py +++ b/train_network.py @@ -150,7 +150,9 @@ def train(args): # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) - + # unnecessary, but work on low-ram device + text_encoder.to("cuda") + unet.to("cuda") # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)