support SD3.5L, fix final saving

This commit is contained in:
Kohya S
2024-10-24 21:28:42 +09:00
parent e3c43bda49
commit 0286114bd2

View File

@@ -321,7 +321,7 @@ def train(args):
accelerator.wait_for_everyone()
# now we can delete Text Encoders to free memory
if args.use_t5xxl_cache_only:
if not args.use_t5xxl_cache_only:
clip_l = None
clip_g = None
t5xxl = None
@@ -330,6 +330,7 @@ def train(args):
# load VAE for caching latents
if sd3_state_dict is None:
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
sd3_state_dict = utils.load_safetensors(
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
)
@@ -360,11 +361,6 @@ def train(args):
# attn_mode == "torch"
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
# SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying.
logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}")
device_to_load = accelerator.device if args.lowram else "cpu"
sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors)
if args.gradient_checkpointing:
mmdit.enable_gradient_checkpointing()
@@ -555,7 +551,7 @@ def train(args):
# clip_l.text_model.encoder.layers[-1].requires_grad_(False)
# clip_l.text_model.final_layer_norm.requires_grad_(False)
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
# move Text Encoders to GPU if not caching outputs
if not args.cache_text_encoder_outputs:
# make sure Text Encoders are on GPU
# TODO support CPU for text encoders
@@ -817,6 +813,13 @@ def train(args):
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)
# show model device and dtype
logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None")
logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None")
logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None")
logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None")
logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None")
loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
@@ -1055,10 +1058,10 @@ def train(args):
save_dtype,
epoch,
global_step,
accelerator.unwrap_model(clip_l) if train_clip else None,
accelerator.unwrap_model(clip_g) if train_clip else None,
accelerator.unwrap_model(t5xxl) if train_t5xxl else None,
accelerator.unwrap_model(mmdit) if train_mmdit else None,
clip_l if train_clip else None,
clip_g if train_clip else None,
t5xxl if train_t5xxl else None,
mmdit if train_mmdit else None,
vae,
)
logger.info("model saved.")
@@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--blocks_to_swap",
type=int,
default=None,
help="[EXPERIMENTAL] "
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
" / 順伝播および逆伝播中にスワップするブロック約640MBの数を設定します。"
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度s/itも低下します。",
)
parser.add_argument(
"--num_last_block_to_freeze",
type=int,