diff --git a/fine_tune.py b/fine_tune.py index 0f42741b..637a729a 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -65,7 +65,8 @@ def train(args): current_epoch = Value("i", 0) current_step = Value("i", 0) - collater = train_util.collater_class(current_epoch, current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/library/train_util.py b/library/train_util.py index 55b5101b..e1a8e922 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3057,12 +3057,20 @@ class ImageLoadingDataset(torch.utils.data.Dataset): # collate_fn用 epoch,stepはmultiprocessing.Value class collater_class: - def __init__(self, epoch, step): + def __init__(self, epoch, step, dataset): self.current_epoch = epoch self.current_step = step + self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing def __call__(self, examples): - dataset = torch.utils.data.get_worker_info().dataset + worker_info = torch.utils.data.get_worker_info() + # worker_info is None in the main process + if worker_info is not None: + dataset = worker_info.dataset + else: + dataset = self.dataset + + # set epoch and step dataset.set_current_epoch(self.current_epoch.value) dataset.set_current_step(self.current_step.value) return examples[0] diff --git a/train_db.py b/train_db.py index f441d5d6..b3eead94 100644 --- a/train_db.py +++ b/train_db.py @@ -23,7 +23,8 @@ from library.config_util import ( BlueprintGenerator, ) import library.custom_train_functions as custom_train_functions -from library.custom_train_functions import apply_snr_weight +from library.custom_train_functions import apply_snr_weight + def train(args): train_util.verify_training_args(args) @@ -57,9 +58,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) - current_epoch = Value('i',0) - current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + current_epoch = Value("i", 0) + current_step = Value("i", 0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch, current_step, ds_for_collater) if args.no_token_padding: train_dataset_group.disable_token_padding() @@ -161,7 +163,9 @@ def train(args): # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) + args.max_train_steps = args.max_train_epochs * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # データセット側にも学習ステップを送信 @@ -234,7 +238,7 @@ def train(args): loss_total = 0.0 for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - current_epoch.value = epoch+1 + current_epoch.value = epoch + 1 # 指定したステップ数までText Encoderを学習する:epoch最初の状態 unet.train() @@ -298,8 +302,7 @@ def train(args): loss = loss * loss_weights if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 20ad2c4d..423649ee 100644 --- a/train_network.py +++ b/train_network.py @@ -101,7 +101,8 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) if args.debug_dataset: train_util.debug_dataset(train_dataset_group) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 681bc628..f279370a 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -186,7 +186,8 @@ def train(args): current_epoch = Value('i',0) current_step = Value('i',0) - collater = train_util.collater_class(current_epoch,current_step) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: