mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
disabled sampling (for now)
This commit is contained in:
@@ -794,8 +794,6 @@ class PipelineLike:
|
|||||||
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
|
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings]))
|
||||||
else:
|
else:
|
||||||
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
|
text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]))
|
||||||
|
|
||||||
|
|
||||||
text_embeddings = torch.stack(text_embeddings_concat)
|
text_embeddings = torch.stack(text_embeddings_concat)
|
||||||
else:
|
else:
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
@@ -803,7 +801,6 @@ class PipelineLike:
|
|||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||||
else:
|
else:
|
||||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])
|
||||||
|
|
||||||
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
|
text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings(
|
||||||
pipe=self,
|
pipe=self,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import gc
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import toml
|
import toml
|
||||||
|
from multiprocessing import Value
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
@@ -17,7 +18,8 @@ from library.config_util import (
|
|||||||
ConfigSanitizer,
|
ConfigSanitizer,
|
||||||
BlueprintGenerator,
|
BlueprintGenerator,
|
||||||
)
|
)
|
||||||
|
import library.custom_train_functions as custom_train_functions
|
||||||
|
from library.custom_train_functions import apply_snr_weight
|
||||||
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
|
||||||
|
|
||||||
imagenet_templates_small = [
|
imagenet_templates_small = [
|
||||||
@@ -73,10 +75,6 @@ imagenet_style_templates_small = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(examples):
|
|
||||||
return examples[0]
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
if args.output_name is None:
|
if args.output_name is None:
|
||||||
args.output_name = args.token_string
|
args.output_name = args.token_string
|
||||||
@@ -195,6 +193,10 @@ def train(args):
|
|||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||||
|
current_epoch = Value('i',0)
|
||||||
|
current_step = Value('i',0)
|
||||||
|
ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None
|
||||||
|
collater = train_util.collater_class(current_epoch,current_step, ds_for_collater)
|
||||||
|
|
||||||
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
# make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装
|
||||||
if use_template:
|
if use_template:
|
||||||
@@ -207,14 +209,14 @@ def train(args):
|
|||||||
train_dataset_group.add_replacement("", captions)
|
train_dataset_group.add_replacement("", captions)
|
||||||
|
|
||||||
if args.num_vectors_per_token > 1:
|
if args.num_vectors_per_token > 1:
|
||||||
prompt_replacement = [args.token_string, replace_to]
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
else:
|
else:
|
||||||
prompt_replacement = None
|
prompt_replacement = None
|
||||||
else:
|
else:
|
||||||
if args.num_vectors_per_token > 1:
|
if args.num_vectors_per_token > 1:
|
||||||
replace_to = " ".join(token_strings)
|
replace_to = " ".join(token_strings)
|
||||||
train_dataset_group.add_replacement(args.token_string, replace_to)
|
train_dataset_group.add_replacement(args.token_string, replace_to)
|
||||||
prompt_replacement = [args.token_string, replace_to]
|
prompt_replacement = (args.token_string, replace_to)
|
||||||
else:
|
else:
|
||||||
prompt_replacement = None
|
prompt_replacement = None
|
||||||
|
|
||||||
@@ -264,16 +266,19 @@ def train(args):
|
|||||||
train_dataset_group,
|
train_dataset_group,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collater,
|
||||||
num_workers=n_workers,
|
num_workers=n_workers,
|
||||||
persistent_workers=args.persistent_data_loader_workers,
|
persistent_workers=args.persistent_data_loader_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 学習ステップ数を計算する
|
# 学習ステップ数を計算する
|
||||||
if args.max_train_epochs is not None:
|
if args.max_train_epochs is not None:
|
||||||
args.max_train_steps = args.max_train_epochs * len(train_dataloader)
|
args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps)
|
||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
|
# データセット側にも学習ステップを送信
|
||||||
|
train_dataset_group.set_max_train_steps(args.max_train_steps)
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
|
||||||
|
|
||||||
@@ -345,12 +350,14 @@ def train(args):
|
|||||||
|
|
||||||
for epoch in range(num_train_epochs):
|
for epoch in range(num_train_epochs):
|
||||||
print(f"epoch {epoch+1}/{num_train_epochs}")
|
print(f"epoch {epoch+1}/{num_train_epochs}")
|
||||||
train_dataset_group.set_current_epoch(epoch + 1)
|
current_epoch.value = epoch+1
|
||||||
|
|
||||||
text_encoder.train()
|
text_encoder.train()
|
||||||
|
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
current_step.value = global_step
|
||||||
with accelerator.accumulate(text_encoder):
|
with accelerator.accumulate(text_encoder):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if "latents" in batch and batch["latents"] is not None:
|
if "latents" in batch and batch["latents"] is not None:
|
||||||
@@ -392,6 +399,9 @@ def train(args):
|
|||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean([1, 2, 3])
|
loss = loss.mean([1, 2, 3])
|
||||||
|
|
||||||
|
if args.min_snr_gamma:
|
||||||
|
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||||
|
|
||||||
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
loss_weights = batch["loss_weights"] # 各sampleごとのweight
|
||||||
loss = loss * loss_weights
|
loss = loss * loss_weights
|
||||||
|
|
||||||
@@ -416,10 +426,10 @@ def train(args):
|
|||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
# TODO: fix sample_images
|
||||||
train_util.sample_images(
|
# train_util.sample_images(
|
||||||
accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
# accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
)
|
# )
|
||||||
|
|
||||||
current_loss = loss.detach().item()
|
current_loss = loss.detach().item()
|
||||||
if args.logging_dir is not None:
|
if args.logging_dir is not None:
|
||||||
@@ -466,9 +476,10 @@ def train(args):
|
|||||||
if saving and args.save_state:
|
if saving and args.save_state:
|
||||||
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1)
|
||||||
|
|
||||||
train_util.sample_images(
|
# TODO: fix sample_images
|
||||||
accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
# train_util.sample_images(
|
||||||
)
|
# accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement
|
||||||
|
# )
|
||||||
|
|
||||||
# end of epoch
|
# end of epoch
|
||||||
|
|
||||||
@@ -543,6 +554,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
train_util.add_training_arguments(parser, True)
|
train_util.add_training_arguments(parser, True)
|
||||||
train_util.add_optimizer_arguments(parser)
|
train_util.add_optimizer_arguments(parser)
|
||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save_model_as",
|
"--save_model_as",
|
||||||
|
|||||||
Reference in New Issue
Block a user