fix: remove unused code

This commit is contained in:
Kohya S
2026-02-11 17:23:58 +09:00
parent b67cc5a457
commit 9349c91c89
6 changed files with 105 additions and 958 deletions

View File

@@ -4,18 +4,16 @@ import argparse
import math
import os
import time
from typing import Dict, List, Optional, Tuple, Union
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from safetensors.torch import save_file
from accelerate import Accelerator, PartialState
from accelerate import Accelerator
from tqdm import tqdm
from PIL import Image
from library.device_utils import init_ipex, clean_memory_on_device
from library import anima_models, anima_utils, strategy_base, train_util, qwen_image_autoencoder_kl
from library import anima_models, anima_utils, train_util, qwen_image_autoencoder_kl
init_ipex()
@@ -125,73 +123,6 @@ def add_anima_training_arguments(parser: argparse.ArgumentParser):
)
# Noise & Timestep sampling (Rectified Flow)
def get_noisy_model_input_and_timesteps(
args,
latents: torch.Tensor,
noise: torch.Tensor,
device: torch.device,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate noisy model input and timesteps for rectified flow training.
Rectified flow: noisy_input = (1 - t) * latents + t * noise
Target: noise - latents
Args:
args: Training arguments with timestep_sample_method, sigmoid_scale, discrete_flow_shift
latents: Clean latent tensors
noise: Random noise tensors
device: Target device
dtype: Target dtype
Returns:
(noisy_model_input, timesteps, sigmas)
"""
bs = latents.shape[0]
timestep_sample_method = getattr(args, "timestep_sample_method", "logit_normal")
sigmoid_scale = getattr(args, "sigmoid_scale", 1.0)
shift = getattr(args, "discrete_flow_shift", 1.0)
if timestep_sample_method == "logit_normal":
dist = torch.distributions.normal.Normal(0, 1)
elif timestep_sample_method == "uniform":
dist = torch.distributions.uniform.Uniform(0, 1)
else:
raise NotImplementedError(f"Unknown timestep_sample_method: {timestep_sample_method}")
t = dist.sample((bs,)).to(device)
if timestep_sample_method == "logit_normal":
t = t * sigmoid_scale
t = torch.sigmoid(t)
# Apply shift
if shift is not None and shift != 1.0:
t = (t * shift) / (1 + (shift - 1) * t)
# Clamp to avoid exact 0 or 1
t = t.clamp(1e-5, 1.0 - 1e-5)
# Create noisy input: (1 - t) * latents + t * noise
t_expanded = t.view(-1, *([1] * (latents.ndim - 1)))
ip_noise_gamma = getattr(args, "ip_noise_gamma", None)
if ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if getattr(args, "ip_noise_gamma_random_strength", False):
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * ip_noise_gamma
noisy_model_input = (1 - t_expanded) * latents + t_expanded * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1 - t_expanded) * latents + t_expanded * noise
# Sigmas for potential loss weighting
sigmas = t.view(-1, 1)
return noisy_model_input.to(dtype), t.to(dtype), sigmas.to(dtype)
# Loss weighting