new ratio code

This commit is contained in:
rockerBOO
2023-11-05 12:37:37 -05:00
parent 5b19bda85c
commit 33c311ed19

View File

@@ -345,10 +345,48 @@ class NetworkTrainer:
# DataLoaderのプロセス数0はメインプロセスになる
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
from typing import Sequence, Union
from torch._utils import _accumulate
import warnings
from torch.utils.data.dataset import Subset
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
indices = get_indices_without_reg(dataset)
random.shuffle(indices)
subset_lengths = []
for i, frac in enumerate(lengths):
if frac < 0 or frac > 1:
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
n_items_in_split = int(math.floor(len(indices) * frac))
subset_lengths.append(n_items_in_split)
remainder = len(indices) - sum(subset_lengths)
for i in range(remainder):
idx_to_add_at = i % len(subset_lengths)
subset_lengths[idx_to_add_at] += 1
lengths = subset_lengths
for i, length in enumerate(lengths):
if length == 0:
warnings.warn(f"Length of split at index {i} is 0. "
f"This might result in an empty dataset.")
if sum(lengths) != len(indices):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
if args.validation_ratio > 0.0:
train_ratio = 1 - args.validation_ratio
validation_ratio = args.validation_ratio
train, val = torch.utils.data.random_split(
train, val = random_split(
train_dataset_group,
[train_ratio, validation_ratio]
)
@@ -358,6 +396,8 @@ class NetworkTrainer:
train = train_dataset_group
val = []
train_dataloader = torch.utils.data.DataLoader(
train,
batch_size=1,
@@ -898,7 +938,7 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
accelerator.log(logs)
if global_step >= args.max_train_steps:
break
@@ -973,13 +1013,11 @@ class NetworkTrainer:
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
if len(val_dataloader) > 0:
avr_loss: float = val_loss_recorder.moving_average
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation": avr_loss}
accelerator.log(logs, step=epoch + 1)