mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
add save_every_n_steps option
This commit is contained in:
80
train_db.py
80
train_db.py
@@ -25,6 +25,7 @@ from library.config_util import (
|
||||
import library.custom_train_functions as custom_train_functions
|
||||
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
|
||||
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
train_util.prepare_dataset_args(args, False)
|
||||
@@ -273,18 +274,19 @@ def train(args):
|
||||
# Get the text embedding for conditioning
|
||||
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
|
||||
@@ -335,6 +337,27 @@ def train(args):
|
||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
False,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss": current_loss, "lr": float(lr_scheduler.get_last_lr()[0])}
|
||||
@@ -364,21 +387,24 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if args.save_every_n_epochs is not None:
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end(
|
||||
args,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
# checking for saving is in util
|
||||
src_path = src_stable_diffusion_ckpt if save_stable_diffusion_format else src_diffusers_model_path
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
src_path,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
unwrap_model(text_encoder),
|
||||
unwrap_model(unet),
|
||||
vae,
|
||||
)
|
||||
|
||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
@@ -389,7 +415,7 @@ def train(args):
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
if args.save_state:
|
||||
if args.save_state and is_main_process:
|
||||
train_util.save_state_on_train_end(args, accelerator)
|
||||
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
@@ -434,4 +460,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args = train_util.read_config_from_file(args, parser)
|
||||
|
||||
train(args)
|
||||
train(args)
|
||||
|
||||
Reference in New Issue
Block a user