Revert "Merge branch 'sd3' into flux-lora-init"

This reverts commit 67f8e17a46, reversing
changes made to 9d7e2dd7c9.
This commit is contained in:
rockerBOO
2025-04-10 21:38:05 -04:00
parent 67f8e17a46
commit 7dd00204eb
6 changed files with 37 additions and 285 deletions

View File

@@ -14,14 +14,10 @@ The command to install PyTorch is as follows:
### Recent Updates
Apr 6, 2025:
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1939](https://github.com/kohya-ss/sd-scripts/pull/1939).
Mar 20, 2025:
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).

View File

@@ -366,6 +366,8 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
@@ -408,34 +410,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
sigmas = None
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
sigmas = torch.rand((bsz,), device=device)
t = torch.rand((bsz,), device=device)
timesteps = sigmas * num_timesteps
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
elif args.timestep_sampling == "flux_shift":
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -446,24 +456,12 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * num_timesteps).long()
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas

View File

@@ -2153,9 +2153,8 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
super().__init__(resolution, network_multiplier, debug_dataset)
self.batch_size = batch_size

View File

@@ -413,19 +413,9 @@ def resize_image(image: np.ndarray, width: int, height: int, resized_width: int,
Returns:
image
"""
# Ensure all size parameters are actual integers
width = int(width)
height = int(height)
resized_width = int(resized_width)
resized_height = int(resized_height)
if resize_interpolation is None:
if width >= resized_width and height >= resized_height:
resize_interpolation = "area"
else:
resize_interpolation = "lanczos"
resize_interpolation = "lanczos" if width > resized_width and height > resized_height else "area"
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]

View File

@@ -1,220 +0,0 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
self.config = MagicMock()
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
# Create fixtures for commonly used objects
@pytest.fixture
def args():
args = MagicMock()
args.timestep_sampling = "uniform"
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
args.ip_noise_gamma = None
args.ip_noise_gamma_random_strength = False
return args
@pytest.fixture
def noise_scheduler():
return MockNoiseScheduler(num_train_timesteps=1000)
@pytest.fixture
def latents():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def noise():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def device():
# return "cuda" if torch.cuda.is_available() else "cpu"
return "cpu"
# Mock the required functions
@pytest.fixture(autouse=True)
def mock_functions():
with (
patch("torch.sigmoid", side_effect=torch.sigmoid),
patch("torch.rand", side_effect=torch.rand),
patch("torch.randn", side_effect=torch.randn),
):
yield
# Test different timestep sampling methods
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "sigmoid"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "shift"
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "flux_shift"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.5
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.1
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
# Test different batch sizes
def test_different_batch_size(args, noise_scheduler, device):
latents = torch.randn(5, 4, 8, 8) # batch size of 5
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
# Test different image sizes
def test_different_image_size(args, noise_scheduler, device):
latents = torch.randn(2, 4, 16, 16) # larger image size
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
# Test edge cases
def test_zero_batch_size(args, noise_scheduler, device):
with pytest.raises(AssertionError): # expecting an error with zero batch size
latents = torch.randn(0, 4, 8, 8)
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
latents = torch.randn(2, 4, 8, 8)
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)

View File

@@ -389,18 +389,7 @@ class NetworkTrainer:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):