fix: format

This commit is contained in:
Kohya S
2026-03-28 19:28:24 +09:00
parent 4ea6032c66
commit 61d705f0b9
2 changed files with 13 additions and 13 deletions

View File

@@ -320,10 +320,10 @@ def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
if isinstance(data, list):
for item in data:
if "target_class" in item:
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
else:
append_prompt_item(item, default_prompt_values)
if "target_class" in item:
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
else:
append_prompt_item(item, default_prompt_values)
elif isinstance(data, dict):
if "prompts" in data:
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
@@ -406,11 +406,11 @@ def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_p
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(
batch_size, dim=0
)
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
@@ -499,7 +499,6 @@ def predict_noise_xl(
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
def diffusion_xl(
unet,
scheduler,

View File

@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
return all(
[
not (
subset.caption_dropout_rate > 0 and not cache_supports_dropout
subset.caption_dropout_rate > 0
and not cache_supports_dropout
or subset.shuffle_caption
or subset.token_warmup_step > 0
or subset.caption_tag_dropout_rate > 0
@@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset):
filtered_img_paths.append(img_path)
filtered_sizes.append(size)
if len(filtered_img_paths) < len(img_paths):
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
logger.info(
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
)
img_paths = filtered_img_paths
sizes = filtered_sizes
@@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset):
len(missing_imgs) == 0
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
if len(extra_imgs) > 0:
logger.warning(
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
)
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
self.conditioning_image_transforms = IMAGE_TRANSFORMS