mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
fix for multi gpu training
This commit is contained in:
@@ -2294,6 +2294,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
|
if not accelerator.is_main_process:
|
||||||
|
continue
|
||||||
prompt = prompt.strip()
|
prompt = prompt.strip()
|
||||||
if len(prompt) == 0 or prompt[0] == '#':
|
if len(prompt) == 0 or prompt[0] == '#':
|
||||||
continue
|
continue
|
||||||
@@ -2351,6 +2353,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
|
|||||||
if negative_prompt is not None:
|
if negative_prompt is not None:
|
||||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||||
|
|
||||||
|
print(f"prompt: {prompt}")
|
||||||
|
print(f"negative_prompt: {negative_prompt}")
|
||||||
|
print(f"height: {height}")
|
||||||
|
print(f"width: {width}")
|
||||||
|
print(f"sample_steps: {sample_steps}")
|
||||||
|
print(f"scale: {scale}")
|
||||||
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
|
||||||
|
|
||||||
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())
|
||||||
|
|||||||
@@ -106,6 +106,7 @@ def train(args):
|
|||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
print("prepare accelerator")
|
print("prepare accelerator")
|
||||||
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
accelerator, unwrap_model = train_util.prepare_accelerator(args)
|
||||||
|
is_main_process = accelerator.is_main_process
|
||||||
|
|
||||||
# mixed precisionに対応した型を用意しておき適宜castする
|
# mixed precisionに対応した型を用意しておき適宜castする
|
||||||
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
weight_dtype, save_dtype = train_util.prepare_dtype(args)
|
||||||
@@ -175,12 +176,13 @@ def train(args):
|
|||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
if is_main_process:
|
||||||
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
|
||||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
@@ -251,15 +253,17 @@ def train(args):
|
|||||||
# 学習する
|
# 学習する
|
||||||
# TODO: find a way to handle total batch size when there are multiple datasets
|
# TODO: find a way to handle total batch size when there are multiple datasets
|
||||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
print("running training / 学習開始")
|
|
||||||
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
if is_main_process:
|
||||||
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
print("running training / 学習開始")
|
||||||
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
|
||||||
print(f" num epochs / epoch数: {num_train_epochs}")
|
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
|
||||||
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
|
||||||
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
print(f" num epochs / epoch数: {num_train_epochs}")
|
||||||
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
print(f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}")
|
||||||
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
# print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
|
||||||
|
print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
|
||||||
|
print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# TODO refactor metadata creation and move to util
|
# TODO refactor metadata creation and move to util
|
||||||
metadata = {
|
metadata = {
|
||||||
@@ -461,7 +465,8 @@ def train(args):
|
|||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
if is_main_process:
|
||||||
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
train_dataset_group.set_current_epoch(epoch + 1)
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch+1)
|
metadata["ss_epoch"] = str(epoch+1)
|
||||||
@@ -573,9 +578,10 @@ def train(args):
|
|||||||
print(f"removing old checkpoint: {old_ckpt_file}")
|
print(f"removing old checkpoint: {old_ckpt_file}")
|
||||||
os.remove(old_ckpt_file)
|
os.remove(old_ckpt_file)
|
||||||
|
|
||||||
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
if is_main_process:
|
||||||
if saving and args.save_state:
|
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
|
||||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
if saving and args.save_state:
|
||||||
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
train_util.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
|
||||||
|
|
||||||
@@ -584,7 +590,6 @@ def train(args):
|
|||||||
metadata["ss_epoch"] = str(num_train_epochs)
|
metadata["ss_epoch"] = str(num_train_epochs)
|
||||||
metadata["ss_training_finished_at"] = str(time.time())
|
metadata["ss_training_finished_at"] = str(time.time())
|
||||||
|
|
||||||
is_main_process = accelerator.is_main_process
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
network = unwrap_model(network)
|
network = unwrap_model(network)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user