From b2363f1021955c049c98e65676efca130690c40f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:20:20 +0800 Subject: [PATCH 1/4] Final implementation --- library/train_util.py | 11 ++++- train_network.py | 104 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5d..beb33bf8 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -657,8 +657,15 @@ class BaseDataset(torch.utils.data.Dataset): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step diff --git a/train_network.py b/train_network.py index b272a6e1..76e6cd8a 100644 --- a/train_network.py +++ b/train_network.py @@ -493,17 +493,24 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 - if accelerator.is_main_process or args.deepspeed: + if accelerator.is_main_process: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - if len(weights) > i: - weights.pop(i) + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -514,6 +521,15 @@ class NetworkTrainer: models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -757,7 +773,53 @@ class NetworkTrainer: if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -816,7 +878,11 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -824,7 +890,12 @@ class NetworkTrainer: accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -1126,6 +1197,25 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 3eb27ced52e8bf522c7e490c3dacba1f8597f5b1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:24:15 +0800 Subject: [PATCH 2/4] Skip the final 1 step --- train_network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_network.py b/train_network.py index 76e6cd8a..d1f02d53 100644 --- a/train_network.py +++ b/train_network.py @@ -897,6 +897,10 @@ class NetworkTrainer: for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) From 4dbcef429b744d0cc101494802448b8c15f4f674 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Jun 2024 21:26:55 +0900 Subject: [PATCH 3/4] update for corner cases --- library/train_util.py | 3 +++ train_network.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 102f9f03..4736ff4f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -663,6 +663,7 @@ class BaseDataset(torch.utils.data.Dataset): for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch @@ -5560,6 +5561,8 @@ class LossRecorder: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/train_network.py b/train_network.py index d1f02d53..7ba07385 100644 --- a/train_network.py +++ b/train_network.py @@ -493,13 +493,15 @@ class NetworkTrainer: # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step @@ -813,11 +815,12 @@ class NetworkTrainer: ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose - epoch_to_start = initial_step // math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip global_step = 0 @@ -878,9 +881,11 @@ class NetworkTrainer: self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -892,7 +897,7 @@ class NetworkTrainer: skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): From 18d7597b0b39cc2204dfbdfdcbf0fead97414be1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jun 2024 19:51:30 +0900 Subject: [PATCH 4/4] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 52c96339..25aba639 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。