mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
support SD3.5L, fix final saving
This commit is contained in:
35
sd3_train.py
35
sd3_train.py
@@ -321,7 +321,7 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# now we can delete Text Encoders to free memory
|
||||
if args.use_t5xxl_cache_only:
|
||||
if not args.use_t5xxl_cache_only:
|
||||
clip_l = None
|
||||
clip_g = None
|
||||
t5xxl = None
|
||||
@@ -330,6 +330,7 @@ def train(args):
|
||||
|
||||
# load VAE for caching latents
|
||||
if sd3_state_dict is None:
|
||||
logger.info(f"load state dict for MMDiT and VAE from {args.pretrained_model_name_or_path}")
|
||||
sd3_state_dict = utils.load_safetensors(
|
||||
args.pretrained_model_name_or_path, "cpu", args.disable_mmap_load_safetensors, model_dtype
|
||||
)
|
||||
@@ -360,11 +361,6 @@ def train(args):
|
||||
# attn_mode == "torch"
|
||||
# ), f"attn_mode {attn_mode} is not supported yet. Please use `--sdpa` instead of `--xformers`. / attn_mode {attn_mode} はサポートされていません。`--xformers`の代わりに`--sdpa`を使ってください。"
|
||||
|
||||
# SD3 state dict may contain multiple models, so we need to load it and extract one by one. annoying.
|
||||
logger.info(f"Loading SD3 models from {args.pretrained_model_name_or_path}")
|
||||
device_to_load = accelerator.device if args.lowram else "cpu"
|
||||
sd3_state_dict = utils.load_safetensors(args.pretrained_model_name_or_path, device_to_load, args.disable_mmap_load_safetensors)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
mmdit.enable_gradient_checkpointing()
|
||||
|
||||
@@ -555,7 +551,7 @@ def train(args):
|
||||
# clip_l.text_model.encoder.layers[-1].requires_grad_(False)
|
||||
# clip_l.text_model.final_layer_norm.requires_grad_(False)
|
||||
|
||||
# TextEncoderの出力をキャッシュするときには、すでに出力を取得済みなのでCPUへ移動する
|
||||
# move Text Encoders to GPU if not caching outputs
|
||||
if not args.cache_text_encoder_outputs:
|
||||
# make sure Text Encoders are on GPU
|
||||
# TODO support CPU for text encoders
|
||||
@@ -817,6 +813,13 @@ def train(args):
|
||||
# log empty object to commit the sample images to wandb
|
||||
accelerator.log({}, step=0)
|
||||
|
||||
# show model device and dtype
|
||||
logger.info(f"mmdit device: {mmdit.device}, dtype: {mmdit.dtype}" if mmdit else "mmdit is None")
|
||||
logger.info(f"clip_l device: {clip_l.device}, dtype: {clip_l.dtype}" if clip_l else "clip_l is None")
|
||||
logger.info(f"clip_g device: {clip_g.device}, dtype: {clip_g.dtype}" if clip_g else "clip_g is None")
|
||||
logger.info(f"t5xxl device: {t5xxl.device}, dtype: {t5xxl.dtype}" if t5xxl else "t5xxl is None")
|
||||
logger.info(f"vae device: {vae.device}, dtype: {vae.dtype}" if vae is not None else "vae is None")
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
epoch = 0 # avoid error when max_train_steps is 0
|
||||
for epoch in range(num_train_epochs):
|
||||
@@ -1055,10 +1058,10 @@ def train(args):
|
||||
save_dtype,
|
||||
epoch,
|
||||
global_step,
|
||||
accelerator.unwrap_model(clip_l) if train_clip else None,
|
||||
accelerator.unwrap_model(clip_g) if train_clip else None,
|
||||
accelerator.unwrap_model(t5xxl) if train_t5xxl else None,
|
||||
accelerator.unwrap_model(mmdit) if train_mmdit else None,
|
||||
clip_l if train_clip else None,
|
||||
clip_g if train_clip else None,
|
||||
t5xxl if train_t5xxl else None,
|
||||
mmdit if train_mmdit else None,
|
||||
vae,
|
||||
)
|
||||
logger.info("model saved.")
|
||||
@@ -1153,6 +1156,16 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--blocks_to_swap",
|
||||
type=int,
|
||||
default=None,
|
||||
help="[EXPERIMENTAL] "
|
||||
"Sets the number of blocks (~640MB) to swap during the forward and backward passes."
|
||||
"Increasing this number lowers the overall VRAM used during training at the expense of training speed (s/it)."
|
||||
" / 順伝播および逆伝播中にスワップするブロック(約640MB)の数を設定します。"
|
||||
"この数を増やすと、トレーニング中のVRAM使用量が減りますが、トレーニング速度(s/it)も低下します。",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_last_block_to_freeze",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user