mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
fix bucketing
This commit is contained in:
@@ -141,7 +141,6 @@ def train(args):
|
||||
controlnet = ControlNetModel.from_pretrained(filename)
|
||||
|
||||
|
||||
|
||||
# モデルに xformers とか memory efficient attention を組み込む
|
||||
train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
|
||||
|
||||
@@ -168,11 +167,11 @@ def train(args):
|
||||
controlnet.enable_gradient_checkpointing()
|
||||
|
||||
# 学習に必要なクラスを準備する
|
||||
print("prepare optimizer, data loader etc.")
|
||||
accelerator.print("prepare optimizer, data loader etc.")
|
||||
|
||||
trainable_params = controlnet.parameters()
|
||||
|
||||
optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(
|
||||
_, _, optimizer = train_util.get_optimizer(
|
||||
args, trainable_params
|
||||
)
|
||||
|
||||
@@ -198,10 +197,9 @@ def train(args):
|
||||
/ accelerator.num_processes
|
||||
/ args.gradient_accumulation_steps
|
||||
)
|
||||
if is_main_process:
|
||||
print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
accelerator.print(
|
||||
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
|
||||
)
|
||||
|
||||
# データセット側にも学習ステップを送信
|
||||
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||
@@ -216,7 +214,7 @@ def train(args):
|
||||
assert (
|
||||
args.mixed_precision == "fp16"
|
||||
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
|
||||
print("enable full fp16 training.")
|
||||
accelerator.print("enable full fp16 training.")
|
||||
controlnet.to(weight_dtype)
|
||||
|
||||
# acceleratorがなんかよろしくやってくれるらしい
|
||||
@@ -258,23 +256,21 @@ def train(args):
|
||||
|
||||
# 学習する
|
||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||
|
||||
if is_main_process:
|
||||
print("running training / 学習開始")
|
||||
print(
|
||||
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
|
||||
)
|
||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
print(
|
||||
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
accelerator.print("running training / 学習開始")
|
||||
accelerator.print(
|
||||
f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}"
|
||||
)
|
||||
accelerator.print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
|
||||
accelerator.print(
|
||||
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
|
||||
)
|
||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||
accelerator.print(
|
||||
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
|
||||
)
|
||||
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(args.max_train_steps),
|
||||
@@ -303,11 +299,11 @@ def train(args):
|
||||
del train_dataset_group
|
||||
|
||||
# function for saving/removing
|
||||
def save_model(ckpt_name, model, steps, epoch_no, force_sync_upload=False):
|
||||
def save_model(ckpt_name, model, force_sync_upload=False):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
accelerator.print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
|
||||
state_dict = model_util.convert_controlnet_state_dict_to_sd(model.state_dict())
|
||||
|
||||
@@ -332,13 +328,13 @@ def train(args):
|
||||
def remove_model(old_ckpt_name):
|
||||
old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name)
|
||||
if os.path.exists(old_ckpt_file):
|
||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
accelerator.print(f"removing old checkpoint: {old_ckpt_file}")
|
||||
os.remove(old_ckpt_file)
|
||||
|
||||
# training loop
|
||||
for epoch in range(num_train_epochs):
|
||||
if is_main_process:
|
||||
print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
|
||||
current_epoch.value = epoch + 1
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
@@ -470,7 +466,7 @@ def train(args):
|
||||
args, "." + args.save_model_as, global_step
|
||||
)
|
||||
save_model(
|
||||
ckpt_name, unwrap_model(controlnet), global_step, epoch
|
||||
ckpt_name, unwrap_model(controlnet),
|
||||
)
|
||||
|
||||
if args.save_state:
|
||||
@@ -520,7 +516,7 @@ def train(args):
|
||||
ckpt_name = train_util.get_epoch_ckpt_name(
|
||||
args, "." + args.save_model_as, epoch + 1
|
||||
)
|
||||
save_model(ckpt_name, unwrap_model(controlnet), global_step, epoch + 1)
|
||||
save_model(ckpt_name, unwrap_model(controlnet))
|
||||
|
||||
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
|
||||
if remove_epoch_no is not None:
|
||||
@@ -561,7 +557,7 @@ def train(args):
|
||||
if is_main_process:
|
||||
ckpt_name = train_util.get_last_ckpt_name(args, "." + args.save_model_as)
|
||||
save_model(
|
||||
ckpt_name, controlnet, global_step, num_train_epochs, force_sync_upload=True
|
||||
ckpt_name, controlnet, force_sync_upload=True
|
||||
)
|
||||
|
||||
print("model saved.")
|
||||
|
||||
Reference in New Issue
Block a user