From 24e3d4b4642673cebd552f1a9cbd3d99eac969a2 Mon Sep 17 00:00:00 2001 From: Jakaline-dev Date: Thu, 30 Mar 2023 02:20:04 +0900 Subject: [PATCH] disabled sampling (for now) --- gen_img_diffusers.py | 17 ++++++------- train_textual_inversion_XTI.py | 46 +++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index b562d097..cd0be71b 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -781,21 +781,19 @@ class PipelineLike: text_embeddings_concat = [] for layer in ['IN01', 'IN02', 'IN04', 'IN05', 'IN07', 'IN08', 'MID', 'OUT03', 'OUT04', 'OUT05', 'OUT06', 'OUT07', 'OUT08', 'OUT09', 'OUT10', 'OUT11']: text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - clip_skip=self.clip_skip, - layer=layer, - **kwargs, + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + clip_skip=self.clip_skip, + layer=layer, + **kwargs, ) if do_classifier_free_guidance: if negative_scale is None: text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings])) else: text_embeddings_concat.append(torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings])) - - text_embeddings = torch.stack(text_embeddings_concat) else: if do_classifier_free_guidance: @@ -803,7 +801,6 @@ class PipelineLike: text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) else: text_embeddings = torch.cat([uncond_embeddings, text_embeddings, real_uncond_embeddings]) - text_embeddings, uncond_embeddings, prompt_tokens = get_weighted_text_embeddings( pipe=self, prompt=prompt, diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 60bf4f7e..8d6ff430 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -4,6 +4,7 @@ import gc import math import os import toml +from multiprocessing import Value from tqdm import tqdm import torch @@ -17,7 +18,8 @@ from library.config_util import ( ConfigSanitizer, BlueprintGenerator, ) - +import library.custom_train_functions as custom_train_functions +from library.custom_train_functions import apply_snr_weight from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI imagenet_templates_small = [ @@ -73,10 +75,6 @@ imagenet_style_templates_small = [ ] -def collate_fn(examples): - return examples[0] - - def train(args): if args.output_name is None: args.output_name = args.token_string @@ -195,6 +193,10 @@ def train(args): blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer) train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings) + current_epoch = Value('i',0) + current_step = Value('i',0) + ds_for_collater = train_dataset_group if args.max_data_loader_n_workers == 0 else None + collater = train_util.collater_class(current_epoch,current_step, ds_for_collater) # make captions: tokenstring tokenstring1 tokenstring2 ...tokenstringn という文字列に書き換える超乱暴な実装 if use_template: @@ -207,14 +209,14 @@ def train(args): train_dataset_group.add_replacement("", captions) if args.num_vectors_per_token > 1: - prompt_replacement = [args.token_string, replace_to] + prompt_replacement = (args.token_string, replace_to) else: prompt_replacement = None else: if args.num_vectors_per_token > 1: replace_to = " ".join(token_strings) train_dataset_group.add_replacement(args.token_string, replace_to) - prompt_replacement = [args.token_string, replace_to] + prompt_replacement = (args.token_string, replace_to) else: prompt_replacement = None @@ -264,16 +266,19 @@ def train(args): train_dataset_group, batch_size=1, shuffle=True, - collate_fn=collate_fn, + collate_fn=collater, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) # 学習ステップ数を計算する if args.max_train_epochs is not None: - args.max_train_steps = args.max_train_epochs * len(train_dataloader) + args.max_train_steps = args.max_train_epochs * math.ceil(len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") + # データセット側にも学習ステップを送信 + train_dataset_group.set_max_train_steps(args.max_train_steps) + # lr schedulerを用意する lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) @@ -345,12 +350,14 @@ def train(args): for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") - train_dataset_group.set_current_epoch(epoch + 1) + current_epoch.value = epoch+1 text_encoder.train() loss_total = 0 + for step, batch in enumerate(train_dataloader): + current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: @@ -391,6 +398,9 @@ def train(args): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) + + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights @@ -416,10 +426,10 @@ def train(args): if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 - - train_util.sample_images( - accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) current_loss = loss.detach().item() if args.logging_dir is not None: @@ -466,9 +476,10 @@ def train(args): if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) - train_util.sample_images( - accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement - ) + # TODO: fix sample_images + # train_util.sample_images( + # accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet, prompt_replacement + # ) # end of epoch @@ -543,6 +554,7 @@ def setup_parser() -> argparse.ArgumentParser: train_util.add_training_arguments(parser, True) train_util.add_optimizer_arguments(parser) config_util.add_config_arguments(parser) + custom_train_functions.add_custom_train_arguments(parser) parser.add_argument( "--save_model_as",