diff --git a/library/train_util.py b/library/train_util.py index 9ce129bd..2d28ab11 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5387,7 +5387,7 @@ class LossRecorder: self.loss_total: float = 0.0 def add(self, *, epoch: int, step: int, loss: float) -> None: - if epoch == 0: + if epoch == 0 or step >= len(self.loss_list): self.loss_list.append(loss) else: self.loss_total -= self.loss_list[step] diff --git a/train_network.py b/train_network.py index c99d3724..7528932f 100644 --- a/train_network.py +++ b/train_network.py @@ -483,6 +483,15 @@ class NetworkTrainer: weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -493,6 +502,15 @@ class NetworkTrainer: models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + 1 # because + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -736,7 +754,53 @@ class NetworkTrainer: if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= accelerator.num_processes * args.gradient_accumulation_steps + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -793,16 +857,32 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + if initial_step > 0: + # set starting global step calculated from initial_step. because skipping steps doesn't increment global_step + global_step = initial_step // (accelerator.num_processes * args.gradient_accumulation_steps) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 + if initial_step > len(train_dataloader): + logger.info(f"skipping epoch {epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + continue + metadata["ss_epoch"] = str(epoch + 1) accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if initial_step > 0: + # logger.info(f"skipping step {step+1} because initial_step (multiplied) is {initial_step}") + loss_recorder.add(epoch=epoch, step=step, loss=0) # add dummy loss + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -1101,6 +1181,25 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) return parser