mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
fix: remove unused code
This commit is contained in:
@@ -4,18 +4,16 @@ import argparse
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from accelerate import Accelerator
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
|
||||
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -125,73 +123,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
|
||||
# Noise & Timestep sampling (Rectified Flow)
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args,
|
||||
latents: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Generate noisy model input and timesteps for rectified flow training.
|
||||
|
||||
Rectified flow: noisy_input = (1 - t) * latents + t * noise
|
||||
Target: noise - latents
|
||||
|
||||
Args:
|
||||
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
|
||||
latents: Clean latent tensors
|
||||
noise: Random noise tensors
|
||||
device: Target device
|
||||
dtype: Target dtype
|
||||
|
||||
Returns:
|
||||
(noisy_model_input, timesteps, sigmas)
|
||||
"""
|
||||
bs = latents.shape[0]
|
||||
|
||||
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
|
||||
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
|
||||
shift = getattr(args, "discrete_flow_shift", 1.0)
|
||||
|
||||
if timestep_sample_method == "logit_normal":
|
||||
dist = torch.distributions.normal.Normal(0, 1)
|
||||
elif timestep_sample_method == "uniform":
|
||||
dist = torch.distributions.uniform.Uniform(0, 1)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
|
||||
|
||||
t = dist.sample((bs,)).to(device)
|
||||
|
||||
if timestep_sample_method == "logit_normal":
|
||||
t = t * sigmoid_scale
|
||||
t = torch.sigmoid(t)
|
||||
|
||||
# Apply shift
|
||||
if shift is not None and shift != 1.0:
|
||||
t = (t * shift) / (1 + (shift - 1) * t)
|
||||
|
||||
# Clamp to avoid exact 0 or 1
|
||||
t = t.clamp(1e-5, 1.0 - 1e-5)
|
||||
|
||||
# Create noisy input: (1 - t) * latents + t * noise
|
||||
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
|
||||
|
||||
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
|
||||
if ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if getattr(args, "ip_noise_gamma_random_strength", False):
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
|
||||
|
||||
# Sigmas for potential loss weighting
|
||||
sigmas = t.view(-1, 1)
|
||||
|
||||
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
|
||||
|
||||
|
||||
# Loss weighting
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user