fix for multi gpu training

This commit is contained in:
ddPn08
2023-03-03 00:21:18 +09:00
parent 8d5ba29363
commit 87846c043f
2 changed files with 30 additions and 17 deletions

View File

@@ -2294,6 +2294,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
with torch.no_grad(): with torch.no_grad():
with accelerator.autocast(): with accelerator.autocast():
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
if not accelerator.is_main_process:
continue
prompt = prompt.strip() prompt = prompt.strip()
if len(prompt) == 0 or prompt[0] == '#': if len(prompt) == 0 or prompt[0] == '#':
continue continue
@@ -2351,6 +2353,12 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if negative_prompt is not None: if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")
print(f"width: {width}")
print(f"sample_steps: {sample_steps}")
print(f"scale: {scale}")
image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0] image = pipeline(prompt, height, width, sample_steps, scale, negative_prompt).images[0]
ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime()) ts_str = time.strftime('%Y%m%d%H%M%S', time.localtime())

View File

@@ -106,6 +106,7 @@ def train(args):
# acceleratorを準備する # acceleratorを準備する
print("prepare accelerator") print("prepare accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args) accelerator, unwrap_model = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする # mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args) weight_dtype, save_dtype = train_util.prepare_dtype(args)
@@ -175,12 +176,13 @@ def train(args):
# 学習ステップ数を計算する # 学習ステップ数を計算する
if args.max_train_epochs is not None: if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * len(train_dataloader) args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes)
if is_main_process:
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する # lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, lr_scheduler = train_util.get_scheduler_fix(args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * accelerator.num_processes * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする # 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
@@ -251,6 +253,8 @@ def train(args):
# 学習する # 学習する
# TODO: find a way to handle total batch size when there are multiple datasets # TODO: find a way to handle total batch size when there are multiple datasets
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
if is_main_process:
print("running training / 学習開始") print("running training / 学習開始")
print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset_group.num_train_images}")
print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}") print(f" num reg images / 正則化画像の数: {train_dataset_group.num_reg_images}")
@@ -461,6 +465,7 @@ def train(args):
loss_list = [] loss_list = []
loss_total = 0.0 loss_total = 0.0
for epoch in range(num_train_epochs): for epoch in range(num_train_epochs):
if is_main_process:
print(f"epoch {epoch+1}/{num_train_epochs}") print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1) train_dataset_group.set_current_epoch(epoch + 1)
@@ -573,6 +578,7 @@ def train(args):
print(f"removing old checkpoint: {old_ckpt_file}") print(f"removing old checkpoint: {old_ckpt_file}")
os.remove(old_ckpt_file) os.remove(old_ckpt_file)
if is_main_process:
saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs)
if saving and args.save_state: if saving and args.save_state:
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
@@ -584,7 +590,6 @@ def train(args):
metadata["ss_epoch"] = str(num_train_epochs) metadata["ss_epoch"] = str(num_train_epochs)
metadata["ss_training_finished_at"] = str(time.time()) metadata["ss_training_finished_at"] = str(time.time())
is_main_process = accelerator.is_main_process
if is_main_process: if is_main_process:
network = unwrap_model(network) network = unwrap_model(network)