diff --git a/library/train_util.py b/library/train_util.py index 2b051e1f..d2eb7cb2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1722,6 +1722,7 @@ class ControlNetDataset(BaseDataset): subset.caption_extension, subset.num_repeats, subset.shuffle_caption, + subset.caption_separator, subset.keep_tokens, subset.color_aug, subset.flip_aug, @@ -2979,9 +2980,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する" ) - parser.add_argument( - "--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する" - ) + parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する") parser.add_argument( "--sample_every_n_epochs", type=int, @@ -3115,12 +3114,8 @@ def add_dataset_arguments( ): # dataset common parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ") - parser.add_argument( - "--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする" - ) - parser.add_argument( - "--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字" - ) + parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする") + parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字") parser.add_argument( "--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子" ) @@ -4048,6 +4043,7 @@ def get_hidden_states_sdxl( text_encoder1: CLIPTextModel, text_encoder2: CLIPTextModelWithProjection, weight_dtype: Optional[str] = None, + accelerator: Optional[Accelerator] = None, ): # input_ids: b,n,77 -> b*n, 77 b_size = input_ids1.size()[0] @@ -4063,7 +4059,8 @@ def get_hidden_states_sdxl( hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer # pool2 = enc_out["text_embeds"] - pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) + unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2) + pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id) # b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280 n_size = 1 if max_token_length is None else max_token_length // 75 @@ -4451,6 +4448,7 @@ SCHEDULER_LINEAR_END = 0.0120 SCHEDULER_TIMESTEPS = 1000 SCHEDLER_SCHEDULE = "scaled_linear" + def get_my_scheduler( *, sample_sampler: str, @@ -4495,10 +4493,7 @@ def get_my_scheduler( ) # clip_sample=Trueにする - if ( - hasattr(scheduler.config, "clip_sample") - and scheduler.config.clip_sample is False - ): + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False: # print("set clip_sample to True") scheduler.config.clip_sample = True @@ -4513,48 +4508,48 @@ def line_to_prompt_dict(line: str) -> dict: # subset of gen_img_diffusers prompt_args = line.split(" --") prompt_dict = {} - prompt_dict['prompt'] = prompt_args[0] + prompt_dict["prompt"] = prompt_args[0] for parg in prompt_args: try: m = re.match(r"w (\d+)", parg, re.IGNORECASE) if m: - prompt_dict['width'] = int(m.group(1)) + prompt_dict["width"] = int(m.group(1)) continue m = re.match(r"h (\d+)", parg, re.IGNORECASE) if m: - prompt_dict['height'] = int(m.group(1)) + prompt_dict["height"] = int(m.group(1)) continue m = re.match(r"d (\d+)", parg, re.IGNORECASE) if m: - prompt_dict['seed'] = int(m.group(1)) + prompt_dict["seed"] = int(m.group(1)) continue m = re.match(r"s (\d+)", parg, re.IGNORECASE) if m: # steps - prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1)))) + prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1)))) continue m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) if m: # scale - prompt_dict['scale'] = float(m.group(1)) + prompt_dict["scale"] = float(m.group(1)) continue m = re.match(r"n (.+)", parg, re.IGNORECASE) if m: # negative prompt - prompt_dict['negative_prompt'] = m.group(1) + prompt_dict["negative_prompt"] = m.group(1) continue m = re.match(r"ss (.+)", parg, re.IGNORECASE) if m: - prompt_dict['sample_sampler'] = m.group(1) + prompt_dict["sample_sampler"] = m.group(1) continue m = re.match(r"cn (.+)", parg, re.IGNORECASE) if m: - prompt_dict['controlnet_image'] = m.group(1) + prompt_dict["controlnet_image"] = m.group(1) continue except ValueError as ex: @@ -4563,6 +4558,7 @@ def line_to_prompt_dict(line: str) -> dict: return prompt_dict + def sample_images_common( pipe_class, accelerator, @@ -4663,7 +4659,7 @@ def sample_images_common( seed = prompt_dict.get("seed") controlnet_image = prompt_dict.get("controlnet_image") prompt: str = prompt_dict.get("prompt", "") - sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler) + sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler) if seed is not None: torch.manual_seed(seed) @@ -4671,7 +4667,10 @@ def sample_images_common( scheduler = schedulers.get(sampler_name) if scheduler is None: - scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,) + scheduler = get_my_scheduler( + sample_sampler=sampler_name, + v_parameterization=args.v_parameterization, + ) schedulers[sampler_name] = scheduler pipeline.scheduler = scheduler diff --git a/sdxl_train.py b/sdxl_train.py index 05ad0878..501eef65 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -505,6 +505,7 @@ def train(args): # else: input_ids1 = input_ids1.to(accelerator.device) input_ids2 = input_ids2.to(accelerator.device) + # unwrap_model is fine for models not wrapped by accelerator encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl( args.max_token_length, input_ids1, @@ -514,6 +515,7 @@ def train(args): text_encoder1, text_encoder2, None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) else: encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 199c4e03..a35779d0 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -1,9 +1,12 @@ import argparse import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -123,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): text_encoders[0], text_encoders[1], None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) else: encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype) diff --git a/sdxl_train_textual_inversion.py b/sdxl_train_textual_inversion.py index f5cca17b..f8a1d7bc 100644 --- a/sdxl_train_textual_inversion.py +++ b/sdxl_train_textual_inversion.py @@ -64,6 +64,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine text_encoders[0], text_encoders[1], None if not args.full_fp16 else weight_dtype, + accelerator=accelerator, ) return encoder_hidden_states1, encoder_hidden_states2, pool2