mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
fix: format
This commit is contained in:
@@ -320,10 +320,10 @@ def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]:
|
||||
|
||||
if isinstance(data, list):
|
||||
for item in data:
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, default_prompt_values)
|
||||
if "target_class" in item:
|
||||
append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")])
|
||||
else:
|
||||
append_prompt_item(item, default_prompt_values)
|
||||
elif isinstance(data, dict):
|
||||
if "prompts" in data:
|
||||
defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}}
|
||||
@@ -406,11 +406,11 @@ def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_p
|
||||
|
||||
def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor:
|
||||
return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
|
||||
|
||||
def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL:
|
||||
text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(
|
||||
batch_size, dim=0
|
||||
)
|
||||
pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave(batch_size, dim=0)
|
||||
return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds)
|
||||
|
||||
|
||||
@@ -499,7 +499,6 @@ def predict_noise_xl(
|
||||
return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
|
||||
|
||||
def diffusion_xl(
|
||||
unet,
|
||||
scheduler,
|
||||
|
||||
@@ -1106,7 +1106,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0 and not cache_supports_dropout
|
||||
subset.caption_dropout_rate > 0
|
||||
and not cache_supports_dropout
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
@@ -2056,7 +2057,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
filtered_img_paths.append(img_path)
|
||||
filtered_sizes.append(size)
|
||||
if len(filtered_img_paths) < len(img_paths):
|
||||
logger.info(f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}")
|
||||
logger.info(
|
||||
f"filtered {len(img_paths) - len(filtered_img_paths)} images by original resolution from {subset.image_dir}"
|
||||
)
|
||||
img_paths = filtered_img_paths
|
||||
sizes = filtered_sizes
|
||||
|
||||
@@ -2542,9 +2545,7 @@ class ControlNetDataset(BaseDataset):
|
||||
len(missing_imgs) == 0
|
||||
), f"missing conditioning data for {len(missing_imgs)} images / 制御用画像が見つかりませんでした: {missing_imgs}"
|
||||
if len(extra_imgs) > 0:
|
||||
logger.warning(
|
||||
f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}"
|
||||
)
|
||||
logger.warning(f"extra conditioning data for {len(extra_imgs)} images / 余分な制御用画像があります: {extra_imgs}")
|
||||
|
||||
self.conditioning_image_transforms = IMAGE_TRANSFORMS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user