mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add no_half_vae option
This commit is contained in:
@@ -173,6 +173,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
|
vae_dtype = torch.float32 if args.no_half_vae else weight_dtype
|
||||||
|
|
||||||
# モデルを読み込む
|
# モデルを読み込む
|
||||||
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
model_version, text_encoder_or_list, vae, unet = self.load_target_model(args, weight_dtype, accelerator)
|
||||||
@@ -351,7 +352,7 @@ class TextualInversionTrainer:
|
|||||||
|
|
||||||
# 学習を準備する
|
# 学習を準備する
|
||||||
if cache_latents:
|
if cache_latents:
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -447,10 +448,10 @@ class TextualInversionTrainer:
|
|||||||
else:
|
else:
|
||||||
unet.eval()
|
unet.eval()
|
||||||
|
|
||||||
if not cache_latents:
|
if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
vae.to(accelerator.device, dtype=weight_dtype)
|
vae.to(accelerator.device, dtype=vae_dtype)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -529,7 +530,7 @@ class TextualInversionTrainer:
|
|||||||
latents = batch["latents"].to(accelerator.device)
|
latents = batch["latents"].to(accelerator.device)
|
||||||
else:
|
else:
|
||||||
# latentに変換
|
# latentに変換
|
||||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
|
||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
@@ -744,6 +745,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
help="ignore caption and use default templates for stype / キャプションは使わずデフォルトのスタイル用テンプレートで学習する",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no_half_vae",
|
||||||
|
action="store_true",
|
||||||
|
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user