make each script consistent, fix to work w/o DeepSpeed

This commit is contained in:
Kohya S
2024-03-25 22:28:46 +09:00
parent c24422fb9d
commit a2b8531627
9 changed files with 30 additions and 18 deletions

View File

@@ -471,8 +471,7 @@ class NetworkTrainer:
vae.to(accelerator.device, dtype=vae_dtype)
# 実験的機能勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16 and not args.deepspeed:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
if args.full_fp16:
train_util.patch_accelerator_for_fp16_training(accelerator)
# resumeする
@@ -781,11 +780,11 @@ class NetworkTrainer:
on_step_start(text_encoder, unet)
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
with torch.no_grad():
# latentに変換
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample()
latents = vae.encode(batch["images"].to(dtype=vae_dtype)).latent_dist.sample().to(dtype=weight_dtype)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):