diff --git a/library/sdxl_lpw_stable_diffusion.py b/library/sdxl_lpw_stable_diffusion.py index d44b3cf8..99b0bc8d 100644 --- a/library/sdxl_lpw_stable_diffusion.py +++ b/library/sdxl_lpw_stable_diffusion.py @@ -922,7 +922,7 @@ class SdxlStableDiffusionLongPromptWeightingPipeline: if up1 is not None: uncond_pool = up1 - dtype = text_embeddings_list[0].dtype + dtype = self.unet.dtype # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): diff --git a/library/train_util.py b/library/train_util.py index 809f0af0..9438a189 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3874,127 +3874,127 @@ def sample_images_common( cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None with torch.no_grad(): - with accelerator.autocast(): - for i, prompt in enumerate(prompts): - if not accelerator.is_main_process: - continue + # with accelerator.autocast(): + for i, prompt in enumerate(prompts): + if not accelerator.is_main_process: + continue - if isinstance(prompt, dict): - negative_prompt = prompt.get("negative_prompt") - sample_steps = prompt.get("sample_steps", 30) - width = prompt.get("width", 512) - height = prompt.get("height", 512) - scale = prompt.get("scale", 7.5) - seed = prompt.get("seed") - controlnet_image = prompt.get("controlnet_image") - prompt = prompt.get("prompt") - else: - # prompt = prompt.strip() - # if len(prompt) == 0 or prompt[0] == "#": - # continue + if isinstance(prompt, dict): + negative_prompt = prompt.get("negative_prompt") + sample_steps = prompt.get("sample_steps", 30) + width = prompt.get("width", 512) + height = prompt.get("height", 512) + scale = prompt.get("scale", 7.5) + seed = prompt.get("seed") + controlnet_image = prompt.get("controlnet_image") + prompt = prompt.get("prompt") + else: + # prompt = prompt.strip() + # if len(prompt) == 0 or prompt[0] == "#": + # continue - # subset of gen_img_diffusers - prompt_args = prompt.split(" --") - prompt = prompt_args[0] - negative_prompt = None - sample_steps = 30 - width = height = 512 - scale = 7.5 - seed = None - controlnet_image = None - for parg in prompt_args: - try: - m = re.match(r"w (\d+)", parg, re.IGNORECASE) - if m: - width = int(m.group(1)) - continue - - m = re.match(r"h (\d+)", parg, re.IGNORECASE) - if m: - height = int(m.group(1)) - continue - - m = re.match(r"d (\d+)", parg, re.IGNORECASE) - if m: - seed = int(m.group(1)) - continue - - m = re.match(r"s (\d+)", parg, re.IGNORECASE) - if m: # steps - sample_steps = max(1, min(1000, int(m.group(1)))) - continue - - m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) - if m: # scale - scale = float(m.group(1)) - continue - - m = re.match(r"n (.+)", parg, re.IGNORECASE) - if m: # negative prompt - negative_prompt = m.group(1) - continue - - m = re.match(r"cn (.+)", parg, re.IGNORECASE) - if m: # negative prompt - controlnet_image = m.group(1) - continue - - except ValueError as ex: - print(f"Exception in parsing / 解析エラー: {parg}") - print(ex) - - if seed is not None: - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - - if prompt_replacement is not None: - prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) - if negative_prompt is not None: - negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) - - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - - height = max(64, height - height % 8) # round to divisible by 8 - width = max(64, width - width % 8) # round to divisible by 8 - print(f"prompt: {prompt}") - print(f"negative_prompt: {negative_prompt}") - print(f"height: {height}") - print(f"width: {width}") - print(f"sample_steps: {sample_steps}") - print(f"scale: {scale}") - image = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=sample_steps, - guidance_scale=scale, - negative_prompt=negative_prompt, - controlnet=controlnet, - controlnet_image=controlnet_image, - ).images[0] - - ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" - seed_suffix = "" if seed is None else f"_{seed}" - img_filename = ( - f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" - ) - - image.save(os.path.join(save_dir, img_filename)) - - # wandb有効時のみログを送信 - try: - wandb_tracker = accelerator.get_tracker("wandb") + # subset of gen_img_diffusers + prompt_args = prompt.split(" --") + prompt = prompt_args[0] + negative_prompt = None + sample_steps = 30 + width = height = 512 + scale = 7.5 + seed = None + controlnet_image = None + for parg in prompt_args: try: - import wandb - except ImportError: # 事前に一度確認するのでここはエラー出ないはず - raise ImportError("No wandb / wandb がインストールされていないようです") + m = re.match(r"w (\d+)", parg, re.IGNORECASE) + if m: + width = int(m.group(1)) + continue - wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) - except: # wandb 無効時 - pass + m = re.match(r"h (\d+)", parg, re.IGNORECASE) + if m: + height = int(m.group(1)) + continue + + m = re.match(r"d (\d+)", parg, re.IGNORECASE) + if m: + seed = int(m.group(1)) + continue + + m = re.match(r"s (\d+)", parg, re.IGNORECASE) + if m: # steps + sample_steps = max(1, min(1000, int(m.group(1)))) + continue + + m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE) + if m: # scale + scale = float(m.group(1)) + continue + + m = re.match(r"n (.+)", parg, re.IGNORECASE) + if m: # negative prompt + negative_prompt = m.group(1) + continue + + m = re.match(r"cn (.+)", parg, re.IGNORECASE) + if m: # negative prompt + controlnet_image = m.group(1) + continue + + except ValueError as ex: + print(f"Exception in parsing / 解析エラー: {parg}") + print(ex) + + if seed is not None: + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if prompt_replacement is not None: + prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1]) + if negative_prompt is not None: + negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1]) + + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + + height = max(64, height - height % 8) # round to divisible by 8 + width = max(64, width - width % 8) # round to divisible by 8 + print(f"prompt: {prompt}") + print(f"negative_prompt: {negative_prompt}") + print(f"height: {height}") + print(f"width: {width}") + print(f"sample_steps: {sample_steps}") + print(f"scale: {scale}") + image = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=sample_steps, + guidance_scale=scale, + negative_prompt=negative_prompt, + controlnet=controlnet, + controlnet_image=controlnet_image, + ).images[0] + + ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) + num_suffix = f"e{epoch:06d}" if epoch is not None else f"{steps:06d}" + seed_suffix = "" if seed is None else f"_{seed}" + img_filename = ( + f"{'' if args.output_name is None else args.output_name + '_'}{ts_str}_{num_suffix}_{i:02d}{seed_suffix}.png" + ) + + image.save(os.path.join(save_dir, img_filename)) + + # wandb有効時のみログを送信 + try: + wandb_tracker = accelerator.get_tracker("wandb") + try: + import wandb + except ImportError: # 事前に一度確認するのでここはエラー出ないはず + raise ImportError("No wandb / wandb がインストールされていないようです") + + wandb_tracker.log({f"sample_{i}": wandb.Image(image)}) + except: # wandb 無効時 + pass # clear pipeline and cache to reduce vram usage del pipeline