remove workaround for accelerator=0.15, fix XTI

This commit is contained in:
ykume
2023-06-11 18:32:14 +09:00
parent 33a6234b52
commit 0315611b11
7 changed files with 153 additions and 159 deletions

View File

@@ -150,7 +150,7 @@ def train(args):
# acceleratorを準備する
print("preparing accelerator")
accelerator, unwrap_model = train_util.prepare_accelerator(args)
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process
# mixed precisionに対応した型を用意しておき適宜castする
@@ -702,7 +702,7 @@ def train(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
ckpt_name = train_util.get_step_ckpt_name(args, "." + args.save_model_as, global_step)
save_model(ckpt_name, unwrap_model(network), global_step, epoch)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch)
if args.save_state:
train_util.save_and_remove_state_stepwise(args, accelerator, global_step)
@@ -744,7 +744,7 @@ def train(args):
saving = (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs
if is_main_process and saving:
ckpt_name = train_util.get_epoch_ckpt_name(args, "." + args.save_model_as, epoch + 1)
save_model(ckpt_name, unwrap_model(network), global_step, epoch + 1)
save_model(ckpt_name, accelerator.unwrap_model(network), global_step, epoch + 1)
remove_epoch_no = train_util.get_remove_epoch_no(args, epoch + 1)
if remove_epoch_no is not None:
@@ -762,7 +762,7 @@ def train(args):
metadata["ss_training_finished_at"] = str(time.time())
if is_main_process:
network = unwrap_model(network)
network = accelerator.unwrap_model(network)
accelerator.end_training()