Merge branch 'dev' into min-SNR

This commit is contained in:
Kohya S
2023-03-26 17:10:53 +09:00
committed by GitHub
7 changed files with 2713 additions and 2347 deletions

View File

@@ -4,6 +4,7 @@ import gc
import math
import os
import toml
from multiprocessing import Value
from tqdm import tqdm
import torch
@@ -73,10 +74,6 @@ imagenet_style_templates_small = [
]
def collate_fn(examples):
return examples[0]
def train(args):
if args.output_name is None:
args.output_name = args.token_string
@@ -187,6 +184,10 @@ def train(args):
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
current_epoch = Value('i',0)
current_step = Value('i',0)
collater = train_util.collater_class(current_epoch,current_step)
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
if use_template:
print("use template for training captions. is object: {args.use_object_template}")
@@ -252,7 +253,7 @@ def train(args):
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collate_fn,
collate_fn=collater,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
@@ -262,6 +263,9 @@ def train(args):
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)
# lr schedulerを用意する
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
@@ -333,12 +337,14 @@ def train(args):
for epoch in range(num_train_epochs):
print(f"epoch {epoch+1}/{num_train_epochs}")
train_dataset_group.set_current_epoch(epoch + 1)
current_epoch.value = epoch+1
text_encoder.train()
loss_total = 0
for step, batch in enumerate(train_dataloader):
current_step.value = global_step
with accelerator.accumulate(text_encoder):
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None: