Merge branch 'sd3' into lumina

This commit is contained in:
rockerBOO
2025-06-09 18:13:06 -04:00
10 changed files with 482 additions and 98 deletions

3
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,3 @@
# These are supported funding model platforms
github: kohya-ss

View File

@@ -9,11 +9,26 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
May 1, 2025:
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
Apr 27, 2025:
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
- See [here](#sample-image-generation-during-training) for details.
- If you have any issues with this, please let us know.
Apr 6, 2025:
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
@@ -866,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
## DeepSpeed installation (experimental, Linux or WSL2 only)
To install DeepSpeed, run the following command in your activated virtual environment:
```bash
pip install deepspeed==0.16.7
```
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -1340,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
* `--n` Negative prompt up to the next option.
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working.

View File

@@ -97,15 +97,19 @@ def main(args):
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
local_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
hf_hub_download(
repo_id=args.repo_id,
filename=file,
local_dir=model_location,
force_download=True,
)
else:
logger.info("using existing wd14 tagger model")

View File

@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
from .device_utils import get_preferred_device
setup_logging()
import logging
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def __wrap_model_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
else:
model = self.__wrap_model_forward_with_torch_autocast(model)
return model
def __wrap_model_forward_with_torch_autocast(self, model):
assert hasattr(model, "forward"), f"model must have a forward method."
forward_fn = model.forward
def forward(*args, **kwargs):
try:
device_type = model.device.type
except AttributeError:
logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()."
)
device_type = get_preferred_device().type
with torch.autocast(device_type = device_type):
return forward_fn(*args, **kwargs)
model.forward = forward
return model
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

View File

@@ -40,7 +40,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
controlnet=None,
):
if steps == 0:
if not args.sample_at_first:
@@ -101,7 +101,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -125,7 +125,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)
torch.set_rng_state(rng_state)
@@ -147,14 +147,16 @@ def sample_image_inference(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
# TODO refactor variable names
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
emb_guidance_scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
@@ -162,8 +164,8 @@ def sample_image_inference(
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
@@ -173,16 +175,21 @@ def sample_image_inference(
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
@@ -191,26 +198,37 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
# encode negative prompts
if cfg_scale != 1.0:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
neg_t5_attn_mask = (
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
)
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
else:
neg_cond = None
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -235,7 +253,20 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = denoise(
flux,
noise,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=emb_guidance_scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
neg_cond=neg_cond,
)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -305,21 +336,24 @@ def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt: torch.Tensor, # t5_out
txt_ids: torch.Tensor,
vec: torch.Tensor,
vec: torch.Tensor, # l_pooled
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
do_cfg = neg_cond is not None
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
@@ -335,20 +369,48 @@ def denoise(
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
if not do_cfg:
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
else:
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
# TODO is it ok to use the same block samples for both cond and uncond?
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
block_single_samples = (
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
)
nc_c_pred = model(
img=torch.cat([img, img], dim=0),
img_ids=torch.cat([img_ids, img_ids], dim=0),
txt=torch.cat([neg_t5_out, txt], dim=0),
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
y=torch.cat([neg_l_pooled, vec], dim=0),
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=nc_c_t5_attn_mask,
)
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
pred = neg_pred + (pred - neg_pred) * cfg_scale
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
@@ -365,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
@@ -409,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -455,12 +507,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -566,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--t5xxl_max_token_length",

View File

@@ -1070,8 +1070,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
else:
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
@@ -5520,6 +5523,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
@@ -6203,6 +6211,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["scale"] = float(m.group(1))
continue
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
if m: # guidance scale
prompt_dict["guidance_scale"] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)

View File

@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor:
def grad_norms(self) -> Tensor | None:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
def weight_norms(self) -> Tensor:
def weight_norms(self) -> Tensor | None:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
def combined_weight_norms(self) -> Tensor:
def combined_weight_norms(self) -> Tensor | None:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
def load_weights(self, file):

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -0,0 +1,220 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
self.config = MagicMock()
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
# Create fixtures for commonly used objects
@pytest.fixture
def args():
args = MagicMock()
args.timestep_sampling = "uniform"
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
args.ip_noise_gamma = None
args.ip_noise_gamma_random_strength = False
return args
@pytest.fixture
def noise_scheduler():
return MockNoiseScheduler(num_train_timesteps=1000)
@pytest.fixture
def latents():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def noise():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def device():
# return "cuda" if torch.cuda.is_available() else "cpu"
return "cpu"
# Mock the required functions
@pytest.fixture(autouse=True)
def mock_functions():
with (
patch("torch.sigmoid", side_effect=torch.sigmoid),
patch("torch.rand", side_effect=torch.rand),
patch("torch.randn", side_effect=torch.randn),
):
yield
# Test different timestep sampling methods
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "sigmoid"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "shift"
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "flux_shift"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.5
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.1
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
# Test different batch sizes
def test_different_batch_size(args, noise_scheduler, device):
latents = torch.randn(5, 4, 8, 8) # batch size of 5
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
# Test different image sizes
def test_different_image_size(args, noise_scheduler, device):
latents = torch.randn(2, 4, 16, 16) # larger image size
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
# Test edge cases
def test_zero_batch_size(args, noise_scheduler, device):
with pytest.raises(AssertionError): # expecting an error with zero batch size
latents = torch.randn(0, 4, 8, 8)
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
latents = torch.randn(2, 4, 8, 8)
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)

View File

@@ -389,7 +389,18 @@ class NetworkTrainer:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -1433,11 +1444,13 @@ class NetworkTrainer:
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
if hasattr(network, "weight_norms"):
mean_norm = network.weight_norms().mean().item()
mean_grad_norm = network.grad_norms().mean().item()
mean_combined_norm = network.combined_weight_norms().mean().item()
weight_norms = network.weight_norms()
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else: