From 33c311ed19821c9be7094ba89371777d7478b028 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Sun, 5 Nov 2023 12:37:37 -0500 Subject: [PATCH] new ratio code --- train_network.py | 48 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 5 deletions(-) diff --git a/train_network.py b/train_network.py index 58767b6f..967c95fb 100644 --- a/train_network.py +++ b/train_network.py @@ -345,10 +345,48 @@ class NetworkTrainer: # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで + def get_indices_without_reg(dataset: torch.utils.data.Dataset): + return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False] + + from typing import Sequence, Union + from torch._utils import _accumulate + import warnings + from torch.utils.data.dataset import Subset + + def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]): + indices = get_indices_without_reg(dataset) + random.shuffle(indices) + + subset_lengths = [] + + for i, frac in enumerate(lengths): + if frac < 0 or frac > 1: + raise ValueError(f"Fraction at index {i} is not between 0 and 1") + n_items_in_split = int(math.floor(len(indices) * frac)) + subset_lengths.append(n_items_in_split) + + remainder = len(indices) - sum(subset_lengths) + + for i in range(remainder): + idx_to_add_at = i % len(subset_lengths) + subset_lengths[idx_to_add_at] += 1 + + lengths = subset_lengths + for i, length in enumerate(lengths): + if length == 0: + warnings.warn(f"Length of split at index {i} is 0. " + f"This might result in an empty dataset.") + + if sum(lengths) != len(indices): + raise ValueError("Sum of input lengths does not equal the length of the input dataset!") + + return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)] + + if args.validation_ratio > 0.0: train_ratio = 1 - args.validation_ratio validation_ratio = args.validation_ratio - train, val = torch.utils.data.random_split( + train, val = random_split( train_dataset_group, [train_ratio, validation_ratio] ) @@ -358,6 +396,8 @@ class NetworkTrainer: train = train_dataset_group val = [] + + train_dataloader = torch.utils.data.DataLoader( train, batch_size=1, @@ -898,7 +938,7 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) - accelerator.log(logs, step=global_step) + accelerator.log(logs) if global_step >= args.max_train_steps: break @@ -973,13 +1013,11 @@ class NetworkTrainer: loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし current_loss = loss.detach().item() - val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) if len(val_dataloader) > 0: - avr_loss: float = val_loss_recorder.moving_average - if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average logs = {"loss/validation": avr_loss} accelerator.log(logs, step=epoch + 1)