mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
Merge pull request #26 from rockerBOO/lumina-test-fix-mask
Lumina test fix mask
This commit is contained in:
3
.github/FUNDING.yml
vendored
Normal file
3
.github/FUNDING.yml
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: kohya-ss
|
||||
27
README.md
27
README.md
@@ -9,11 +9,26 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
- [FLUX.1 training](#flux1-training)
|
||||
- [SD3 training](#sd3-training)
|
||||
|
||||
### Recent Updates
|
||||
|
||||
May 1, 2025:
|
||||
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
|
||||
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
|
||||
|
||||
Apr 27, 2025:
|
||||
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
|
||||
- See [here](#sample-image-generation-during-training) for details.
|
||||
- If you have any issues with this, please let us know.
|
||||
|
||||
Apr 6, 2025:
|
||||
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
|
||||
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
|
||||
|
||||
Mar 30, 2025:
|
||||
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
|
||||
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
|
||||
@@ -866,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
|
||||
|
||||
(Single GPU with id `0` will be used.)
|
||||
|
||||
## DeepSpeed installation (experimental, Linux or WSL2 only)
|
||||
|
||||
To install DeepSpeed, run the following command in your activated virtual environment:
|
||||
|
||||
```bash
|
||||
pip install deepspeed==0.16.7
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
When a new release comes out you can upgrade your repo with the following command:
|
||||
@@ -1340,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
|
||||
|
||||
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
|
||||
|
||||
* `--n` Negative prompt up to the next option.
|
||||
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
|
||||
* `--w` Specifies the width of the generated image.
|
||||
* `--h` Specifies the height of the generated image.
|
||||
* `--d` Specifies the seed of the generated image.
|
||||
* `--l` Specifies the CFG scale of the generated image.
|
||||
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
|
||||
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
|
||||
* `--s` Specifies the number of steps in the generation.
|
||||
|
||||
The prompt weighting such as `( )` and `[ ]` are working.
|
||||
|
||||
@@ -97,15 +97,19 @@ def main(args):
|
||||
else:
|
||||
for file in SUB_DIR_FILES:
|
||||
hf_hub_download(
|
||||
args.repo_id,
|
||||
file,
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
subfolder=SUB_DIR,
|
||||
cache_dir=os.path.join(model_location, SUB_DIR),
|
||||
local_dir=os.path.join(model_location, SUB_DIR),
|
||||
force_download=True,
|
||||
force_filename=file,
|
||||
)
|
||||
for file in files:
|
||||
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
|
||||
hf_hub_download(
|
||||
repo_id=args.repo_id,
|
||||
filename=file,
|
||||
local_dir=model_location,
|
||||
force_download=True,
|
||||
)
|
||||
else:
|
||||
logger.info("using existing wd14 tagger model")
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
|
||||
|
||||
from .utils import setup_logging
|
||||
|
||||
from .device_utils import get_preferred_device
|
||||
|
||||
setup_logging()
|
||||
import logging
|
||||
|
||||
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
|
||||
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
|
||||
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
|
||||
)
|
||||
|
||||
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
|
||||
if args.mixed_precision.lower() == "fp16":
|
||||
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
|
||||
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
|
||||
class DeepSpeedWrapper(torch.nn.Module):
|
||||
def __init__(self, **kw_models) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.models = torch.nn.ModuleDict()
|
||||
|
||||
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
|
||||
|
||||
for key, model in kw_models.items():
|
||||
if isinstance(model, list):
|
||||
model = torch.nn.ModuleList(model)
|
||||
|
||||
if wrap_model_forward_with_torch_autocast:
|
||||
model = self.__wrap_model_with_torch_autocast(model)
|
||||
|
||||
assert isinstance(
|
||||
model, torch.nn.Module
|
||||
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
|
||||
|
||||
self.models.update(torch.nn.ModuleDict({key: model}))
|
||||
|
||||
def __wrap_model_with_torch_autocast(self, model):
|
||||
if isinstance(model, torch.nn.ModuleList):
|
||||
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
|
||||
else:
|
||||
model = self.__wrap_model_forward_with_torch_autocast(model)
|
||||
return model
|
||||
|
||||
def __wrap_model_forward_with_torch_autocast(self, model):
|
||||
|
||||
assert hasattr(model, "forward"), f"model must have a forward method."
|
||||
|
||||
forward_fn = model.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
try:
|
||||
device_type = model.device.type
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
|
||||
"to determine the device_type for torch.autocast()."
|
||||
)
|
||||
device_type = get_preferred_device().type
|
||||
|
||||
with torch.autocast(device_type = device_type):
|
||||
return forward_fn(*args, **kwargs)
|
||||
|
||||
model.forward = forward
|
||||
return model
|
||||
|
||||
def get_models(self):
|
||||
return self.models
|
||||
|
||||
|
||||
ds_model = DeepSpeedWrapper(**models)
|
||||
return ds_model
|
||||
|
||||
@@ -40,7 +40,7 @@ def sample_images(
|
||||
text_encoders,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement=None,
|
||||
controlnet=None
|
||||
controlnet=None,
|
||||
):
|
||||
if steps == 0:
|
||||
if not args.sample_at_first:
|
||||
@@ -101,7 +101,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
else:
|
||||
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
|
||||
@@ -125,7 +125,7 @@ def sample_images(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
)
|
||||
|
||||
torch.set_rng_state(rng_state)
|
||||
@@ -147,14 +147,16 @@ def sample_image_inference(
|
||||
steps,
|
||||
sample_prompts_te_outputs,
|
||||
prompt_replacement,
|
||||
controlnet
|
||||
controlnet,
|
||||
):
|
||||
assert isinstance(prompt_dict, dict)
|
||||
# negative_prompt = prompt_dict.get("negative_prompt")
|
||||
negative_prompt = prompt_dict.get("negative_prompt")
|
||||
sample_steps = prompt_dict.get("sample_steps", 20)
|
||||
width = prompt_dict.get("width", 512)
|
||||
height = prompt_dict.get("height", 512)
|
||||
scale = prompt_dict.get("scale", 3.5)
|
||||
# TODO refactor variable names
|
||||
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
|
||||
emb_guidance_scale = prompt_dict.get("scale", 3.5)
|
||||
seed = prompt_dict.get("seed")
|
||||
controlnet_image = prompt_dict.get("controlnet_image")
|
||||
prompt: str = prompt_dict.get("prompt", "")
|
||||
@@ -162,8 +164,8 @@ def sample_image_inference(
|
||||
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
# if negative_prompt is not None:
|
||||
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
|
||||
if seed is not None:
|
||||
torch.manual_seed(seed)
|
||||
@@ -173,16 +175,21 @@ def sample_image_inference(
|
||||
torch.seed()
|
||||
torch.cuda.seed()
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
# logger.info(f"negative_prompt: {negative_prompt}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"negative_prompt: {negative_prompt}")
|
||||
elif negative_prompt != "":
|
||||
logger.info(f"negative prompt is ignored because scale is 1.0")
|
||||
logger.info(f"height: {height}")
|
||||
logger.info(f"width: {width}")
|
||||
logger.info(f"sample_steps: {sample_steps}")
|
||||
logger.info(f"scale: {scale}")
|
||||
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
|
||||
if cfg_scale != 1.0:
|
||||
logger.info(f"CFG scale: {cfg_scale}")
|
||||
# logger.info(f"sample_sampler: {sampler_name}")
|
||||
if seed is not None:
|
||||
logger.info(f"seed: {seed}")
|
||||
@@ -191,26 +198,37 @@ def sample_image_inference(
|
||||
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
|
||||
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prompt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prompt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prompt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
def encode_prompt(prpt):
|
||||
text_encoder_conds = []
|
||||
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
|
||||
text_encoder_conds = sample_prompts_te_outputs[prpt]
|
||||
print(f"Using cached text encoder outputs for prompt: {prpt}")
|
||||
if text_encoders is not None:
|
||||
print(f"Encoding prompt: {prpt}")
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prpt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
text_encoder_conds = encoded_text_encoder_conds
|
||||
else:
|
||||
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
|
||||
for i in range(len(encoded_text_encoder_conds)):
|
||||
if encoded_text_encoder_conds[i] is not None:
|
||||
text_encoder_conds[i] = encoded_text_encoder_conds[i]
|
||||
return text_encoder_conds
|
||||
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
|
||||
# encode negative prompts
|
||||
if cfg_scale != 1.0:
|
||||
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
|
||||
neg_t5_attn_mask = (
|
||||
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
|
||||
)
|
||||
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
|
||||
else:
|
||||
neg_cond = None
|
||||
|
||||
# sample image
|
||||
weight_dtype = ae.dtype # TOFO give dtype as argument
|
||||
@@ -235,7 +253,20 @@ def sample_image_inference(
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
x = denoise(
|
||||
flux,
|
||||
noise,
|
||||
img_ids,
|
||||
t5_out,
|
||||
txt_ids,
|
||||
l_pooled,
|
||||
timesteps=timesteps,
|
||||
guidance=emb_guidance_scale,
|
||||
t5_attn_mask=t5_attn_mask,
|
||||
controlnet=controlnet,
|
||||
controlnet_img=controlnet_image,
|
||||
neg_cond=neg_cond,
|
||||
)
|
||||
|
||||
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
@@ -305,21 +336,24 @@ def denoise(
|
||||
model: flux_models.Flux,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt: torch.Tensor, # t5_out
|
||||
txt_ids: torch.Tensor,
|
||||
vec: torch.Tensor,
|
||||
vec: torch.Tensor, # l_pooled
|
||||
timesteps: list[float],
|
||||
guidance: float = 4.0,
|
||||
t5_attn_mask: Optional[torch.Tensor] = None,
|
||||
controlnet: Optional[flux_models.ControlNetFlux] = None,
|
||||
controlnet_img: Optional[torch.Tensor] = None,
|
||||
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
do_cfg = neg_cond is not None
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
|
||||
if controlnet is not None:
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=img,
|
||||
@@ -335,20 +369,48 @@ def denoise(
|
||||
else:
|
||||
block_samples = None
|
||||
block_single_samples = None
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
if not do_cfg:
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
else:
|
||||
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
|
||||
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
|
||||
|
||||
# TODO is it ok to use the same block samples for both cond and uncond?
|
||||
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
|
||||
block_single_samples = (
|
||||
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
|
||||
)
|
||||
|
||||
nc_c_pred = model(
|
||||
img=torch.cat([img, img], dim=0),
|
||||
img_ids=torch.cat([img_ids, img_ids], dim=0),
|
||||
txt=torch.cat([neg_t5_out, txt], dim=0),
|
||||
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
|
||||
y=torch.cat([neg_l_pooled, vec], dim=0),
|
||||
block_controlnet_hidden_states=block_samples,
|
||||
block_controlnet_single_hidden_states=block_single_samples,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=nc_c_t5_attn_mask,
|
||||
)
|
||||
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
|
||||
pred = neg_pred + (pred - neg_pred) * cfg_scale
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
@@ -365,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
@@ -409,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz, _, h, w = latents.shape
|
||||
sigmas = None
|
||||
|
||||
assert bsz > 0, "Batch size not large enough"
|
||||
num_timesteps = noise_scheduler.config.num_train_timesteps
|
||||
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
|
||||
# Simple random t-based noise sampling
|
||||
# Simple random sigma-based noise sampling
|
||||
if args.timestep_sampling == "sigmoid":
|
||||
# https://github.com/XLabs-AI/x-flux/tree/main
|
||||
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
|
||||
else:
|
||||
t = torch.rand((bsz,), device=device)
|
||||
sigmas = torch.rand((bsz,), device=device)
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
elif args.timestep_sampling == "flux_shift":
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
sigmas = torch.randn(bsz, device=device)
|
||||
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
sigmas = sigmas.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
|
||||
sigmas = time_shift(mu, 1.0, sigmas)
|
||||
timesteps = sigmas * num_timesteps
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -455,12 +507,24 @@ def get_noisy_model_input_and_timesteps(
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
indices = (u * num_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
# Broadcast sigmas to latent shape
|
||||
sigmas = sigmas.view(-1, 1, 1, 1)
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
if args.ip_noise_gamma:
|
||||
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
|
||||
if args.ip_noise_gamma_random_strength:
|
||||
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
|
||||
else:
|
||||
ip_noise_gamma = args.ip_noise_gamma
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
|
||||
else:
|
||||
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -566,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
|
||||
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
|
||||
@@ -868,6 +868,8 @@ class NextDiT(nn.Module):
|
||||
cap_feat_dim (int): Dimension of the caption features.
|
||||
axes_dims (List[int]): List of dimensions for the axes.
|
||||
axes_lens (List[int]): List of lengths for the axes.
|
||||
use_flash_attn (bool): Whether to use Flash Attention.
|
||||
use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference.
|
||||
|
||||
Returns:
|
||||
None
|
||||
@@ -1110,7 +1112,11 @@ class NextDiT(nn.Module):
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
|
||||
|
||||
x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)
|
||||
|
||||
x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
|
||||
for i in range(bsz):
|
||||
x[i, :image_seq_len] = x[i]
|
||||
x_mask[i, :image_seq_len] = True
|
||||
|
||||
x = self.x_embedder(x)
|
||||
|
||||
|
||||
@@ -173,62 +173,61 @@ def pack_latents(x: torch.Tensor) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP = {
|
||||
|
||||
DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
|
||||
# Embedding layers
|
||||
"cap_embedder.0.weight": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"cap_embedder.1.weight": "time_caption_embed.caption_embedder.1.weight",
|
||||
"cap_embedder.1.bias": "text_embedder.1.bias",
|
||||
"x_embedder.weight": "patch_embedder.proj.weight",
|
||||
"x_embedder.bias": "patch_embedder.proj.bias",
|
||||
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
|
||||
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
|
||||
"text_embedder.1.bias": "cap_embedder.1.bias",
|
||||
"patch_embedder.proj.weight": "x_embedder.weight",
|
||||
"patch_embedder.proj.bias": "x_embedder.bias",
|
||||
# Attention modulation
|
||||
"layers.().adaLN_modulation.1.weight": "transformer_blocks.().adaln_modulation.1.weight",
|
||||
"layers.().adaLN_modulation.1.bias": "transformer_blocks.().adaln_modulation.1.bias",
|
||||
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
|
||||
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
|
||||
# Final layers
|
||||
"final_layer.adaLN_modulation.1.weight": "final_adaln_modulation.1.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "final_adaln_modulation.1.bias",
|
||||
"final_layer.linear.weight": "final_linear.weight",
|
||||
"final_layer.linear.bias": "final_linear.bias",
|
||||
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
|
||||
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
|
||||
"final_linear.weight": "final_layer.linear.weight",
|
||||
"final_linear.bias": "final_layer.linear.bias",
|
||||
# Noise refiner
|
||||
"noise_refiner.().adaLN_modulation.1.weight": "single_transformer_blocks.().adaln_modulation.1.weight",
|
||||
"noise_refiner.().adaLN_modulation.1.bias": "single_transformer_blocks.().adaln_modulation.1.bias",
|
||||
"noise_refiner.().attention.qkv.weight": "single_transformer_blocks.().attn.to_qkv.weight",
|
||||
"noise_refiner.().attention.out.weight": "single_transformer_blocks.().attn.to_out.0.weight",
|
||||
# Time embedding
|
||||
"t_embedder.mlp.0.weight": "time_embedder.0.weight",
|
||||
"t_embedder.mlp.0.bias": "time_embedder.0.bias",
|
||||
"t_embedder.mlp.2.weight": "time_embedder.2.weight",
|
||||
"t_embedder.mlp.2.bias": "time_embedder.2.bias",
|
||||
# Context attention
|
||||
"context_refiner.().attention.qkv.weight": "transformer_blocks.().attn2.to_qkv.weight",
|
||||
"context_refiner.().attention.out.weight": "transformer_blocks.().attn2.to_out.0.weight",
|
||||
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
|
||||
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
|
||||
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
|
||||
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
|
||||
# Normalization
|
||||
"layers.().attention_norm1.weight": "transformer_blocks.().norm1.weight",
|
||||
"layers.().attention_norm2.weight": "transformer_blocks.().norm2.weight",
|
||||
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
|
||||
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
|
||||
# FFN
|
||||
"layers.().feed_forward.w1.weight": "transformer_blocks.().ff.net.0.proj.weight",
|
||||
"layers.().feed_forward.w2.weight": "transformer_blocks.().ff.net.2.weight",
|
||||
"layers.().feed_forward.w3.weight": "transformer_blocks.().ff.net.4.weight",
|
||||
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
|
||||
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
|
||||
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
|
||||
}
|
||||
|
||||
|
||||
def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
|
||||
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
|
||||
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
|
||||
new_sd = {}
|
||||
new_sd = sd.copy() # Preserve original keys
|
||||
|
||||
for key, value in sd.items():
|
||||
new_key = key
|
||||
for pattern, replacement in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
if "()." in pattern:
|
||||
for block_idx in range(num_double_blocks):
|
||||
if str(block_idx) in key:
|
||||
converted = pattern.replace("()", str(block_idx))
|
||||
new_key = key.replace(converted, replacement.replace("()", str(block_idx)))
|
||||
break
|
||||
for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
|
||||
# Handle block-specific patterns
|
||||
if '().' in diff_key:
|
||||
for block_idx in range(num_double_blocks):
|
||||
block_alpha_key = alpha_key.replace('().', f'{block_idx}.')
|
||||
block_diff_key = diff_key.replace('().', f'{block_idx}.')
|
||||
|
||||
# Search for and convert block-specific keys
|
||||
for input_key, value in list(sd.items()):
|
||||
if input_key == block_diff_key:
|
||||
new_sd[block_alpha_key] = value
|
||||
else:
|
||||
# Handle static keys
|
||||
if diff_key in sd:
|
||||
print(f"Replacing {diff_key} with {alpha_key}")
|
||||
new_sd[alpha_key] = sd[diff_key]
|
||||
else:
|
||||
print(f"Not found: {diff_key}")
|
||||
|
||||
if new_key == key:
|
||||
logger.debug(f"Unmatched key in conversion: {key}")
|
||||
new_sd[new_key] = value
|
||||
|
||||
logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
|
||||
return new_sd
|
||||
|
||||
@@ -610,21 +610,6 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
@@ -664,49 +649,22 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
use_dynamic_shifting=False,
|
||||
base_shift: Optional[float] = 0.5,
|
||||
max_shift: Optional[float] = 1.15,
|
||||
base_image_seq_len: Optional[int] = 256,
|
||||
max_image_seq_len: Optional[int] = 4096,
|
||||
invert_sigmas: bool = False,
|
||||
shift_terminal: Optional[float] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
if not use_dynamic_shifting:
|
||||
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self._shift = shift
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def shift(self):
|
||||
"""
|
||||
The value used for shifting.
|
||||
"""
|
||||
return self._shift
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
@@ -732,9 +690,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_shift(self, shift: float):
|
||||
self._shift = shift
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -754,31 +709,10 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
|
||||
|
||||
if sample.device.type == "mps" and torch.is_floating_point(timestep):
|
||||
# mps does not support float64
|
||||
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
|
||||
timestep = timestep.to(sample.device, dtype=torch.float32)
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(sample.device)
|
||||
timestep = timestep.to(sample.device)
|
||||
|
||||
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
|
||||
if self.begin_index is None:
|
||||
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
|
||||
elif self.step_index is not None:
|
||||
# add_noise is called after first denoising step (for inpainting)
|
||||
step_indices = [self.step_index] * timestep.shape[0]
|
||||
else:
|
||||
# add noise is called before first denoising step to create initial latent(img2img)
|
||||
step_indices = [self.begin_index] * timestep.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(sample.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
@@ -786,37 +720,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -826,49 +730,18 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is None:
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
else:
|
||||
sigmas = np.array(sigmas).astype(np.float32)
|
||||
num_inference_steps = len(sigmas)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
|
||||
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
|
||||
if self.config.use_karras_sigmas:
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_exponential_sigmas:
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
|
||||
elif self.config.use_beta_sigmas:
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
timesteps = np.linspace(self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
|
||||
if self.config.invert_sigmas:
|
||||
sigmas = 1.0 - sigmas
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
|
||||
else:
|
||||
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = sigmas
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@@ -934,11 +807,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
@@ -954,10 +823,30 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sigma_next = self.sigmas[self.step_index + 1]
|
||||
|
||||
prev_sample = sample + (sigma_next - sigma) * model_output
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
noise = randn_tensor(model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
||||
# backwards compatibility
|
||||
|
||||
# if self.config.prediction_type == "vector_field":
|
||||
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
@@ -969,86 +858,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
rho = 7.0 # 7.0 is the value used in the paper
|
||||
ramp = np.linspace(0, 1, num_inference_steps)
|
||||
min_inv_rho = sigma_min ** (1 / rho)
|
||||
max_inv_rho = sigma_max ** (1 / rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_exponential
|
||||
def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""Constructs an exponential noise schedule."""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.exp(np.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps))
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta
|
||||
def _convert_to_beta(
|
||||
self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6
|
||||
) -> torch.Tensor:
|
||||
"""From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)"""
|
||||
|
||||
# Hack to make sure that other schedulers which copy this function don't break
|
||||
# TODO: Add this logic to the other schedulers
|
||||
if hasattr(self.config, "sigma_min"):
|
||||
sigma_min = self.config.sigma_min
|
||||
else:
|
||||
sigma_min = None
|
||||
|
||||
if hasattr(self.config, "sigma_max"):
|
||||
sigma_max = self.config.sigma_max
|
||||
else:
|
||||
sigma_max = None
|
||||
|
||||
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
||||
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
||||
|
||||
sigmas = np.array(
|
||||
[
|
||||
sigma_min + (ppf * (sigma_max - sigma_min))
|
||||
for ppf in [
|
||||
scipy.stats.beta.ppf(timestep, alpha, beta)
|
||||
for timestep in 1 - np.linspace(0, 1, num_inference_steps)
|
||||
]
|
||||
]
|
||||
)
|
||||
return sigmas
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
@@ -1070,8 +1070,11 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
|
||||
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
|
||||
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
if len(img_ar_errors) == 0:
|
||||
mean_img_ar_error = 0 # avoid NaN
|
||||
else:
|
||||
img_ar_errors = np.array(img_ar_errors)
|
||||
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
|
||||
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
|
||||
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
|
||||
|
||||
@@ -5520,6 +5523,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
|
||||
|
||||
|
||||
def patch_accelerator_for_fp16_training(accelerator):
|
||||
|
||||
from accelerate import DistributedType
|
||||
if accelerator.distributed_type == DistributedType.DEEPSPEED:
|
||||
return
|
||||
|
||||
org_unscale_grads = accelerator.scaler._unscale_grads_
|
||||
|
||||
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
|
||||
@@ -6203,6 +6211,11 @@ def line_to_prompt_dict(line: str) -> dict:
|
||||
prompt_dict["scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
|
||||
if m: # guidance scale
|
||||
prompt_dict["guidance_scale"] = float(m.group(1))
|
||||
continue
|
||||
|
||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||
if m: # negative prompt
|
||||
prompt_dict["negative_prompt"] = m.group(1)
|
||||
|
||||
@@ -955,26 +955,26 @@ class LoRANetwork(torch.nn.Module):
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
lora.update_grad_norms()
|
||||
|
||||
def grad_norms(self) -> Tensor:
|
||||
def grad_norms(self) -> Tensor | None:
|
||||
grad_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
|
||||
grad_norms.append(lora.grad_norms.mean(dim=0))
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
|
||||
|
||||
def weight_norms(self) -> Tensor:
|
||||
def weight_norms(self) -> Tensor | None:
|
||||
weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
|
||||
weight_norms.append(lora.weight_norms.mean(dim=0))
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
|
||||
|
||||
def combined_weight_norms(self) -> Tensor:
|
||||
def combined_weight_norms(self) -> Tensor | None:
|
||||
combined_weight_norms = []
|
||||
for lora in self.text_encoder_loras + self.unet_loras:
|
||||
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
|
||||
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else torch.tensor([])
|
||||
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
|
||||
|
||||
|
||||
def load_weights(self, file):
|
||||
|
||||
@@ -6,3 +6,4 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::UserWarning
|
||||
ignore::FutureWarning
|
||||
pythonpath = .
|
||||
|
||||
220
tests/library/test_flux_train_utils.py
Normal file
220
tests/library/test_flux_train_utils.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from library.flux_train_utils import (
|
||||
get_noisy_model_input_and_timesteps,
|
||||
)
|
||||
|
||||
# Mock classes and functions
|
||||
class MockNoiseScheduler:
|
||||
def __init__(self, num_train_timesteps=1000):
|
||||
self.config = MagicMock()
|
||||
self.config.num_train_timesteps = num_train_timesteps
|
||||
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
|
||||
|
||||
|
||||
# Create fixtures for commonly used objects
|
||||
@pytest.fixture
|
||||
def args():
|
||||
args = MagicMock()
|
||||
args.timestep_sampling = "uniform"
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
args.ip_noise_gamma = None
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_scheduler():
|
||||
return MockNoiseScheduler(num_train_timesteps=1000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def latents():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise():
|
||||
return torch.randn(2, 4, 8, 8)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def device():
|
||||
# return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
return "cpu"
|
||||
|
||||
|
||||
# Mock the required functions
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_functions():
|
||||
with (
|
||||
patch("torch.sigmoid", side_effect=torch.sigmoid),
|
||||
patch("torch.rand", side_effect=torch.rand),
|
||||
patch("torch.randn", side_effect=torch.randn),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
# Test different timestep sampling methods
|
||||
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "uniform"
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "sigmoid"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
args.discrete_flow_shift = 3.1582
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
|
||||
args.timestep_sampling = "flux_shift"
|
||||
args.sigmoid_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
|
||||
# Mock the necessary functions for this specific test
|
||||
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
|
||||
return_value=torch.tensor([0.3, 0.7], device=device)), \
|
||||
patch("library.flux_train_utils.get_sigmas",
|
||||
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
|
||||
|
||||
args.timestep_sampling = "other" # Will trigger the weighting scheme path
|
||||
args.weighting_scheme = "uniform"
|
||||
args.logit_mean = 0.0
|
||||
args.logit_std = 1.0
|
||||
args.mode_scale = 1.0
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test IP noise options
|
||||
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.5
|
||||
args.ip_noise_gamma_random_strength = False
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
|
||||
args.ip_noise_gamma = 0.1
|
||||
args.ip_noise_gamma_random_strength = True
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (latents.shape[0],)
|
||||
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
|
||||
|
||||
|
||||
# Test different data types
|
||||
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
|
||||
dtype = torch.float16
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.dtype == dtype
|
||||
assert timesteps.dtype == dtype
|
||||
|
||||
|
||||
# Test different batch sizes
|
||||
def test_different_batch_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(5, 4, 8, 8) # batch size of 5
|
||||
noise = torch.randn(5, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (5,)
|
||||
assert sigmas.shape == (5, 1, 1, 1)
|
||||
|
||||
|
||||
# Test different image sizes
|
||||
def test_different_image_size(args, noise_scheduler, device):
|
||||
latents = torch.randn(2, 4, 16, 16) # larger image size
|
||||
noise = torch.randn(2, 4, 16, 16)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
assert sigmas.shape == (2, 1, 1, 1)
|
||||
|
||||
|
||||
# Test edge cases
|
||||
def test_zero_batch_size(args, noise_scheduler, device):
|
||||
with pytest.raises(AssertionError): # expecting an error with zero batch size
|
||||
latents = torch.randn(0, 4, 8, 8)
|
||||
noise = torch.randn(0, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
|
||||
def test_different_timestep_count(args, device):
|
||||
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
|
||||
latents = torch.randn(2, 4, 8, 8)
|
||||
noise = torch.randn(2, 4, 8, 8)
|
||||
dtype = torch.float32
|
||||
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
|
||||
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape == (2,)
|
||||
# Check that timesteps are within the proper range
|
||||
assert torch.all(timesteps < 500)
|
||||
295
tests/library/test_lumina_models.py
Normal file
295
tests/library/test_lumina_models.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from library.lumina_models import (
|
||||
LuminaParams,
|
||||
to_cuda,
|
||||
to_cpu,
|
||||
RopeEmbedder,
|
||||
TimestepEmbedder,
|
||||
modulate,
|
||||
NextDiT,
|
||||
)
|
||||
|
||||
cuda_required = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
|
||||
def test_lumina_params():
|
||||
# Test default configuration
|
||||
default_params = LuminaParams()
|
||||
assert default_params.patch_size == 2
|
||||
assert default_params.in_channels == 4
|
||||
assert default_params.axes_dims == [36, 36, 36]
|
||||
assert default_params.axes_lens == [300, 512, 512]
|
||||
|
||||
# Test 2B config
|
||||
config_2b = LuminaParams.get_2b_config()
|
||||
assert config_2b.dim == 2304
|
||||
assert config_2b.in_channels == 16
|
||||
assert config_2b.n_layers == 26
|
||||
assert config_2b.n_heads == 24
|
||||
assert config_2b.cap_feat_dim == 2304
|
||||
|
||||
# Test 7B config
|
||||
config_7b = LuminaParams.get_7b_config()
|
||||
assert config_7b.dim == 4096
|
||||
assert config_7b.n_layers == 32
|
||||
assert config_7b.n_heads == 32
|
||||
assert config_7b.axes_dims == [64, 64, 64]
|
||||
|
||||
|
||||
@cuda_required
|
||||
def test_to_cuda_to_cpu():
|
||||
# Test tensor conversion
|
||||
x = torch.tensor([1, 2, 3])
|
||||
x_cuda = to_cuda(x)
|
||||
x_cpu = to_cpu(x_cuda)
|
||||
assert x.cpu().tolist() == x_cpu.tolist()
|
||||
|
||||
# Test list conversion
|
||||
list_data = [torch.tensor([1]), torch.tensor([2])]
|
||||
list_cuda = to_cuda(list_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in list_cuda)
|
||||
|
||||
list_cpu = to_cpu(list_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in list_cpu)
|
||||
|
||||
# Test dict conversion
|
||||
dict_data = {"a": torch.tensor([1]), "b": torch.tensor([2])}
|
||||
dict_cuda = to_cuda(dict_data)
|
||||
assert all(tensor.device.type == "cuda" for tensor in dict_cuda.values())
|
||||
|
||||
dict_cpu = to_cpu(dict_cuda)
|
||||
assert all(not tensor.device.type == "cuda" for tensor in dict_cpu.values())
|
||||
|
||||
|
||||
def test_timestep_embedder():
|
||||
# Test initialization
|
||||
hidden_size = 256
|
||||
freq_emb_size = 128
|
||||
embedder = TimestepEmbedder(hidden_size, freq_emb_size)
|
||||
assert embedder.frequency_embedding_size == freq_emb_size
|
||||
|
||||
# Test timestep embedding
|
||||
t = torch.tensor([0.5, 1.0, 2.0])
|
||||
emb_dim = freq_emb_size
|
||||
embeddings = TimestepEmbedder.timestep_embedding(t, emb_dim)
|
||||
|
||||
assert embeddings.shape == (3, emb_dim)
|
||||
assert embeddings.dtype == torch.float32
|
||||
|
||||
# Ensure embeddings are unique for different input times
|
||||
assert not torch.allclose(embeddings[0], embeddings[1])
|
||||
|
||||
# Test forward pass
|
||||
t_emb = embedder(t)
|
||||
assert t_emb.shape == (3, hidden_size)
|
||||
|
||||
|
||||
def test_rope_embedder_simple():
|
||||
rope_embedder = RopeEmbedder()
|
||||
batch_size, seq_len = 2, 10
|
||||
|
||||
# Create position_ids with valid ranges for each axis
|
||||
position_ids = torch.stack(
|
||||
[
|
||||
torch.zeros(batch_size, seq_len, dtype=torch.int64), # First axis: only 0 is valid
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Second axis: 0-511
|
||||
torch.randint(0, 512, (batch_size, seq_len), dtype=torch.int64), # Third axis: 0-511
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
freqs_cis = rope_embedder(position_ids)
|
||||
# RoPE embeddings work in pairs, so output dimension is half of total axes_dims
|
||||
expected_dim = sum(rope_embedder.axes_dims) // 2 # 128 // 2 = 64
|
||||
assert freqs_cis.shape == (batch_size, seq_len, expected_dim)
|
||||
|
||||
|
||||
def test_modulate():
|
||||
# Test modulation with different scales
|
||||
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||
scale = torch.tensor([1.5, 2.0])
|
||||
|
||||
modulated_x = modulate(x, scale)
|
||||
|
||||
# Check that modulation scales correctly
|
||||
# The function does x * (1 + scale), so:
|
||||
# For scale [1.5, 2.0], (1 + scale) = [2.5, 3.0]
|
||||
expected_x = torch.tensor([[2.5 * 1.0, 2.5 * 2.0], [3.0 * 3.0, 3.0 * 4.0]])
|
||||
# Which equals: [[2.5, 5.0], [9.0, 12.0]]
|
||||
|
||||
assert torch.allclose(modulated_x, expected_x)
|
||||
|
||||
|
||||
def test_nextdit_parameter_count_optimized():
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# So for dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model_small = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4, # Smaller
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=2, # Much fewer layers
|
||||
n_heads=4, # Fewer heads
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32], # Smaller
|
||||
)
|
||||
param_count_small = model_small.parameter_count()
|
||||
assert param_count_small > 0
|
||||
|
||||
# For dim=192, n_heads=6: 192//6 = 32, so sum(axes_dims) must = 32
|
||||
model_medium = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=192, # 192 // 6 = 32
|
||||
n_layers=4, # More layers
|
||||
n_heads=6,
|
||||
n_kv_heads=3,
|
||||
axes_dims=[10, 11, 11], # sum = 32
|
||||
axes_lens=[10, 32, 32],
|
||||
)
|
||||
param_count_medium = model_medium.parameter_count()
|
||||
assert param_count_medium > param_count_small
|
||||
print(f"Small model: {param_count_small:,} parameters")
|
||||
print(f"Medium model: {param_count_medium:,} parameters")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_precompute_freqs_cis():
|
||||
# Test precompute_freqs_cis
|
||||
dim = [16, 56, 56]
|
||||
end = [1, 512, 512]
|
||||
theta = 10000.0
|
||||
|
||||
freqs_cis = NextDiT.precompute_freqs_cis(dim, end, theta)
|
||||
|
||||
# Check number of frequency tensors
|
||||
assert len(freqs_cis) == len(dim)
|
||||
|
||||
# Check each frequency tensor
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
assert freqs_cis[i].shape == (e, d // 2)
|
||||
assert freqs_cis[i].dtype == torch.complex128
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed():
|
||||
"""Test the patchify_and_embed method which is crucial for training"""
|
||||
# Create a small NextDiT model for testing
|
||||
# The constraint is: (dim // n_heads) == sum(axes_dims)
|
||||
# For dim=120, n_heads=4: 120//4 = 30, so sum(axes_dims) must = 30
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=120, # 120 // 4 = 30
|
||||
n_layers=1, # Minimal layers for faster testing
|
||||
n_refiner_layers=1, # Minimal refiner layers
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
axes_dims=[10, 10, 10], # sum = 30
|
||||
axes_lens=[10, 32, 32],
|
||||
cap_feat_dim=120, # Match dim for consistency
|
||||
)
|
||||
|
||||
# Prepare test inputs
|
||||
batch_size = 2
|
||||
height, width = 64, 64 # Must be divisible by patch_size (2)
|
||||
caption_seq_len = 8
|
||||
|
||||
# Create mock inputs
|
||||
x = torch.randn(batch_size, 4, height, width) # Image latents
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 120) # Caption features
|
||||
cap_mask = torch.ones(batch_size, caption_seq_len, dtype=torch.bool) # All valid tokens
|
||||
# Make second batch have shorter caption
|
||||
cap_mask[1, 6:] = False # Only first 6 tokens are valid for second batch
|
||||
t = torch.randn(batch_size, 120) # Timestep embeddings
|
||||
|
||||
# Call patchify_and_embed
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
image_seq_len = (height // 2) * (width // 2) # patch_size = 2
|
||||
expected_seq_lengths = [caption_seq_len + image_seq_len, 6 + image_seq_len] # Second batch has shorter caption
|
||||
max_seq_len = max(expected_seq_lengths)
|
||||
|
||||
# Check joint hidden states shape
|
||||
assert joint_hidden_states.shape == (batch_size, max_seq_len, 120)
|
||||
assert joint_hidden_states.dtype == torch.float32
|
||||
|
||||
# Check attention mask shape and values
|
||||
assert attention_mask.shape == (batch_size, max_seq_len)
|
||||
assert attention_mask.dtype == torch.bool
|
||||
# First batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[0, : expected_seq_lengths[0]])
|
||||
assert torch.all(~attention_mask[0, expected_seq_lengths[0] :])
|
||||
# Second batch should have all positions valid up to its sequence length
|
||||
assert torch.all(attention_mask[1, : expected_seq_lengths[1]])
|
||||
assert torch.all(~attention_mask[1, expected_seq_lengths[1] :])
|
||||
|
||||
# Check freqs_cis shape
|
||||
assert freqs_cis.shape == (batch_size, max_seq_len, sum(model.axes_dims) // 2)
|
||||
|
||||
# Check effective caption lengths
|
||||
assert l_effective_cap_len == [caption_seq_len, 6]
|
||||
|
||||
# Check sequence lengths
|
||||
assert seq_lengths == expected_seq_lengths
|
||||
|
||||
# Validate that the joint hidden states contain non-zero values where attention mask is True
|
||||
for i in range(batch_size):
|
||||
valid_positions = attention_mask[i]
|
||||
# Check that valid positions have meaningful data (not all zeros)
|
||||
valid_data = joint_hidden_states[i][valid_positions]
|
||||
assert not torch.allclose(valid_data, torch.zeros_like(valid_data))
|
||||
|
||||
# Check that invalid positions are zeros
|
||||
if valid_positions.sum() < max_seq_len:
|
||||
invalid_data = joint_hidden_states[i][~valid_positions]
|
||||
assert torch.allclose(invalid_data, torch.zeros_like(invalid_data))
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test_nextdit_patchify_and_embed_edge_cases():
|
||||
"""Test edge cases for patchify_and_embed"""
|
||||
# Create minimal model
|
||||
model = NextDiT(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
dim=60, # 60 // 3 = 20
|
||||
n_layers=1,
|
||||
n_refiner_layers=1,
|
||||
n_heads=3,
|
||||
n_kv_heads=1,
|
||||
axes_dims=[8, 6, 6], # sum = 20
|
||||
axes_lens=[10, 16, 16],
|
||||
cap_feat_dim=60,
|
||||
)
|
||||
|
||||
# Test with empty captions (all masked)
|
||||
batch_size = 1
|
||||
height, width = 32, 32
|
||||
caption_seq_len = 4
|
||||
|
||||
x = torch.randn(batch_size, 4, height, width)
|
||||
cap_feats = torch.randn(batch_size, caption_seq_len, 60)
|
||||
cap_mask = torch.zeros(batch_size, caption_seq_len, dtype=torch.bool) # All tokens masked
|
||||
t = torch.randn(batch_size, 60)
|
||||
|
||||
joint_hidden_states, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths = model.patchify_and_embed(
|
||||
x, cap_feats, cap_mask, t
|
||||
)
|
||||
|
||||
# With all captions masked, effective length should be 0
|
||||
assert l_effective_cap_len == [0]
|
||||
|
||||
# Sequence length should just be the image sequence length
|
||||
image_seq_len = (height // 2) * (width // 2)
|
||||
assert seq_lengths == [image_seq_len]
|
||||
|
||||
# Joint hidden states should only contain image data
|
||||
assert joint_hidden_states.shape == (batch_size, image_seq_len, 60)
|
||||
assert attention_mask.shape == (batch_size, image_seq_len)
|
||||
assert torch.all(attention_mask[0]) # All image positions should be valid
|
||||
241
tests/library/test_lumina_train_util.py
Normal file
241
tests/library/test_lumina_train_util.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import pytest
|
||||
import torch
|
||||
import math
|
||||
|
||||
from library.lumina_train_util import (
|
||||
batchify,
|
||||
time_shift,
|
||||
get_lin_function,
|
||||
get_schedule,
|
||||
compute_density_for_timestep_sampling,
|
||||
get_sigmas,
|
||||
compute_loss_weighting_for_sd3,
|
||||
get_noisy_model_input_and_timesteps,
|
||||
apply_model_prediction_type,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler
|
||||
|
||||
|
||||
def test_batchify():
|
||||
# Test case with no batch size specified
|
||||
prompts = [
|
||||
{"prompt": "test1"},
|
||||
{"prompt": "test2"},
|
||||
{"prompt": "test3"}
|
||||
]
|
||||
batchified = list(batchify(prompts))
|
||||
assert len(batchified) == 1
|
||||
assert len(batchified[0]) == 3
|
||||
|
||||
# Test case with batch size specified
|
||||
batchified_sized = list(batchify(prompts, batch_size=2))
|
||||
assert len(batchified_sized) == 2
|
||||
assert len(batchified_sized[0]) == 2
|
||||
assert len(batchified_sized[1]) == 1
|
||||
|
||||
# Test batching with prompts having same parameters
|
||||
prompts_with_params = [
|
||||
{"prompt": "test1", "width": 512, "height": 512},
|
||||
{"prompt": "test2", "width": 512, "height": 512},
|
||||
{"prompt": "test3", "width": 1024, "height": 1024}
|
||||
]
|
||||
batchified_params = list(batchify(prompts_with_params))
|
||||
assert len(batchified_params) == 2
|
||||
|
||||
# Test invalid batch size
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=0))
|
||||
with pytest.raises(ValueError):
|
||||
list(batchify(prompts, batch_size=-1))
|
||||
|
||||
|
||||
def test_time_shift():
|
||||
# Test standard parameters
|
||||
t = torch.tensor([0.5])
|
||||
mu = 1.0
|
||||
sigma = 1.0
|
||||
result = time_shift(mu, sigma, t)
|
||||
assert 0 <= result <= 1
|
||||
|
||||
# Test with edge cases
|
||||
t_edges = torch.tensor([0.0, 1.0])
|
||||
result_edges = time_shift(1.0, 1.0, t_edges)
|
||||
|
||||
# Check that results are bounded within [0, 1]
|
||||
assert torch.all(result_edges >= 0)
|
||||
assert torch.all(result_edges <= 1)
|
||||
|
||||
|
||||
def test_get_lin_function():
|
||||
# Default parameters
|
||||
func = get_lin_function()
|
||||
assert func(256) == 0.5
|
||||
assert func(4096) == 1.15
|
||||
|
||||
# Custom parameters
|
||||
custom_func = get_lin_function(x1=100, x2=1000, y1=0.1, y2=0.9)
|
||||
assert custom_func(100) == 0.1
|
||||
assert custom_func(1000) == 0.9
|
||||
|
||||
|
||||
def test_get_schedule():
|
||||
# Basic schedule
|
||||
schedule = get_schedule(num_steps=10, image_seq_len=256)
|
||||
assert len(schedule) == 10
|
||||
assert all(0 <= x <= 1 for x in schedule)
|
||||
|
||||
# Test different sequence lengths
|
||||
short_schedule = get_schedule(num_steps=5, image_seq_len=128)
|
||||
long_schedule = get_schedule(num_steps=15, image_seq_len=1024)
|
||||
assert len(short_schedule) == 5
|
||||
assert len(long_schedule) == 15
|
||||
|
||||
# Test with shift disabled
|
||||
unshifted_schedule = get_schedule(num_steps=10, image_seq_len=256, shift=False)
|
||||
assert torch.allclose(
|
||||
torch.tensor(unshifted_schedule),
|
||||
torch.linspace(1, 1/10, 10)
|
||||
)
|
||||
|
||||
|
||||
def test_compute_density_for_timestep_sampling():
|
||||
# Test uniform sampling
|
||||
uniform_samples = compute_density_for_timestep_sampling("uniform", batch_size=100)
|
||||
assert len(uniform_samples) == 100
|
||||
assert torch.all((uniform_samples >= 0) & (uniform_samples <= 1))
|
||||
|
||||
# Test logit normal sampling
|
||||
logit_normal_samples = compute_density_for_timestep_sampling(
|
||||
"logit_normal", batch_size=100, logit_mean=0.0, logit_std=1.0
|
||||
)
|
||||
assert len(logit_normal_samples) == 100
|
||||
assert torch.all((logit_normal_samples >= 0) & (logit_normal_samples <= 1))
|
||||
|
||||
# Test mode sampling
|
||||
mode_samples = compute_density_for_timestep_sampling(
|
||||
"mode", batch_size=100, mode_scale=0.5
|
||||
)
|
||||
assert len(mode_samples) == 100
|
||||
assert torch.all((mode_samples >= 0) & (mode_samples <= 1))
|
||||
|
||||
|
||||
def test_get_sigmas():
|
||||
# Create a mock noise scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Test with default parameters
|
||||
timesteps = torch.tensor([100, 500, 900])
|
||||
sigmas = get_sigmas(scheduler, timesteps, device)
|
||||
|
||||
# Check shape and basic properties
|
||||
assert sigmas.shape[0] == 3
|
||||
assert torch.all(sigmas >= 0)
|
||||
|
||||
# Test with different n_dim
|
||||
sigmas_4d = get_sigmas(scheduler, timesteps, device, n_dim=4)
|
||||
assert sigmas_4d.ndim == 4
|
||||
|
||||
# Test with different dtype
|
||||
sigmas_float16 = get_sigmas(scheduler, timesteps, device, dtype=torch.float16)
|
||||
assert sigmas_float16.dtype == torch.float16
|
||||
|
||||
|
||||
def test_compute_loss_weighting_for_sd3():
|
||||
# Prepare some mock sigmas
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test sigma_sqrt weighting
|
||||
sqrt_weighting = compute_loss_weighting_for_sd3("sigma_sqrt", sigmas)
|
||||
assert torch.allclose(sqrt_weighting, 1 / (sigmas**2), rtol=1e-5)
|
||||
|
||||
# Test cosmap weighting
|
||||
cosmap_weighting = compute_loss_weighting_for_sd3("cosmap", sigmas)
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
expected_cosmap = 2 / (math.pi * bot)
|
||||
assert torch.allclose(cosmap_weighting, expected_cosmap, rtol=1e-5)
|
||||
|
||||
# Test default weighting
|
||||
default_weighting = compute_loss_weighting_for_sd3("unknown", sigmas)
|
||||
assert torch.all(default_weighting == 1)
|
||||
|
||||
|
||||
def test_apply_model_prediction_type():
|
||||
# Create mock args and tensors
|
||||
class MockArgs:
|
||||
model_prediction_type = "raw"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
|
||||
args = MockArgs()
|
||||
model_pred = torch.tensor([1.0, 2.0, 3.0])
|
||||
noisy_model_input = torch.tensor([0.5, 1.0, 1.5])
|
||||
sigmas = torch.tensor([0.1, 0.5, 1.0])
|
||||
|
||||
# Test raw prediction type
|
||||
raw_pred, raw_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(raw_pred == model_pred)
|
||||
assert raw_weighting is None
|
||||
|
||||
# Test additive prediction type
|
||||
args.model_prediction_type = "additive"
|
||||
additive_pred, _ = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(additive_pred == model_pred + noisy_model_input)
|
||||
|
||||
# Test sigma scaled prediction type
|
||||
args.model_prediction_type = "sigma_scaled"
|
||||
sigma_scaled_pred, sigma_weighting = apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
assert torch.all(sigma_scaled_pred == model_pred * (-sigmas) + noisy_model_input)
|
||||
assert sigma_weighting is not None
|
||||
|
||||
|
||||
def test_retrieve_timesteps():
|
||||
# Create a mock scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
|
||||
# Test with num_inference_steps
|
||||
timesteps, n_steps = retrieve_timesteps(scheduler, num_inference_steps=50)
|
||||
assert len(timesteps) == 50
|
||||
assert n_steps == 50
|
||||
|
||||
# Test error handling with simultaneous timesteps and sigmas
|
||||
with pytest.raises(ValueError):
|
||||
retrieve_timesteps(scheduler, timesteps=[1, 2, 3], sigmas=[0.1, 0.2, 0.3])
|
||||
|
||||
|
||||
def test_get_noisy_model_input_and_timesteps():
|
||||
# Create a mock args and setup
|
||||
class MockArgs:
|
||||
timestep_sampling = "uniform"
|
||||
weighting_scheme = "sigma_sqrt"
|
||||
sigmoid_scale = 1.0
|
||||
discrete_flow_shift = 6.0
|
||||
|
||||
args = MockArgs()
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
||||
device = torch.device('cpu')
|
||||
|
||||
# Prepare mock latents and noise
|
||||
latents = torch.randn(4, 16, 64, 64)
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
# Test uniform sampling
|
||||
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
|
||||
# Validate output shapes and types
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
assert noisy_input.dtype == torch.float32
|
||||
assert timesteps.dtype == torch.float32
|
||||
|
||||
# Test different sampling methods
|
||||
sampling_methods = ["sigmoid", "shift", "nextdit_shift"]
|
||||
for method in sampling_methods:
|
||||
args.timestep_sampling = method
|
||||
noisy_input, timesteps, _ = get_noisy_model_input_and_timesteps(
|
||||
args, scheduler, latents, noise, device, torch.float32
|
||||
)
|
||||
assert noisy_input.shape == latents.shape
|
||||
assert timesteps.shape[0] == latents.shape[0]
|
||||
112
tests/library/test_lumina_util.py
Normal file
112
tests/library/test_lumina_util.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.nn.modules import conv
|
||||
|
||||
from library import lumina_util
|
||||
|
||||
|
||||
def test_unpack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
x = torch.randn(2, 4, 16) # 2 batches, 4 tokens, 16 channels
|
||||
packed_latent_height = 2
|
||||
packed_latent_width = 2
|
||||
|
||||
# Unpack the latents
|
||||
unpacked = lumina_util.unpack_latents(x, packed_latent_height, packed_latent_width)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
assert unpacked.shape == (2, 4, 4, 4)
|
||||
|
||||
|
||||
def test_pack_latents():
|
||||
# Create a test tensor
|
||||
# Shape: [batch, channels, height*patch_height, width*patch_width]
|
||||
x = torch.randn(2, 4, 4, 4)
|
||||
|
||||
# Pack the latents
|
||||
packed = lumina_util.pack_latents(x)
|
||||
|
||||
# Check output shape
|
||||
# Expected shape: [batch, height*width, channels*patch_height*patch_width]
|
||||
assert packed.shape == (2, 4, 16)
|
||||
|
||||
|
||||
def test_convert_diffusers_sd_to_alpha_vllm():
|
||||
num_double_blocks = 2
|
||||
# Predefined test cases based on the actual conversion map
|
||||
test_cases = [
|
||||
# Static key conversions with possible list mappings
|
||||
{
|
||||
"original_keys": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"original_pattern": ["time_caption_embed.caption_embedder.0.weight"],
|
||||
"expected_converted_keys": ["cap_embedder.0.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["patch_embedder.proj.weight"],
|
||||
"original_pattern": ["patch_embedder.proj.weight"],
|
||||
"expected_converted_keys": ["x_embedder.weight"],
|
||||
},
|
||||
{
|
||||
"original_keys": ["transformer_blocks.0.norm1.weight"],
|
||||
"original_pattern": ["transformer_blocks.().norm1.weight"],
|
||||
"expected_converted_keys": ["layers.0.attention_norm1.weight"],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for test_case in test_cases:
|
||||
for original_key, original_pattern, expected_converted_key in zip(
|
||||
test_case["original_keys"], test_case["original_pattern"], test_case["expected_converted_keys"]
|
||||
):
|
||||
# Create test state dict
|
||||
test_sd = {original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion (handle both string and list keys)
|
||||
# Find the correct converted key
|
||||
match_found = False
|
||||
if expected_converted_key in converted_sd:
|
||||
# Verify tensor preservation
|
||||
assert torch.allclose(converted_sd[expected_converted_key], test_sd[original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for {original_key}"
|
||||
)
|
||||
match_found = True
|
||||
break
|
||||
|
||||
assert match_found, f"Failed to convert {original_key}"
|
||||
|
||||
# Ensure original key is also present
|
||||
assert original_key in converted_sd
|
||||
|
||||
# Test with block-specific keys
|
||||
block_specific_cases = [
|
||||
{
|
||||
"original_pattern": "transformer_blocks.().norm1.weight",
|
||||
"converted_pattern": "layers.().attention_norm1.weight",
|
||||
}
|
||||
]
|
||||
|
||||
for case in block_specific_cases:
|
||||
for block_idx in range(2): # Test multiple block indices
|
||||
# Prepare block-specific keys
|
||||
block_original_key = case["original_pattern"].replace("()", str(block_idx))
|
||||
block_converted_key = case["converted_pattern"].replace("()", str(block_idx))
|
||||
print(block_original_key, block_converted_key)
|
||||
|
||||
# Create test state dict
|
||||
test_sd = {block_original_key: torch.randn(10, 10)}
|
||||
|
||||
# Convert the state dict
|
||||
converted_sd = lumina_util.convert_diffusers_sd_to_alpha_vllm(test_sd, num_double_blocks)
|
||||
|
||||
# Verify conversion
|
||||
# assert block_converted_key in converted_sd, f"Failed to convert block key {block_original_key}"
|
||||
assert torch.allclose(converted_sd[block_converted_key], test_sd[block_original_key], atol=1e-6), (
|
||||
f"Tensor mismatch for block key {block_original_key}"
|
||||
)
|
||||
|
||||
# Ensure original key is also present
|
||||
assert block_original_key in converted_sd
|
||||
227
tests/library/test_strategy_lumina.py
Normal file
227
tests/library/test_strategy_lumina.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import numpy as np
|
||||
from unittest.mock import patch
|
||||
from transformers import Gemma2Model
|
||||
|
||||
from library.strategy_lumina import (
|
||||
LuminaTokenizeStrategy,
|
||||
LuminaTextEncodingStrategy,
|
||||
LuminaTextEncoderOutputsCachingStrategy,
|
||||
LuminaLatentsCachingStrategy,
|
||||
)
|
||||
|
||||
|
||||
class SimpleMockGemma2Model:
|
||||
"""Lightweight mock that avoids initializing the actual Gemma2Model"""
|
||||
|
||||
def __init__(self, hidden_size=2304):
|
||||
self.device = torch.device("cpu")
|
||||
self._hidden_size = hidden_size
|
||||
self._orig_mod = self # For dynamic compilation compatibility
|
||||
|
||||
def __call__(self, input_ids, attention_mask, output_hidden_states=False, return_dict=False):
|
||||
# Create a mock output object with hidden states
|
||||
batch_size, seq_len = input_ids.shape
|
||||
hidden_size = self._hidden_size
|
||||
|
||||
class MockOutput:
|
||||
def __init__(self, hidden_states):
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
mock_hidden_states = [
|
||||
torch.randn(batch_size, seq_len, hidden_size, device=input_ids.device)
|
||||
for _ in range(3) # Mimic multiple layers of hidden states
|
||||
]
|
||||
|
||||
return MockOutput(mock_hidden_states)
|
||||
|
||||
|
||||
def test_lumina_tokenize_strategy():
|
||||
# Test default initialization
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
assert tokenize_strategy.max_length == 256
|
||||
assert tokenize_strategy.tokenizer.padding_side == "right"
|
||||
|
||||
# Test tokenization of a single string
|
||||
text = "Hello"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
assert tokens.ndim == 2
|
||||
assert attention_mask.ndim == 2
|
||||
assert tokens.shape == attention_mask.shape
|
||||
assert tokens.shape[1] == 256 # max_length
|
||||
|
||||
# Test tokenize_with_weights
|
||||
tokens, attention_mask, weights = tokenize_strategy.tokenize_with_weights(text)
|
||||
assert len(weights) == 1
|
||||
assert torch.all(weights[0] == 1)
|
||||
|
||||
|
||||
def test_lumina_text_encoding_strategy():
|
||||
# Create strategies
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
|
||||
# Create a mock model
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Prepare sample text
|
||||
text = "Test encoding strategy"
|
||||
tokens, attention_mask = tokenize_strategy.tokenize(text)
|
||||
|
||||
# Perform encoding
|
||||
hidden_states, input_ids, attention_masks = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask)
|
||||
)
|
||||
|
||||
# Validate outputs
|
||||
assert original_isinstance(hidden_states, torch.Tensor)
|
||||
assert original_isinstance(input_ids, torch.Tensor)
|
||||
assert original_isinstance(attention_masks, torch.Tensor)
|
||||
|
||||
# Check the shape of the second-to-last hidden state
|
||||
assert hidden_states.ndim == 3
|
||||
|
||||
# Test weighted encoding (which falls back to standard encoding for Lumina)
|
||||
weights = [torch.ones_like(tokens)]
|
||||
hidden_states_w, input_ids_w, attention_masks_w = encoding_strategy.encode_tokens_with_weights(
|
||||
tokenize_strategy, [mock_model], (tokens, attention_mask), weights
|
||||
)
|
||||
|
||||
# For the mock, we can't guarantee identical outputs since each call returns random tensors
|
||||
# Instead, check that the outputs have the same shape and are tensors
|
||||
assert hidden_states_w.shape == hidden_states.shape
|
||||
assert original_isinstance(hidden_states_w, torch.Tensor)
|
||||
assert torch.allclose(input_ids, input_ids_w) # Input IDs should be the same
|
||||
assert torch.allclose(attention_masks, attention_masks_w) # Attention masks should be the same
|
||||
|
||||
|
||||
def test_lumina_text_encoder_outputs_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Create a cache file path
|
||||
cache_file = os.path.join(tmpdir, "test_outputs.npz")
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaTextEncoderOutputsCachingStrategy(
|
||||
cache_to_disk=True,
|
||||
batch_size=1,
|
||||
skip_disk_cache_validity_check=False,
|
||||
)
|
||||
|
||||
# Create a mock class for ImageInfo
|
||||
class MockImageInfo:
|
||||
def __init__(self, caption, system_prompt, cache_path):
|
||||
self.caption = caption
|
||||
self.system_prompt = system_prompt
|
||||
self.text_encoder_outputs_npz = cache_path
|
||||
|
||||
# Create a sample input info
|
||||
image_info = MockImageInfo("Test caption", "", cache_file)
|
||||
|
||||
# Simulate a batch
|
||||
batch = [image_info]
|
||||
|
||||
# Create mock strategies and model
|
||||
tokenize_strategy = LuminaTokenizeStrategy(max_length=None)
|
||||
encoding_strategy = LuminaTextEncodingStrategy()
|
||||
mock_model = SimpleMockGemma2Model()
|
||||
|
||||
# Patch the isinstance check to accept our simple mock
|
||||
original_isinstance = isinstance
|
||||
with patch("library.strategy_lumina.isinstance") as mock_isinstance:
|
||||
|
||||
def custom_isinstance(obj, class_or_tuple):
|
||||
if obj is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
if hasattr(obj, "_orig_mod") and obj._orig_mod is mock_model and class_or_tuple is Gemma2Model:
|
||||
return True
|
||||
return original_isinstance(obj, class_or_tuple)
|
||||
|
||||
mock_isinstance.side_effect = custom_isinstance
|
||||
|
||||
# Call cache_batch_outputs
|
||||
caching_strategy.cache_batch_outputs(tokenize_strategy, [mock_model], encoding_strategy, batch)
|
||||
|
||||
# Verify the npz file was created
|
||||
assert os.path.exists(cache_file), f"Cache file not created at {cache_file}"
|
||||
|
||||
# Verify the is_disk_cached_outputs_expected method
|
||||
assert caching_strategy.is_disk_cached_outputs_expected(cache_file)
|
||||
|
||||
# Test loading from npz
|
||||
loaded_data = caching_strategy.load_outputs_npz(cache_file)
|
||||
assert len(loaded_data) == 3 # hidden_state, input_ids, attention_mask
|
||||
|
||||
|
||||
def test_lumina_latents_caching_strategy():
|
||||
# Create a temporary directory for caching
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Prepare a mock absolute path
|
||||
abs_path = os.path.join(tmpdir, "test_image.png")
|
||||
|
||||
# Use smaller image size for faster testing
|
||||
image_size = (64, 64)
|
||||
|
||||
# Create a smaller dummy image for testing
|
||||
test_image = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
|
||||
|
||||
# Create the caching strategy
|
||||
caching_strategy = LuminaLatentsCachingStrategy(cache_to_disk=True, batch_size=1, skip_disk_cache_validity_check=False)
|
||||
|
||||
# Create a simple mock VAE
|
||||
class MockVAE:
|
||||
def __init__(self):
|
||||
self.device = torch.device("cpu")
|
||||
self.dtype = torch.float32
|
||||
|
||||
def encode(self, x):
|
||||
# Return smaller encoded tensor for faster processing
|
||||
encoded = torch.randn(1, 4, 8, 8, device=x.device)
|
||||
return type("EncodedLatents", (), {"to": lambda *args, **kwargs: encoded})
|
||||
|
||||
# Prepare a mock batch
|
||||
class MockImageInfo:
|
||||
def __init__(self, path, image):
|
||||
self.absolute_path = path
|
||||
self.image = image
|
||||
self.image_path = path
|
||||
self.bucket_reso = image_size
|
||||
self.resized_size = image_size
|
||||
self.resize_interpolation = "lanczos"
|
||||
# Specify full path to the latents npz file
|
||||
self.latents_npz = os.path.join(tmpdir, f"{os.path.splitext(os.path.basename(path))[0]}_0064x0064_lumina.npz")
|
||||
|
||||
batch = [MockImageInfo(abs_path, test_image)]
|
||||
|
||||
# Call cache_batch_latents
|
||||
mock_vae = MockVAE()
|
||||
caching_strategy.cache_batch_latents(mock_vae, batch, flip_aug=False, alpha_mask=False, random_crop=False)
|
||||
|
||||
# Generate the expected npz path
|
||||
npz_path = caching_strategy.get_latents_npz_path(abs_path, image_size)
|
||||
|
||||
# Verify the file was created
|
||||
assert os.path.exists(npz_path), f"NPZ file not created at {npz_path}"
|
||||
|
||||
# Verify is_disk_cached_latents_expected
|
||||
assert caching_strategy.is_disk_cached_latents_expected(image_size, npz_path, False, False)
|
||||
|
||||
# Test loading from disk
|
||||
loaded_data = caching_strategy.load_latents_from_disk(npz_path, image_size)
|
||||
assert len(loaded_data) == 5 # Check for 5 expected elements
|
||||
173
tests/test_lumina_train_network.py
Normal file
173
tests/test_lumina_train_network.py
Normal file
@@ -0,0 +1,173 @@
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock, patch
|
||||
import argparse
|
||||
|
||||
from library import lumina_models, lumina_util
|
||||
from lumina_train_network import LuminaNetworkTrainer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lumina_trainer():
|
||||
return LuminaNetworkTrainer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
args = MagicMock()
|
||||
args.pretrained_model_name_or_path = "test_path"
|
||||
args.disable_mmap_load_safetensors = False
|
||||
args.use_flash_attn = False
|
||||
args.use_sage_attn = False
|
||||
args.fp8_base = False
|
||||
args.blocks_to_swap = None
|
||||
args.gemma2 = "test_gemma2_path"
|
||||
args.ae = "test_ae_path"
|
||||
args.cache_text_encoder_outputs = True
|
||||
args.cache_text_encoder_outputs_to_disk = False
|
||||
args.network_train_unet_only = False
|
||||
return args
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_accelerator():
|
||||
accelerator = MagicMock()
|
||||
accelerator.device = torch.device("cpu")
|
||||
accelerator.prepare.side_effect = lambda x, **kwargs: x
|
||||
accelerator.unwrap_model.side_effect = lambda x: x
|
||||
return accelerator
|
||||
|
||||
|
||||
def test_assert_extra_args(lumina_trainer, mock_args):
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
|
||||
# Test with default settings
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# Verify verify_bucket_reso_steps was called for both groups
|
||||
assert train_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
assert val_dataset_group.verify_bucket_reso_steps.call_count > 0
|
||||
|
||||
# Check text encoder output caching
|
||||
assert lumina_trainer.train_gemma2 is (not mock_args.network_train_unet_only)
|
||||
assert mock_args.cache_text_encoder_outputs is True
|
||||
|
||||
|
||||
def test_load_target_model(lumina_trainer, mock_args, mock_accelerator):
|
||||
# Patch lumina_util methods
|
||||
with (
|
||||
patch("library.lumina_util.load_lumina_model") as mock_load_lumina_model,
|
||||
patch("library.lumina_util.load_gemma2") as mock_load_gemma2,
|
||||
patch("library.lumina_util.load_ae") as mock_load_ae
|
||||
):
|
||||
# Create mock models
|
||||
mock_model = MagicMock(spec=lumina_models.NextDiT)
|
||||
mock_model.dtype = torch.float32
|
||||
mock_gemma2 = MagicMock()
|
||||
mock_ae = MagicMock()
|
||||
|
||||
mock_load_lumina_model.return_value = mock_model
|
||||
mock_load_gemma2.return_value = mock_gemma2
|
||||
mock_load_ae.return_value = mock_ae
|
||||
|
||||
# Test load_target_model
|
||||
version, gemma2_list, ae, model = lumina_trainer.load_target_model(mock_args, torch.float32, mock_accelerator)
|
||||
|
||||
# Verify calls and return values
|
||||
assert version == lumina_util.MODEL_VERSION_LUMINA_V2
|
||||
assert gemma2_list == [mock_gemma2]
|
||||
assert ae == mock_ae
|
||||
assert model == mock_model
|
||||
|
||||
# Verify load calls
|
||||
mock_load_lumina_model.assert_called_once()
|
||||
mock_load_gemma2.assert_called_once()
|
||||
mock_load_ae.assert_called_once()
|
||||
|
||||
|
||||
def test_get_strategies(lumina_trainer, mock_args):
|
||||
# Test tokenize strategy
|
||||
tokenize_strategy = lumina_trainer.get_tokenize_strategy(mock_args)
|
||||
assert tokenize_strategy.__class__.__name__ == "LuminaTokenizeStrategy"
|
||||
|
||||
# Test latents caching strategy
|
||||
latents_strategy = lumina_trainer.get_latents_caching_strategy(mock_args)
|
||||
assert latents_strategy.__class__.__name__ == "LuminaLatentsCachingStrategy"
|
||||
|
||||
# Test text encoding strategy
|
||||
text_encoding_strategy = lumina_trainer.get_text_encoding_strategy(mock_args)
|
||||
assert text_encoding_strategy.__class__.__name__ == "LuminaTextEncodingStrategy"
|
||||
|
||||
|
||||
def test_text_encoder_output_caching_strategy(lumina_trainer, mock_args):
|
||||
# Call assert_extra_args to set train_gemma2
|
||||
train_dataset_group = MagicMock()
|
||||
train_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
val_dataset_group = MagicMock()
|
||||
val_dataset_group.verify_bucket_reso_steps = MagicMock()
|
||||
lumina_trainer.assert_extra_args(mock_args, train_dataset_group, val_dataset_group)
|
||||
|
||||
# With text encoder caching enabled
|
||||
mock_args.skip_cache_check = False
|
||||
mock_args.text_encoder_batch_size = 16
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
|
||||
assert strategy.__class__.__name__ == "LuminaTextEncoderOutputsCachingStrategy"
|
||||
assert strategy.cache_to_disk is False # based on mock_args
|
||||
|
||||
# With text encoder caching disabled
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
strategy = lumina_trainer.get_text_encoder_outputs_caching_strategy(mock_args)
|
||||
assert strategy is None
|
||||
|
||||
|
||||
def test_noise_scheduler(lumina_trainer, mock_args):
|
||||
device = torch.device("cpu")
|
||||
noise_scheduler = lumina_trainer.get_noise_scheduler(mock_args, device)
|
||||
|
||||
assert noise_scheduler.__class__.__name__ == "FlowMatchEulerDiscreteScheduler"
|
||||
assert noise_scheduler.num_train_timesteps == 1000
|
||||
assert hasattr(lumina_trainer, "noise_scheduler_copy")
|
||||
|
||||
|
||||
def test_sai_model_spec(lumina_trainer, mock_args):
|
||||
with patch("library.train_util.get_sai_model_spec") as mock_get_spec:
|
||||
mock_get_spec.return_value = "test_spec"
|
||||
spec = lumina_trainer.get_sai_model_spec(mock_args)
|
||||
assert spec == "test_spec"
|
||||
mock_get_spec.assert_called_once_with(None, mock_args, False, True, False, lumina="lumina2")
|
||||
|
||||
|
||||
def test_update_metadata(lumina_trainer, mock_args):
|
||||
metadata = {}
|
||||
lumina_trainer.update_metadata(metadata, mock_args)
|
||||
|
||||
assert "ss_weighting_scheme" in metadata
|
||||
assert "ss_logit_mean" in metadata
|
||||
assert "ss_logit_std" in metadata
|
||||
assert "ss_mode_scale" in metadata
|
||||
assert "ss_timestep_sampling" in metadata
|
||||
assert "ss_sigmoid_scale" in metadata
|
||||
assert "ss_model_prediction_type" in metadata
|
||||
assert "ss_discrete_flow_shift" in metadata
|
||||
|
||||
|
||||
def test_is_text_encoder_not_needed_for_training(lumina_trainer, mock_args):
|
||||
# Test with text encoder output caching, but not training text encoder
|
||||
mock_args.cache_text_encoder_outputs = True
|
||||
with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=False):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is True
|
||||
|
||||
# Test with text encoder output caching and training text encoder
|
||||
with patch.object(lumina_trainer, 'is_train_text_encoder', return_value=True):
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
|
||||
# Test with no text encoder output caching
|
||||
mock_args.cache_text_encoder_outputs = False
|
||||
result = lumina_trainer.is_text_encoder_not_needed_for_training(mock_args)
|
||||
assert result is False
|
||||
@@ -389,7 +389,18 @@ class NetworkTrainer:
|
||||
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
|
||||
else:
|
||||
# latentに変換
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
|
||||
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
|
||||
else:
|
||||
chunks = [
|
||||
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
|
||||
]
|
||||
list_latents = []
|
||||
for chunk in chunks:
|
||||
with torch.no_grad():
|
||||
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
|
||||
list_latents.append(chunk)
|
||||
latents = torch.cat(list_latents, dim=0)
|
||||
|
||||
# NaNが含まれていれば警告を表示し0に置き換える
|
||||
if torch.any(torch.isnan(latents)):
|
||||
@@ -1433,11 +1444,13 @@ class NetworkTrainer:
|
||||
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
|
||||
else:
|
||||
if hasattr(network, "weight_norms"):
|
||||
mean_norm = network.weight_norms().mean().item()
|
||||
mean_grad_norm = network.grad_norms().mean().item()
|
||||
mean_combined_norm = network.combined_weight_norms().mean().item()
|
||||
weight_norms = network.weight_norms()
|
||||
maximum_norm = weight_norms.max().item() if weight_norms.numel() > 0 else None
|
||||
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
|
||||
grad_norms = network.grad_norms()
|
||||
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
|
||||
combined_weight_norms = network.combined_weight_norms()
|
||||
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
|
||||
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
|
||||
keys_scaled = None
|
||||
max_mean_logs = {}
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user