mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
new ratio code
This commit is contained in:
@@ -345,10 +345,48 @@ class NetworkTrainer:
|
|||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
|
def get_indices_without_reg(dataset: torch.utils.data.Dataset):
|
||||||
|
return [id for id, (key, item) in enumerate(dataset.image_data.items()) if item.is_reg is False]
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from torch._utils import _accumulate
|
||||||
|
import warnings
|
||||||
|
from torch.utils.data.dataset import Subset
|
||||||
|
|
||||||
|
def random_split(dataset: torch.utils.data.Dataset, lengths: Sequence[Union[int, float]]):
|
||||||
|
indices = get_indices_without_reg(dataset)
|
||||||
|
random.shuffle(indices)
|
||||||
|
|
||||||
|
subset_lengths = []
|
||||||
|
|
||||||
|
for i, frac in enumerate(lengths):
|
||||||
|
if frac < 0 or frac > 1:
|
||||||
|
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
|
||||||
|
n_items_in_split = int(math.floor(len(indices) * frac))
|
||||||
|
subset_lengths.append(n_items_in_split)
|
||||||
|
|
||||||
|
remainder = len(indices) - sum(subset_lengths)
|
||||||
|
|
||||||
|
for i in range(remainder):
|
||||||
|
idx_to_add_at = i % len(subset_lengths)
|
||||||
|
subset_lengths[idx_to_add_at] += 1
|
||||||
|
|
||||||
|
lengths = subset_lengths
|
||||||
|
for i, length in enumerate(lengths):
|
||||||
|
if length == 0:
|
||||||
|
warnings.warn(f"Length of split at index {i} is 0. "
|
||||||
|
f"This might result in an empty dataset.")
|
||||||
|
|
||||||
|
if sum(lengths) != len(indices):
|
||||||
|
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")
|
||||||
|
|
||||||
|
return [Subset(dataset, indices[offset - length: offset]) for offset, length in zip(_accumulate(lengths), lengths)]
|
||||||
|
|
||||||
|
|
||||||
if args.validation_ratio > 0.0:
|
if args.validation_ratio > 0.0:
|
||||||
train_ratio = 1 - args.validation_ratio
|
train_ratio = 1 - args.validation_ratio
|
||||||
validation_ratio = args.validation_ratio
|
validation_ratio = args.validation_ratio
|
||||||
train, val = torch.utils.data.random_split(
|
train, val = random_split(
|
||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
[train_ratio, validation_ratio]
|
[train_ratio, validation_ratio]
|
||||||
)
|
)
|
||||||
@@ -358,6 +396,8 @@ class NetworkTrainer:
|
|||||||
train = train_dataset_group
|
train = train_dataset_group
|
||||||
val = []
|
val = []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train,
|
train,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
@@ -898,7 +938,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||||
accelerator.log(logs, step=global_step)
|
accelerator.log(logs)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
@@ -973,13 +1013,11 @@ class NetworkTrainer:
|
|||||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
|
|
||||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||||
|
|
||||||
if len(val_dataloader) > 0:
|
if len(val_dataloader) > 0:
|
||||||
avr_loss: float = val_loss_recorder.moving_average
|
|
||||||
|
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
|
avr_loss: float = val_loss_recorder.moving_average
|
||||||
logs = {"loss/validation": avr_loss}
|
logs = {"loss/validation": avr_loss}
|
||||||
accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=epoch + 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user