mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
format: format
This commit is contained in:
@@ -32,6 +32,7 @@ from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler, get_sigmas
|
||||
|
||||
# Anima-specific training arguments
|
||||
|
||||
|
||||
def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
"""Add Anima-specific training arguments to the parser."""
|
||||
parser.add_argument(
|
||||
@@ -169,20 +170,20 @@ def get_noisy_model_input_and_timesteps(
|
||||
"""
|
||||
bs = latents.shape[0]
|
||||
|
||||
timestep_sample_method = getattr(args, 'timestep_sample_method', 'logit_normal')
|
||||
sigmoid_scale = getattr(args, 'sigmoid_scale', 1.0)
|
||||
shift = getattr(args, 'discrete_flow_shift', 1.0)
|
||||
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
|
||||
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
|
||||
shift = getattr(args, "discrete_flow_shift", 1.0)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
if timestep_sample_method == "logit_normal":
|
||||
dist = torch.distributions.normal.Normal(0, 1)
|
||||
elif timestep_sample_method == 'uniform':
|
||||
elif timestep_sample_method == "uniform":
|
||||
dist = torch.distributions.uniform.Uniform(0, 1)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
|
||||
|
||||
t = dist.sample((bs,)).to(device)
|
||||
|
||||
if timestep_sample_method == 'logit_normal':
|
||||
if timestep_sample_method == "logit_normal":
|
||||
t = t * sigmoid_scale
|
||||
t = torch.sigmoid(t)
|
||||
|
||||
@@ -196,10 +197,10 @@ def get_noisy_model_input_and_timesteps(
|
||||
# Create noisy input: (1 - t) * latents + t * noise
|
||||
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
|
||||
|
||||
ip_noise_gamma = getattr(args, 'ip_noise_gamma', None)
|
||||
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
|
||||
if ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if getattr(args, 'ip_noise_gamma_random_strength', False):
|
||||
if getattr(args, "ip_noise_gamma_random_strength", False):
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
@@ -213,6 +214,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
|
||||
# Loss weighting
|
||||
|
||||
|
||||
def compute_loss_weighting_for_anima(weighting_scheme: str, sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss weighting for Anima training.
|
||||
|
||||
@@ -276,15 +278,15 @@ def get_anima_param_groups(
|
||||
# Store original name for debugging
|
||||
p.original_name = name
|
||||
|
||||
if 'llm_adapter' in name:
|
||||
if "llm_adapter" in name:
|
||||
llm_adapter_params.append(p)
|
||||
elif '.self_attn' in name:
|
||||
elif ".self_attn" in name:
|
||||
self_attn_params.append(p)
|
||||
elif '.cross_attn' in name:
|
||||
elif ".cross_attn" in name:
|
||||
cross_attn_params.append(p)
|
||||
elif '.mlp' in name:
|
||||
elif ".mlp" in name:
|
||||
mlp_params.append(p)
|
||||
elif '.adaln_modulation' in name:
|
||||
elif ".adaln_modulation" in name:
|
||||
mod_params.append(p)
|
||||
else:
|
||||
base_params.append(p)
|
||||
@@ -311,9 +313,9 @@ def get_anima_param_groups(
|
||||
p.requires_grad_(False)
|
||||
logger.info(f" Frozen {name} params ({len(params)} parameters)")
|
||||
elif len(params) > 0:
|
||||
param_groups.append({'params': params, 'lr': lr})
|
||||
param_groups.append({"params": params, "lr": lr})
|
||||
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group['params'] if p.requires_grad)
|
||||
total_trainable = sum(p.numel() for group in param_groups for p in group["params"] if p.requires_grad)
|
||||
logger.info(f"Total trainable parameters: {total_trainable:,}")
|
||||
|
||||
return param_groups
|
||||
@@ -328,10 +330,9 @@ def save_anima_model_on_train_end(
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
):
|
||||
"""Save Anima model at the end of training."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
dit_sd = dit.state_dict()
|
||||
# Save with 'net.' prefix for ComfyUI compatibility
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
@@ -350,10 +351,9 @@ def save_anima_model_on_epoch_end_or_stepwise(
|
||||
dit: anima_models.MiniTrainDIT,
|
||||
):
|
||||
"""Save Anima model at epoch end or specific steps."""
|
||||
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True
|
||||
)
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, False, False, False, is_stable_diffusion_ckpt=True)
|
||||
dit_sd = dit.state_dict()
|
||||
anima_utils.save_anima_model(ckpt_file, dit_sd, save_dtype)
|
||||
|
||||
@@ -410,9 +410,7 @@ def do_sample(
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = None
|
||||
noise = torch.randn(
|
||||
latent.size(), dtype=torch.float32, generator=generator, device="cpu"
|
||||
).to(dtype).to(device)
|
||||
noise = torch.randn(latent.size(), dtype=torch.float32, generator=generator, device="cpu").to(dtype).to(device)
|
||||
|
||||
# Timestep schedule: linear from 1.0 to 0.0
|
||||
sigmas = torch.linspace(1.0, 0.0, steps + 1, device=device, dtype=dtype)
|
||||
@@ -512,10 +510,20 @@ def sample_images(
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
for prompt_dict in prompts:
|
||||
_sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
vae_scale,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
)
|
||||
|
||||
# Restore RNG state
|
||||
@@ -527,10 +535,20 @@ def sample_images(
|
||||
|
||||
|
||||
def _sample_image_inference(
|
||||
accelerator, args, dit, text_encoder, vae, vae_scale,
|
||||
tokenize_strategy, text_encoding_strategy,
|
||||
save_dir, prompt_dict, epoch, steps,
|
||||
sample_prompts_te_outputs, prompt_replacement,
|
||||
accelerator,
|
||||
args,
|
||||
dit,
|
||||
text_encoder,
|
||||
vae,
|
||||
vae_scale,
|
||||
tokenize_strategy,
|
||||
text_encoding_strategy,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
epoch,
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
):
|
||||
"""Generate a single sample image."""
|
||||
prompt = prompt_dict.get("prompt", "")
|
||||
@@ -585,7 +603,7 @@ def _sample_image_inference(
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device)
|
||||
|
||||
# Process through LLM adapter if available
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
|
||||
crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=prompt_embeds,
|
||||
target_input_ids=t5_input_ids,
|
||||
@@ -613,7 +631,7 @@ def _sample_image_inference(
|
||||
neg_t5_ids = neg_t5_ids.to(accelerator.device, dtype=torch.long)
|
||||
neg_t5_am = neg_t5_am.to(accelerator.device)
|
||||
|
||||
if dit.use_llm_adapter and hasattr(dit, 'llm_adapter'):
|
||||
if dit.use_llm_adapter and hasattr(dit, "llm_adapter"):
|
||||
neg_crossattn_emb = dit.llm_adapter(
|
||||
source_hidden_states=neg_pe,
|
||||
target_input_ids=neg_t5_ids,
|
||||
@@ -627,9 +645,16 @@ def _sample_image_inference(
|
||||
# Generate sample
|
||||
clean_memory_on_device(accelerator.device)
|
||||
latents = do_sample(
|
||||
height, width, seed, dit, crossattn_emb,
|
||||
sample_steps, dit.t_embedding_norm.weight.dtype,
|
||||
accelerator.device, scale, neg_crossattn_emb,
|
||||
height,
|
||||
width,
|
||||
seed,
|
||||
dit,
|
||||
crossattn_emb,
|
||||
sample_steps,
|
||||
dit.t_embedding_norm.weight.dtype,
|
||||
accelerator.device,
|
||||
scale,
|
||||
neg_crossattn_emb,
|
||||
)
|
||||
|
||||
# Decode latents
|
||||
@@ -662,4 +687,5 @@ def _sample_image_inference(
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
import wandb
|
||||
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False)
|
||||
|
||||
Reference in New Issue
Block a user