diff --git a/library/leco_train_util.py b/library/leco_train_util.py index 987f1233..7a67c904 100644 --- a/library/leco_train_util.py +++ b/library/leco_train_util.py @@ -320,10 +320,10 @@ def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]: if isinstance(data, list): for item in data: - if "target_class" in item: - append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")]) - else: - append_prompt_item(item, default_prompt_values) + if "target_class" in item: + append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")]) + else: + append_prompt_item(item, default_prompt_values) elif isinstance(data, dict): if "prompts" in data: defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}} @@ -406,11 +406,11 @@ def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_p def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor: return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0) + + def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL: text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0) - pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave( - batch_size, dim=0 - ) + pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0) return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds) @@ -499,7 +499,6 @@ def predict_noise_xl( return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - def diffusion_xl( unet, scheduler, diff --git a/library/train_util.py b/library/train_util.py index 0df51852..d87a73a5 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset): return all( [ not ( - subset.caption_dropout_rate > 0 and not cache_supports_dropout + subset.caption_dropout_rate > 0 + and not cache_supports_dropout or subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0 @@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset): filtered_img_paths.append(img_path) filtered_sizes.append(size) if len(filtered_img_paths) < len(img_paths): - logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}") + logger.info( + f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}" + ) img_paths = filtered_img_paths sizes = filtered_sizes @@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset): len(missing_imgs) == 0 ), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}" if len(extra_imgs) > 0: - logger.warning( - f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}" - ) + logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}") self.conditioning_image_transforms = IMAGE_TRANSFORMS