mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 06:28:48 +00:00
add save_state_on_train end, fix reg imgs repeats
This commit is contained in:
@@ -76,7 +76,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.flip_aug = flip_aug
|
||||
self.color_aug = color_aug
|
||||
self.debug_dataset = debug_dataset
|
||||
self.padding_disabled = False
|
||||
self.token_padding_disabled = False
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
|
||||
|
||||
@@ -102,8 +102,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.image_data: dict[str, ImageInfo] = {}
|
||||
|
||||
def disable_padding(self):
|
||||
self.padding_disabled = True
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
def process_caption(self, caption):
|
||||
if self.shuffle_caption:
|
||||
@@ -412,13 +412,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
caption = self.process_caption(image_info.caption)
|
||||
captions.append(caption)
|
||||
if not self.padding_disabled: # this option might be omitted in future
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
|
||||
example = {}
|
||||
example['loss_weights'] = torch.FloatTensor(loss_weights)
|
||||
|
||||
if self.padding_disabled:
|
||||
if self.token_padding_disabled:
|
||||
# padding=True means pad in the batch
|
||||
example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
|
||||
else:
|
||||
@@ -540,13 +540,20 @@ class DreamBoothDataset(BaseDataset):
|
||||
if num_reg_images == 0:
|
||||
print("no regularization images / 正則化画像が見つかりませんでした")
|
||||
else:
|
||||
# num_repeatsを計算する:どうせ大した数ではないのでループで処理する
|
||||
n = 0
|
||||
first_loop = True
|
||||
while n < num_train_images:
|
||||
for info in reg_infos:
|
||||
self.register_image(info)
|
||||
n += info.num_repeats
|
||||
if n >= num_train_images: # reg画像にnum_repeats>1のときはまずありえないので考慮しない
|
||||
if first_loop:
|
||||
self.register_image(info)
|
||||
n += info.num_repeats
|
||||
else:
|
||||
info.num_repeats += 1
|
||||
n += 1
|
||||
if n >= num_train_images:
|
||||
break
|
||||
first_loop = False
|
||||
|
||||
self.num_reg_images = num_reg_images
|
||||
|
||||
@@ -1253,7 +1260,6 @@ def save_on_epoch_end(args: argparse.Namespace, save_func, remove_old_func, epoc
|
||||
saving = epoch_no % args.save_every_n_epochs == 0 and epoch_no < num_train_epochs
|
||||
remove_epoch_no = None
|
||||
if saving:
|
||||
print("saving checkpoint.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
save_func()
|
||||
|
||||
@@ -1270,6 +1276,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
if save_stable_diffusion_format:
|
||||
def save_sd():
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||
src_path, epoch_no, global_step, save_dtype, vae)
|
||||
|
||||
@@ -1277,6 +1284,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
_, old_ckpt_name = get_epoch_ckpt_name(args, use_safetensors, old_epoch_no)
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
save_func = save_sd
|
||||
@@ -1284,6 +1292,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
else:
|
||||
def save_du():
|
||||
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, epoch_no))
|
||||
print(f"saving model: {out_dir}")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
src_path, vae=vae, use_safetensors=use_safetensors)
|
||||
@@ -1291,6 +1300,7 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
def remove_du(old_epoch_no):
|
||||
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
|
||||
if os.path.exists(out_dir_old):
|
||||
print(f"removing old model: {out_dir_old}")
|
||||
shutil.rmtree(out_dir_old)
|
||||
|
||||
save_func = save_du
|
||||
@@ -1298,19 +1308,17 @@ def save_sd_model_on_epoch_end(args: argparse.Namespace, accelerator, src_path:
|
||||
|
||||
saving, remove_epoch_no = save_on_epoch_end(args, save_func, remove_old_func, epoch_no, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
save_state_on_epoch_end(args, accelerator, model_name, epoch_no, remove_epoch_no)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
def save_state_on_epoch_end(args: argparse.Namespace, accelerator, model_name, epoch_no, remove_epoch_no):
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, epoch_no)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
print(f"removing old state: {state_dir_old}")
|
||||
shutil.rmtree(state_dir_old)
|
||||
|
||||
|
||||
def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_stable_diffusion_format: bool, use_safetensors: bool, save_dtype: torch.dtype, epoch: int, global_step: int, text_encoder, unet, vae):
|
||||
@@ -1326,12 +1334,19 @@ def save_sd_model_on_train_end(args: argparse.Namespace, src_path: str, save_sta
|
||||
model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
|
||||
src_path, epoch, global_step, save_dtype, vae)
|
||||
else:
|
||||
print(f"save trained model as Diffusers to {args.output_dir}")
|
||||
|
||||
out_dir = os.path.join(args.output_dir, model_name)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
print(f"save trained model as Diffusers to {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet,
|
||||
src_path, vae=vae, use_safetensors=use_safetensors)
|
||||
|
||||
|
||||
def save_state_on_train_end(args: argparse.Namespace, accelerator):
|
||||
print("saving last state.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model_name = DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name
|
||||
accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME.format(model_name)))
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -302,22 +302,19 @@ def train(args):
|
||||
def save_func():
|
||||
ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"saving checkpoint: {ckpt_file}")
|
||||
unwrap_model(network).save_weights(ckpt_file, save_dtype)
|
||||
|
||||
def remove_old_func(old_epoch_no):
|
||||
old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
saving, remove_epoch_no = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||
if saving and args.save_state:
|
||||
print("saving state.")
|
||||
accelerator.save_state(os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, epoch + 1)))
|
||||
if remove_epoch_no is not None:
|
||||
state_dir_old = os.path.join(args.output_dir, train_util.EPOCH_STATE_NAME.format(model_name, remove_epoch_no))
|
||||
if os.path.exists(state_dir_old):
|
||||
shutil.rmtree(state_dir_old)
|
||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1, remove_epoch_no)
|
||||
|
||||
# end of epoch
|
||||
|
||||
|
||||
Reference in New Issue
Block a user