diff --git a/README.md b/README.md index 2e80a697..4bc0c2b5 100644 --- a/README.md +++ b/README.md @@ -14,14 +14,10 @@ The command to install PyTorch is as follows: ### Recent Updates -Apr 6, 2025: -- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details. - - `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available. - Mar 30, 2025: - LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974). - Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details. -- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936). +- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1939](https://github.com/kohya-ss/sd-scripts/pull/1939). Mar 20, 2025: - `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985). diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index 0e73a01d..f7f06c5c 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -366,6 +366,8 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) return sigma @@ -408,34 +410,42 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None): def get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype + args, noise_scheduler, latents, noise, device, dtype ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, _, h, w = latents.shape - assert bsz > 0, "Batch size not large enough" - num_timesteps = noise_scheduler.config.num_train_timesteps + sigmas = None + if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid": - # Simple random sigma-based noise sampling + # Simple random t-based noise sampling if args.timestep_sampling == "sigmoid": # https://github.com/XLabs-AI/x-flux/tree/main - sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) + t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device)) else: - sigmas = torch.rand((bsz,), device=device) + t = torch.rand((bsz,), device=device) - timesteps = sigmas * num_timesteps + timesteps = t * 1000.0 + t = t.view(-1, 1, 1, 1) + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "shift": shift = args.discrete_flow_shift - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() - sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas) - timesteps = sigmas * num_timesteps + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise elif args.timestep_sampling == "flux_shift": - sigmas = torch.randn(bsz, device=device) - sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling - sigmas = sigmas.sigmoid() - mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size - sigmas = time_shift(mu, 1.0, sigmas) - timesteps = sigmas * num_timesteps + logits_norm = torch.randn(bsz, device=device) + logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling + timesteps = logits_norm.sigmoid() + mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) + timesteps = time_shift(mu, 1.0, timesteps) + + t = timesteps.view(-1, 1, 1, 1) + timesteps = timesteps * 1000.0 + noisy_model_input = (1 - t) * latents + t * noise else: # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -446,24 +456,12 @@ def get_noisy_model_input_and_timesteps( logit_std=args.logit_std, mode_scale=args.mode_scale, ) - indices = (u * num_timesteps).long() + indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=device) + + # Add noise according to flow matching. sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype) - - # Broadcast sigmas to latent shape - sigmas = sigmas.view(-1, 1, 1, 1) - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - if args.ip_noise_gamma: - xi = torch.randn_like(latents, device=latents.device, dtype=dtype) - if args.ip_noise_gamma_random_strength: - ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma) - else: - ip_noise_gamma = args.ip_noise_gamma - noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi) - else: - noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise + noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas diff --git a/library/train_util.py b/library/train_util.py index e8fd43a9..c9102f89 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2153,9 +2153,8 @@ class FineTuningDataset(BaseDataset): debug_dataset: bool, validation_seed: int, validation_split: float, - resize_interpolation: Optional[str], ) -> None: - super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation) + super().__init__(resolution, network_multiplier, debug_dataset) self.batch_size = batch_size diff --git a/library/utils.py b/library/utils.py index d0586b84..0f535a87 100644 --- a/library/utils.py +++ b/library/utils.py @@ -413,19 +413,9 @@ def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, Returns: image """ - - # Ensure all size parameters are actual integers - width = int(width) - height = int(height) - resized_width = int(resized_width) - resized_height = int(resized_height) - if resize_interpolation is None: - if width >= resized_width and height >= resized_height: - resize_interpolation = "area" - else: - resize_interpolation = "lanczos" - + resize_interpolation = "lanczos" if width > resized_width and height > resized_height else "area" + # we use PIL for lanczos (for backward compatibility) and box, cv2 for others use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"] diff --git a/tests/library/test_flux_train_utils.py b/tests/library/test_flux_train_utils.py deleted file mode 100644 index 2ad7ce4e..00000000 --- a/tests/library/test_flux_train_utils.py +++ /dev/null @@ -1,220 +0,0 @@ -import pytest -import torch -from unittest.mock import MagicMock, patch -from library.flux_train_utils import ( - get_noisy_model_input_and_timesteps, -) - -# Mock classes and functions -class MockNoiseScheduler: - def __init__(self, num_train_timesteps=1000): - self.config = MagicMock() - self.config.num_train_timesteps = num_train_timesteps - self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long) - - -# Create fixtures for commonly used objects -@pytest.fixture -def args(): - args = MagicMock() - args.timestep_sampling = "uniform" - args.weighting_scheme = "uniform" - args.logit_mean = 0.0 - args.logit_std = 1.0 - args.mode_scale = 1.0 - args.sigmoid_scale = 1.0 - args.discrete_flow_shift = 3.1582 - args.ip_noise_gamma = None - args.ip_noise_gamma_random_strength = False - return args - - -@pytest.fixture -def noise_scheduler(): - return MockNoiseScheduler(num_train_timesteps=1000) - - -@pytest.fixture -def latents(): - return torch.randn(2, 4, 8, 8) - - -@pytest.fixture -def noise(): - return torch.randn(2, 4, 8, 8) - - -@pytest.fixture -def device(): - # return "cuda" if torch.cuda.is_available() else "cpu" - return "cpu" - - -# Mock the required functions -@pytest.fixture(autouse=True) -def mock_functions(): - with ( - patch("torch.sigmoid", side_effect=torch.sigmoid), - patch("torch.rand", side_effect=torch.rand), - patch("torch.randn", side_effect=torch.randn), - ): - yield - - -# Test different timestep sampling methods -def test_uniform_sampling(args, noise_scheduler, latents, noise, device): - args.timestep_sampling = "uniform" - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - assert noisy_input.dtype == dtype - assert timesteps.dtype == dtype - - -def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device): - args.timestep_sampling = "sigmoid" - args.sigmoid_scale = 1.0 - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -def test_shift_sampling(args, noise_scheduler, latents, noise, device): - args.timestep_sampling = "shift" - args.sigmoid_scale = 1.0 - args.discrete_flow_shift = 3.1582 - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device): - args.timestep_sampling = "flux_shift" - args.sigmoid_scale = 1.0 - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -def test_weighting_scheme(args, noise_scheduler, latents, noise, device): - # Mock the necessary functions for this specific test - with patch("library.flux_train_utils.compute_density_for_timestep_sampling", - return_value=torch.tensor([0.3, 0.7], device=device)), \ - patch("library.flux_train_utils.get_sigmas", - return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)): - - args.timestep_sampling = "other" # Will trigger the weighting scheme path - args.weighting_scheme = "uniform" - args.logit_mean = 0.0 - args.logit_std = 1.0 - args.mode_scale = 1.0 - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps( - args, noise_scheduler, latents, noise, device, dtype - ) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -# Test IP noise options -def test_with_ip_noise(args, noise_scheduler, latents, noise, device): - args.ip_noise_gamma = 0.5 - args.ip_noise_gamma_random_strength = False - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device): - args.ip_noise_gamma = 0.1 - args.ip_noise_gamma_random_strength = True - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (latents.shape[0],) - assert sigmas.shape == (latents.shape[0], 1, 1, 1) - - -# Test different data types -def test_float16_dtype(args, noise_scheduler, latents, noise, device): - dtype = torch.float16 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.dtype == dtype - assert timesteps.dtype == dtype - - -# Test different batch sizes -def test_different_batch_size(args, noise_scheduler, device): - latents = torch.randn(5, 4, 8, 8) # batch size of 5 - noise = torch.randn(5, 4, 8, 8) - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (5,) - assert sigmas.shape == (5, 1, 1, 1) - - -# Test different image sizes -def test_different_image_size(args, noise_scheduler, device): - latents = torch.randn(2, 4, 16, 16) # larger image size - noise = torch.randn(2, 4, 16, 16) - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) - assert sigmas.shape == (2, 1, 1, 1) - - -# Test edge cases -def test_zero_batch_size(args, noise_scheduler, device): - with pytest.raises(AssertionError): # expecting an error with zero batch size - latents = torch.randn(0, 4, 8, 8) - noise = torch.randn(0, 4, 8, 8) - dtype = torch.float32 - - get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - -def test_different_timestep_count(args, device): - noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count - latents = torch.randn(2, 4, 8, 8) - noise = torch.randn(2, 4, 8, 8) - dtype = torch.float32 - - noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) - - assert noisy_input.shape == latents.shape - assert timesteps.shape == (2,) - # Check that timesteps are within the proper range - assert torch.all(timesteps < 500) diff --git a/train_network.py b/train_network.py index d6bc66ed..f66cdeb4 100644 --- a/train_network.py +++ b/train_network.py @@ -389,18 +389,7 @@ class NetworkTrainer: latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device)) else: # latentに変換 - if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size: - latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) - else: - chunks = [ - batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size) - ] - list_latents = [] - for chunk in chunks: - with torch.no_grad(): - chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype)) - list_latents.append(chunk) - latents = torch.cat(list_latents, dim=0) + latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype)) # NaNが含まれていれば警告を表示し0に置き換える if torch.any(torch.isnan(latents)):