mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
new ratio code
This commit is contained in:
@@ -345,10 +345,48 @@ class NetworkTrainer:
|
||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||
|
||||
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
|
||||
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
|
||||
|
||||
from typing import Sequence, Union
|
||||
from torch._utils import _accumulate
|
||||
import warnings
|
||||
from torch.utils.data.dataset import Subset
|
||||
|
||||
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
|
||||
indices = get_indices_without_reg(dataset)
|
||||
random.shuffle(indices)
|
||||
|
||||
subset_lengths = []
|
||||
|
||||
for i, frac in enumerate(lengths):
|
||||
if frac < 0 or frac > 1:
|
||||
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
|
||||
n_items_in_split = int(math.floor(len(indices) * frac))
|
||||
subset_lengths.append(n_items_in_split)
|
||||
|
||||
remainder = len(indices) - sum(subset_lengths)
|
||||
|
||||
for i in range(remainder):
|
||||
idx_to_add_at = i % len(subset_lengths)
|
||||
subset_lengths[idx_to_add_at] += 1
|
||||
|
||||
lengths = subset_lengths
|
||||
for i, length in enumerate(lengths):
|
||||
if length == 0:
|
||||
warnings.warn(f"Length of split at index {i} is 0. "
|
||||
f"This might result in an empty dataset.")
|
||||
|
||||
if sum(lengths) != len(indices):
|
||||
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||
|
||||
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||
|
||||
|
||||
if args.validation_ratio > 0.0:
|
||||
train_ratio = 1 - args.validation_ratio
|
||||
validation_ratio = args.validation_ratio
|
||||
train, val = torch.utils.data.random_split(
|
||||
train, val = random_split(
|
||||
train_dataset_group,
|
||||
[train_ratio, validation_ratio]
|
||||
)
|
||||
@@ -358,6 +396,8 @@ class NetworkTrainer:
|
||||
train = train_dataset_group
|
||||
val = []
|
||||
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train,
|
||||
batch_size=1,
|
||||
@@ -898,7 +938,7 @@ class NetworkTrainer:
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
accelerator.log(logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
@@ -973,13 +1013,11 @@ class NetworkTrainer:
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user