mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
データセットにepoch、stepが通達されないバグ修正
This commit is contained in:
@@ -8,6 +8,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -24,9 +25,18 @@ from library.config_util import (
|
|||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class collater_class:
|
||||||
def collate_fn(examples):
|
def __init__(self,epoch,step):
|
||||||
return examples[0]
|
self.current_epoch=epoch
|
||||||
|
self.current_step=step
|
||||||
|
def __call__(self, examples):
|
||||||
|
dataset = torch.utils.data.get_worker_info().dataset
|
||||||
|
dataset.set_current_epoch(self.current_epoch.value)
|
||||||
|
dataset.set_current_step(self.current_step.value)
|
||||||
|
# print("self.current_step:%d"%self.current_step)
|
||||||
|
# print("dataset_lengh:%d"%len(dataset))
|
||||||
|
print("id(self)(collate):%d"%id(self))
|
||||||
|
return examples[0]
|
||||||
|
|
||||||
|
|
||||||
# TODO 他のスクリプトと共通化する
|
# TODO 他のスクリプトと共通化する
|
||||||
@@ -101,6 +111,10 @@ def train(args):
|
|||||||
config_util.blueprint_args_conflict(args,blueprint)
|
config_util.blueprint_args_conflict(args,blueprint)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
collater = collater_class(current_epoch,current_step)
|
||||||
|
|
||||||
if args.debug_dataset:
|
if args.debug_dataset:
|
||||||
train_util.debug_dataset(train_dataset_group)
|
train_util.debug_dataset(train_dataset_group)
|
||||||
return
|
return
|
||||||
@@ -186,11 +200,12 @@ def train(args):
|
|||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0はメインプロセスになる
|
# DataLoaderのプロセス数:0はメインプロセスになる
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
@@ -498,17 +513,18 @@ def train(args):
|
|||||||
|
|
||||||
loss_list = []
|
loss_list = []
|
||||||
loss_total = 0.0
|
loss_total = 0.0
|
||||||
|
del train_dataset_group
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
train_dataset_group.set_current_step(global_step)
|
|
||||||
|
|
||||||
metadata["ss_epoch"] = str(epoch + 1)
|
metadata["ss_epoch"] = str(epoch + 1)
|
||||||
|
|
||||||
network.on_epoch_start(text_encoder, unet)
|
network.on_epoch_start(text_encoder, unet)
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(network):
|
with accelerator.accumulate(network):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user