fix to work with num_workers=0

This commit is contained in:
Kohya S
2023-03-28 19:42:47 +09:00
parent 99eaf1fd65
commit 4f70e5dca6
5 changed files with 27 additions and 13 deletions

View File

@@ -65,7 +65,8 @@ def train(args):
current_epoch = Value("i", 0)
current_step = Value("i", 0)
collater = train_util.collater_class(current_epoch, current_step)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)

View File

@@ -3057,12 +3057,20 @@ class ImageLoadingDataset(torch.utils.data.Dataset):
# collate_fn用 epoch,stepはmultiprocessing.Value
class collater_class:
def __init__(self, epoch, step):
def __init__(self, epoch, step, dataset):
self.current_epoch = epoch
self.current_step = step
self.dataset = dataset # not used if worker_info is not None, in case of multiprocessing
def __call__(self, examples):
dataset = torch.utils.data.get_worker_info().dataset
worker_info = torch.utils.data.get_worker_info()
# worker_info is None in the main process
if worker_info is not None:
dataset = worker_info.dataset
else:
dataset = self.dataset
# set epoch and step
dataset.set_current_epoch(self.current_epoch.value)
dataset.set_current_step(self.current_step.value)
return examples[0]

View File

@@ -23,7 +23,8 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight
def train(args):
train_util.verify_training_args(args)
@@ -57,9 +58,10 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch, current_step, ds_for_collater)
if args.no_token_padding:
train_dataset_group.disable_token_padding()
@@ -161,7 +163,9 @@ def train(args):
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps
)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
@@ -234,7 +238,7 @@ def train(args):
loss_total = 0.0
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch+1
current_epoch.value = epoch + 1
# 指定したステップ数までText Encoderを学習するepoch最初の状態
unet.train()
@@ -298,8 +302,7 @@ def train(args):
loss = loss * loss_weights
if args.min_snr_gamma:
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし

View File

@@ -101,7 +101,8 @@ def train(args):
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
if args.debug_dataset:
train_util.debug_dataset(train_dataset_group)

View File

@@ -186,7 +186,8 @@ def train(args):
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template: