import math from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union import torch import yaml from library import sdxl_train_util ResolutionValue = Union[int, Tuple[int, int]] @dataclass class PromptEmbedsXL: text_embeds: torch.Tensor pooled_embeds: torch.Tensor class PromptEmbedsCache: def __init__(self): self.prompts: dict[str, Any] = {} def __setitem__(self, name: str, value: Any) -> None: self.prompts[name] = value def __getitem__(self, name: str) -> Optional[Any]: return self.prompts.get(name) @dataclass class PromptSettings: target: str positive: Optional[str] = None unconditional: str = "" neutral: Optional[str] = None action: str = "erase" guidance_scale: float = 1.0 resolution: ResolutionValue = 512 dynamic_resolution: bool = False batch_size: int = 1 dynamic_crops: bool = False multiplier: float = 1.0 weight: float = 1.0 def __post_init__(self): if self.positive is None: self.positive = self.target if self.neutral is None: self.neutral = self.unconditional if self.action not in ("erase", "enhance"): raise ValueError(f"Invalid action: {self.action}") self.guidance_scale = float(self.guidance_scale) self.batch_size = int(self.batch_size) self.multiplier = float(self.multiplier) self.weight = float(self.weight) self.dynamic_resolution = bool(self.dynamic_resolution) self.dynamic_crops = bool(self.dynamic_crops) self.resolution = normalize_resolution(self.resolution) def get_resolution(self) -> Tuple[int, int]: if isinstance(self.resolution, tuple): return self.resolution return (self.resolution, self.resolution) def build_target(self, positive_latents, neutral_latents, unconditional_latents): offset = self.guidance_scale * (positive_latents - unconditional_latents) if self.action == "erase": return neutral_latents - offset return neutral_latents + offset def normalize_resolution(value: Any) -> ResolutionValue: if isinstance(value, tuple): if len(value) != 2: raise ValueError(f"resolution tuple must have 2 items: {value}") return (int(value[0]), int(value[1])) if isinstance(value, list): if len(value) == 2 and all(isinstance(v, (int, float)) for v in value): return (int(value[0]), int(value[1])) raise ValueError(f"resolution list must have 2 numeric items: {value}") return int(value) def _read_non_empty_lines(path: Union[str, Path]) -> List[str]: with open(path, "r", encoding="utf-8") as f: return [line.strip() for line in f.readlines() if line.strip()] def _recognized_prompt_keys() -> set[str]: return { "target", "positive", "unconditional", "neutral", "action", "guidance_scale", "resolution", "dynamic_resolution", "batch_size", "dynamic_crops", "multiplier", "weight", } def _recognized_slider_keys() -> set[str]: return { "target_class", "positive", "negative", "neutral", "guidance_scale", "resolution", "resolutions", "dynamic_resolution", "batch_size", "dynamic_crops", "multiplier", "weight", } def _merge_known_defaults(defaults: dict[str, Any], item: dict[str, Any], known_keys: Iterable[str]) -> dict[str, Any]: merged = {k: v for k, v in defaults.items() if k in known_keys} merged.update(item) return merged def _normalize_resolution_values(value: Any) -> List[ResolutionValue]: if value is None: return [512] if isinstance(value, list) and value and isinstance(value[0], (list, tuple)): return [normalize_resolution(v) for v in value] return [normalize_resolution(value)] def _expand_slider_target(target: dict[str, Any], neutral: str) -> List[PromptSettings]: target_class = str(target.get("target_class", "")) positive = str(target.get("positive", "") or "") negative = str(target.get("negative", "") or "") guidance_scale = target.get("guidance_scale", 1.0) dynamic_resolution = target.get("dynamic_resolution", False) batch_size = target.get("batch_size", 1) dynamic_crops = target.get("dynamic_crops", False) multiplier = target.get("multiplier", 1.0) weight = target.get("weight", 1.0) resolutions = _normalize_resolution_values(target.get("resolutions", target.get("resolution", 512))) if not positive.strip() and not negative.strip(): raise ValueError("slider target requires either positive or negative prompt") prompt_settings: List[PromptSettings] = [] for resolution in resolutions: if positive.strip() and negative.strip(): prompt_settings.extend( [ PromptSettings( target=target_class, positive=negative, unconditional=positive, neutral=neutral, action="erase", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=multiplier, weight=weight, ), PromptSettings( target=target_class, positive=positive, unconditional=negative, neutral=neutral, action="enhance", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=multiplier, weight=weight, ), PromptSettings( target=target_class, positive=positive, unconditional=negative, neutral=neutral, action="erase", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=-multiplier, weight=weight, ), PromptSettings( target=target_class, positive=negative, unconditional=positive, neutral=neutral, action="enhance", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=-multiplier, weight=weight, ), ] ) elif negative.strip(): prompt_settings.extend( [ PromptSettings( target=target_class, positive=negative, unconditional="", neutral=neutral, action="erase", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=multiplier, weight=weight, ), PromptSettings( target=target_class, positive=negative, unconditional="", neutral=neutral, action="enhance", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=-multiplier, weight=weight, ), ] ) else: prompt_settings.extend( [ PromptSettings( target=target_class, positive=positive, unconditional="", neutral=neutral, action="enhance", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=multiplier, weight=weight, ), PromptSettings( target=target_class, positive=positive, unconditional="", neutral=neutral, action="erase", guidance_scale=guidance_scale, resolution=resolution, dynamic_resolution=dynamic_resolution, batch_size=batch_size, dynamic_crops=dynamic_crops, multiplier=-multiplier, weight=weight, ), ] ) return prompt_settings def load_prompt_settings(path: Union[str, Path]) -> List[PromptSettings]: path = Path(path) with open(path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) if data is None: raise ValueError("prompt file is empty") default_prompt_values = { "guidance_scale": 1.0, "resolution": 512, "dynamic_resolution": False, "batch_size": 1, "dynamic_crops": False, "multiplier": 1.0, "weight": 1.0, } prompt_settings: List[PromptSettings] = [] def append_prompt_item(item: dict[str, Any], defaults: dict[str, Any]) -> None: merged = _merge_known_defaults(defaults, item, _recognized_prompt_keys()) prompt_settings.append(PromptSettings(**merged)) def append_slider_item(item: dict[str, Any], defaults: dict[str, Any], neutral_values: Sequence[str]) -> None: merged = _merge_known_defaults(defaults, item, _recognized_slider_keys()) if not neutral_values: neutral_values = [str(merged.get("neutral", "") or "")] for neutral in neutral_values: prompt_settings.extend(_expand_slider_target(merged, neutral)) if isinstance(data, list): for item in data: if "target_class" in item: append_slider_item(item, default_prompt_values, [str(item.get("neutral", "") or "")]) else: append_prompt_item(item, default_prompt_values) elif isinstance(data, dict): if "prompts" in data: defaults = {**default_prompt_values, **{k: v for k, v in data.items() if k in _recognized_prompt_keys()}} for item in data["prompts"]: append_prompt_item(item, defaults) else: slider_config = data.get("slider", data) targets = slider_config.get("targets") if targets is None: if "target_class" in slider_config: targets = [slider_config] elif "target" in slider_config: targets = [slider_config] else: raise ValueError("prompt file does not contain prompts or slider targets") if len(targets) == 0: raise ValueError("prompt file contains an empty targets list") if "target" in targets[0]: defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_prompt_keys()}} for item in targets: append_prompt_item(item, defaults) else: defaults = {**default_prompt_values, **{k: v for k, v in slider_config.items() if k in _recognized_slider_keys()}} neutral_values: List[str] = [] if "neutrals" in slider_config: neutral_values.extend(str(v) for v in slider_config["neutrals"]) if "neutral_prompt_file" in slider_config: neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["neutral_prompt_file"])) if "prompt_file" in slider_config: neutral_values.extend(_read_non_empty_lines(path.parent / slider_config["prompt_file"])) if not neutral_values: neutral_values = [str(slider_config.get("neutral", "") or "")] for item in targets: item_neutrals = neutral_values if "neutrals" in item: item_neutrals = [str(v) for v in item["neutrals"]] elif "neutral_prompt_file" in item: item_neutrals = _read_non_empty_lines(path.parent / item["neutral_prompt_file"]) elif "prompt_file" in item: item_neutrals = _read_non_empty_lines(path.parent / item["prompt_file"]) elif "neutral" in item: item_neutrals = [str(item["neutral"] or "")] append_slider_item(item, defaults, item_neutrals) else: raise ValueError("prompt file must be a list or mapping") if not prompt_settings: raise ValueError("no prompt settings found") return prompt_settings def encode_prompt_sd(tokenize_strategy, text_encoding_strategy, text_encoder, prompt: str) -> torch.Tensor: tokens = tokenize_strategy.tokenize(prompt) return text_encoding_strategy.encode_tokens(tokenize_strategy, [text_encoder], tokens)[0] def encode_prompt_sdxl(tokenize_strategy, text_encoding_strategy, text_encoders, prompt: str) -> PromptEmbedsXL: tokens = tokenize_strategy.tokenize(prompt) hidden1, hidden2, pool2 = text_encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens) return PromptEmbedsXL(torch.cat([hidden1, hidden2], dim=2), pool2) def apply_noise_offset(latents: torch.Tensor, noise_offset: Optional[float]) -> torch.Tensor: if noise_offset is None: return latents return latents + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device) def get_initial_latents(scheduler, batch_size: int, height: int, width: int, n_prompts: int = 1) -> torch.Tensor: noise = torch.randn( (batch_size, 4, height // 8, width // 8), device="cpu", ).repeat(n_prompts, 1, 1, 1) return noise * scheduler.init_noise_sigma def concat_embeddings(unconditional: torch.Tensor, conditional: torch.Tensor, batch_size: int) -> torch.Tensor: return torch.cat([unconditional, conditional], dim=0).repeat_interleave(batch_size, dim=0) def concat_embeddings_xl(unconditional: PromptEmbedsXL, conditional: PromptEmbedsXL, batch_size: int) -> PromptEmbedsXL: text_embeds = torch.cat([unconditional.text_embeds, conditional.text_embeds], dim=0).repeat_interleave(batch_size, dim=0) pooled_embeds = torch.cat([unconditional.pooled_embeds, conditional.pooled_embeds], dim=0).repeat_interleave( batch_size, dim=0 ) return PromptEmbedsXL(text_embeds=text_embeds, pooled_embeds=pooled_embeds) def predict_noise(unet, scheduler, timestep, latents: torch.Tensor, text_embeddings: torch.Tensor, guidance_scale: float = 1.0): latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) noise_pred = unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) def diffusion( unet, scheduler, latents: torch.Tensor, text_embeddings: torch.Tensor, total_timesteps: int, start_timesteps: int = 0, guidance_scale: float = 3.0, ): for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: noise_pred = predict_noise(unet, scheduler, timestep, latents, text_embeddings, guidance_scale=guidance_scale) latents = scheduler.step(noise_pred, timestep, latents).prev_sample return latents def get_add_time_ids( height: int, width: int, dynamic_crops: bool = False, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, ) -> torch.Tensor: if dynamic_crops: random_scale = torch.rand(1).item() * 2 + 1 original_size = (int(height * random_scale), int(width * random_scale)) crops_coords_top_left = ( torch.randint(0, max(original_size[0] - height, 1), (1,)).item(), torch.randint(0, max(original_size[1] - width, 1), (1,)).item(), ) target_size = (height, width) else: original_size = (height, width) crops_coords_top_left = (0, 0) target_size = (height, width) add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=dtype) if device is not None: add_time_ids = add_time_ids.to(device) return add_time_ids def predict_noise_xl( unet, scheduler, timestep, latents: torch.Tensor, prompt_embeds: PromptEmbedsXL, add_time_ids: torch.Tensor, guidance_scale: float = 1.0, ): latent_model_input = torch.cat([latents] * 2) latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) orig_size = add_time_ids[:, :2] crop_size = add_time_ids[:, 2:4] target_size = add_time_ids[:, 4:6] size_embeddings = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, latent_model_input.device) vector_embedding = torch.cat([prompt_embeds.pooled_embeds, size_embeddings.to(prompt_embeds.pooled_embeds.dtype)], dim=1) noise_pred = unet(latent_model_input, timestep, prompt_embeds.text_embeds, vector_embedding) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) return noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) def diffusion_xl( unet, scheduler, latents: torch.Tensor, prompt_embeds: PromptEmbedsXL, add_time_ids: torch.Tensor, total_timesteps: int, start_timesteps: int = 0, guidance_scale: float = 3.0, ): for timestep in scheduler.timesteps[start_timesteps:total_timesteps]: noise_pred = predict_noise_xl( unet, scheduler, timestep, latents, prompt_embeds=prompt_embeds, add_time_ids=add_time_ids, guidance_scale=guidance_scale, ) latents = scheduler.step(noise_pred, timestep, latents).prev_sample return latents def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> Tuple[int, int]: max_resolution = bucket_resolution min_resolution = bucket_resolution // 2 step = 64 min_step = min_resolution // step max_step = max_resolution // step height = torch.randint(min_step, max_step + 1, (1,)).item() * step width = torch.randint(min_step, max_step + 1, (1,)).item() * step return height, width def get_random_resolution(prompt: PromptSettings) -> Tuple[int, int]: height, width = prompt.get_resolution() if prompt.dynamic_resolution and height == width: return get_random_resolution_in_bucket(height) return height, width