Merge branch 'sd3' into no-grad-block-swap

This commit is contained in:
Kohya S
2025-05-24 18:50:25 +09:00
25 changed files with 1111 additions and 158 deletions

3
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,3 @@
# These are supported funding model platforms
github: kohya-ss

View File

@@ -12,6 +12,9 @@ on:
- dev
- sd3
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
permissions: read-all
jobs:
build:
runs-on: ${{ matrix.os }}
@@ -40,7 +43,7 @@ jobs:
- name: Install dependencies
run: |
# Pre-install torch to pin version (requirements.txt has dependencies like transformers which requires pytorch)
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision==0.19.0 pytest==8.3.4
pip install dadaptation==3.2 torch==${{ matrix.pytorch-version }} torchvision pytest==8.3.4
pip install -r requirements.txt
- name: Test with pytest

View File

@@ -12,6 +12,9 @@ on:
- synchronize
- reopened
# CKV2_GHA_1: "Ensure top-level permissions are not set to write-all"
permissions: read-all
jobs:
build:
runs-on: ubuntu-latest

View File

@@ -9,11 +9,35 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)
### Recent Updates
May 1, 2025:
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.
Apr 27, 2025:
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
- See [here](#sample-image-generation-during-training) for details.
- If you have any issues with this, please let us know.
Apr 6, 2025:
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).
Mar 20, 2025:
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
Mar 6, 2025:
- Added a utility script to merge the weights of SD3's DiT, VAE (optional), CLIP-L, CLIP-G, and T5XXL into a single .safetensors file. Run `tools/merge_sd3_safetensors.py`. See `--help` for usage. PR [#1960](https://github.com/kohya-ss/sd-scripts/pull/1960)
@@ -748,6 +772,8 @@ Not available yet.
[__Change History__](#change-history) is moved to the bottom of the page.
更新履歴は[ページ末尾](#change-history)に移しました。
Latest update: 2025-03-21 (Version 0.9.1)
[日本語版READMEはこちら](./README-ja.md)
The development version is in the `dev` branch. Please check the dev branch for the latest changes.
@@ -855,6 +881,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o
(Single GPU with id `0` will be used.)
## DeepSpeed installation (experimental, Linux or WSL2 only)
To install DeepSpeed, run the following command in your activated virtual environment:
```bash
pip install deepspeed==0.16.7
```
## Upgrade
When a new release comes out you can upgrade your repo with the following command:
@@ -891,6 +925,11 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser
## Change History
### Mar 21, 2025 / 2025-03-21 Version 0.9.1
- Fixed a bug where some of LoRA modules for CLIP Text Encoder were not trained. Thank you Nekotekina for PR [#1964](https://github.com/kohya-ss/sd-scripts/pull/1964)
- The LoRA modules for CLIP Text Encoder are now 264 modules, which is the same as before. Only 88 modules were trained in the previous version.
### Jan 17, 2025 / 2025-01-17 Version 0.9.0
- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries.
@@ -1324,11 +1363,13 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b
Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.
* `--n` Negative prompt up to the next option.
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* In guidance distillation models like FLUX.1, this value is used as the embedded guidance scale for backward compatibility.
* `--g` Specifies the CFG scale for the models with embedded guidance scale. The default is `1.0`, `1.0` means no CFG. In general, should not be changed unless you train the un-distilled FLUX.1 models.
* `--s` Specifies the number of steps in the generation.
The prompt weighting such as `( )` and `[ ]` are working.

View File

@@ -152,6 +152,7 @@ These options are related to subset configuration.
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` | (not specified) | o | o | o |
* `num_repeats`
* Specifies the number of repeats for images in a subset. This is equivalent to `--dataset_repeats` in fine-tuning but can be specified for any training method.
@@ -165,6 +166,8 @@ These options are related to subset configuration.
* Specifies an additional separator. The part separated by this separator is treated as one tag and is shuffled and dropped. It is then replaced by `caption_separator`. For example, if you specify `aaa;;;bbb;;;ccc`, it will be replaced by `aaa,bbb,ccc` or dropped together.
* `enable_wildcard`
* Enables wildcard notation. This will be explained later.
* `resize_interpolation`
* Specifies the interpolation method used when resizing images. Normally, there is no need to specify this. The following options can be specified: `lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box`. By default (when not specified), `area` is used for downscaling, and `lanczos` is used for upscaling. If this option is specified, the same interpolation method will be used for both upscaling and downscaling. When `lanczos` or `box` is specified, PIL is used; for other options, OpenCV is used.
### DreamBooth-specific options

View File

@@ -144,6 +144,7 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
| `keep_tokens_separator` | `“|||”` | o | o | o |
| `secondary_separator` | `“;;;”` | o | o | o |
| `enable_wildcard` | `true` | o | o | o |
| `resize_interpolation` |(通常は設定しません) | o | o | o |
* `num_repeats`
* サブセットの画像の繰り返し回数を指定します。fine tuning における `--dataset_repeats` に相当しますが、`num_repeats` はどの学習方法でも指定可能です。
@@ -162,6 +163,9 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学
* `enable_wildcard`
* ワイルドカード記法および複数行キャプションを有効にします。ワイルドカード記法、複数行キャプションについては後述します。
* `resize_interpolation`
* 画像のリサイズ時に使用する補間方法を指定します。通常は指定しなくて構いません。`lanczos`, `nearest`, `bilinear`, `linear`, `bicubic`, `cubic`, `area`, `box` が指定可能です。デフォルト(未指定時)は、縮小時は `area`、拡大時は `lanczos` になります。このオプションを指定すると、拡大時・縮小時とも同じ補間方法が使用されます。`lanczos``box`を指定するとPILが、それ以外を指定するとOpenCVが使用されます。
### DreamBooth 方式専用のオプション
DreamBooth 方式のオプションは、サブセット向けオプションのみ存在します。

View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm
import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)
image = image.astype(np.float32)
return image
@@ -100,15 +97,19 @@ def main(args):
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
local_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
hf_hub_download(
repo_id=args.repo_id,
filename=file,
local_dir=model_location,
force_download=True,
)
else:
logger.info("using existing wd14 tagger model")

View File

@@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
@@ -106,7 +107,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None
@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -196,6 +197,7 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -241,6 +243,7 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
}
# options handled by argparse but not handled by user config
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")
@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
"""), " ")

View File

@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator
from .utils import setup_logging
from .device_utils import get_preferred_device
setup_logging()
import logging
@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"
for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"
self.models.update(torch.nn.ModuleDict({key: model}))
def __wrap_model_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
else:
model = self.__wrap_model_forward_with_torch_autocast(model)
return model
def __wrap_model_forward_with_torch_autocast(self, model):
assert hasattr(model, "forward"), f"model must have a forward method."
forward_fn = model.forward
def forward(*args, **kwargs):
try:
device_type = model.device.type
except AttributeError:
logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()."
)
device_type = get_preferred_device().type
with torch.autocast(device_type = device_type):
return forward_fn(*args, **kwargs)
model.forward = forward
return model
def get_models(self):
return self.models
ds_model = DeepSpeedWrapper(**models)
return ds_model

View File

@@ -40,7 +40,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
controlnet=None,
):
if steps == 0:
if not args.sample_at_first:
@@ -101,7 +101,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -125,7 +125,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)
torch.set_rng_state(rng_state)
@@ -147,14 +147,16 @@ def sample_image_inference(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
# TODO refactor variable names
cfg_scale = prompt_dict.get("guidance_scale", 1.0)
emb_guidance_scale = prompt_dict.get("scale", 3.5)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
@@ -162,8 +164,8 @@ def sample_image_inference(
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if seed is not None:
torch.manual_seed(seed)
@@ -173,16 +175,21 @@ def sample_image_inference(
torch.seed()
torch.cuda.seed()
# if negative_prompt is None:
# negative_prompt = ""
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
@@ -191,26 +198,37 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()
text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
# encode negative prompts
if cfg_scale != 1.0:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
neg_t5_attn_mask = (
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
)
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
else:
neg_cond = None
# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -235,7 +253,20 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = denoise(
flux,
noise,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=emb_guidance_scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
neg_cond=neg_cond,
)
x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)
@@ -305,22 +336,24 @@ def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt: torch.Tensor, # t5_out
txt_ids: torch.Tensor,
vec: torch.Tensor,
vec: torch.Tensor, # l_pooled
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
do_cfg = neg_cond is not None
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
@@ -336,20 +369,48 @@ def denoise(
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
if not do_cfg:
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)
img = img + (t_prev - t_curr) * pred
else:
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
# TODO is it ok to use the same block samples for both cond and uncond?
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
block_single_samples = (
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
)
nc_c_pred = model(
img=torch.cat([img, img], dim=0),
img_ids=torch.cat([img_ids, img_ids], dim=0),
txt=torch.cat([neg_t5_out, txt], dim=0),
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
y=torch.cat([neg_l_pooled, vec], dim=0),
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=nc_c_t5_attn_mask,
)
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
pred = neg_pred + (pred - neg_pred) * cfg_scale
img = img + (t_prev - t_curr) * pred
model.prepare_block_swap_before_forward()
return img
@@ -366,8 +427,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
@@ -410,42 +469,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -456,12 +507,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -567,7 +630,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors"
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--t5xxl_max_token_length",

186
library/jpeg_xl_util.py Normal file
View File

@@ -0,0 +1,186 @@
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
# Added partial read support for up to 200x speedup
import os
from typing import List, Tuple
class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
self.shift = 0
self.bitstream = bytearray()
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)
def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream.extend(self.file.read(self.partial_to_read_length))
else:
self.bitstream.extend(self.file.read(length))
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits
def partial_read(self, current_length: int, length: int) -> None:
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream.extend(self.file.read(to_read_length))
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
"""
Decodes the actual codestream.
JXL codestream specification: http://www-internal/2022/18181-1
"""
# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
# Skip signature
codestream.get_bits(16)
# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height
def decode_container(file) -> Tuple[int,int]:
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""
def parse_box(file, file_start: int) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}
file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")
offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False
if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)
return decode_codestream(file, offset=offset, offsets=offsets)
def get_jxl_size(path: str) -> Tuple[int,int]:
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)

View File

@@ -74,7 +74,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image, validate_interpolation_fn
setup_logging()
import logging
@@ -113,14 +113,16 @@ except:
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
from library.jpeg_xl_util import get_jxl_size
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass
# JPEG-XL on Windows
# JPEG-XL on Linux and Windows
try:
import pillow_jxl
from library.jpeg_xl_util import get_jxl_size
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
@@ -205,6 +207,7 @@ class ImageInfo:
self.text_encoder_pool2: Optional[torch.Tensor] = None
self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None
class BucketManager:
@@ -429,6 +432,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -459,6 +463,8 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split
self.resize_interpolation = resize_interpolation
class DreamBoothSubset(BaseSubset):
def __init__(
@@ -490,6 +496,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -517,6 +524,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.is_reg = is_reg
@@ -559,6 +567,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
@@ -586,6 +595,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.metadata_file = metadata_file
@@ -624,6 +634,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
@@ -651,6 +662,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)
self.conditioning_data_dir = conditioning_data_dir
@@ -671,6 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
) -> None:
super().__init__()
@@ -705,6 +718,10 @@ class BaseDataset(torch.utils.data.Dataset):
self.image_transforms = IMAGE_TRANSFORMS
if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
self.resize_interpolation = resize_interpolation
self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}
@@ -1043,8 +1060,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
else:
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")
@@ -1448,6 +1468,8 @@ class BaseDataset(torch.utils.data.Dataset):
)
def get_image_size(self, image_path):
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
return get_jxl_size(image_path)
# return imagesize.get(image_path)
image_size = imagesize.get(image_path)
if image_size[0] <= 0:
@@ -1494,7 +1516,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -1591,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset):
if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -1852,8 +1874,9 @@ class DreamBoothDataset(BaseDataset):
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
assert resolution is not None, f"resolution is required / resolution解像度指定は必須です"
@@ -2078,6 +2101,7 @@ class DreamBoothDataset(BaseDataset):
for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2133,8 +2157,9 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
self.batch_size = batch_size
@@ -2360,9 +2385,10 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)
db_subsets = []
for subset in subsets:
@@ -2394,6 +2420,7 @@ class ControlNetDataset(BaseDataset):
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
)
db_subsets.append(db_subset)
@@ -2412,6 +2439,7 @@ class ControlNetDataset(BaseDataset):
debug_dataset,
validation_split,
validation_seed,
resize_interpolation,
)
# config_util等から参照される値をいれておく若干微妙なのでなんとかしたい
@@ -2420,7 +2448,8 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation
# assert all conditioning data exists
missing_imgs = []
@@ -2508,9 +2537,8 @@ class ControlNetDataset(BaseDataset):
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ
cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2524,7 +2552,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)
if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2921,17 +2949,13 @@ def load_image(image_path, alpha=False):
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize
if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)
image_height, image_width = image.shape[0:2]
@@ -2976,7 +3000,7 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3017,7 +3041,7 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)
info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -4495,7 +4519,13 @@ def add_dataset_arguments(
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)
parser.add_argument(
"--resize_interpolation",
type=str,
default=None,
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
)
parser.add_argument(
"--token_warmup_min",
type=int,
@@ -5468,6 +5498,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_
def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
@@ -6151,6 +6186,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["scale"] = float(m.group(1))
continue
m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
if m: # guidance scale
prompt_dict["guidance_scale"] = float(m.group(1))
continue
m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)
@@ -6533,3 +6573,4 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses

View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file
def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)
setup_logging()
logger = logging.getLogger(__name__)
# endregion
@@ -378,7 +379,7 @@ def load_safetensors(
# region Image utils
def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False
if has_alpha:
@@ -386,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)
# Convert back to cv2 format
if has_alpha:
@@ -397,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2
def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.
Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"
Returns:
image
"""
# Ensure all size parameters are actual integers
width = int(width)
height = int(height)
resized_width = int(resized_width)
resized_height = int(resized_height)
if resize_interpolation is None:
if width >= resized_width and height >= resized_height:
resize_interpolation = "area"
else:
resize_interpolation = "lanczos"
# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]
resized_size = (resized_width, resized_height)
if use_pil:
interpolation = get_pil_interpolation(resize_interpolation)
image = pil_resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (PIL)")
else:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (cv2)")
return image
def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer
https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None
if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None
def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation
https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None
if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesnt have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None
def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]
# endregion
# TODO make inf_utils.py

View File

@@ -268,7 +268,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh
class DyLoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -866,7 +866,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -278,7 +278,7 @@ def merge_lora_weights(pipe, weights_sd: Dict, multiplier: float = 1.0):
class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -755,7 +755,7 @@ class LoRANetwork(torch.nn.Module):
UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPSdpaAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"

View File

@@ -9,11 +9,13 @@
import math
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta
if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape
def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module
def forward(self, x):
@@ -140,7 +155,17 @@ class LoRAModule(torch.nn.Module):
lx = self.lora_up(lx)
return org_forwarded + lx * self.multiplier * scale
# LoRA Gradient-Guided Perturbation Optimization
if self.training and self.ggpo_sigma is not None and self.ggpo_beta is not None and self.combined_weight_norms is not None and self.grad_norms is not None:
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms ** 2)) + (self.ggpo_beta * (self.grad_norms ** 2))
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else:
return org_forwarded + lx * self.multiplier * scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
@@ -167,6 +192,116 @@ class LoRAModule(torch.nn.Module):
return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale
@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller
# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]
# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)
# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)
# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()
# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()
# Free memory
del sampled_weights, weights_float32
@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM
for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk
true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()
# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()
# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage
if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")
return {
'true_mean_norm': true_mean_norm,
'estimated_norm': estimated_norm,
'absolute_error': absolute_error,
'relative_error': relative_error
}
@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return
# only update norms when we are training
if self.training is False:
return
module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)
self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt((self.org_weight_norm_estimate**2) +
torch.sum(module_weights**2, dim=1, keepdim=True))
@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return
lora_down_grad = None
lora_up_grad = None
for name, param in self.named_parameters():
if name == "lora_down.weight":
lora_down_grad = param.grad
elif name == "lora_up.weight":
lora_up_grad = param.grad
# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)
@property
def device(self):
return next(self.parameters()).device
@property
def dtype(self):
return next(self.parameters()).dtype
class LoRAInfModule(LoRAModule):
def __init__(
@@ -420,6 +555,16 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False
ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)
if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)
if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)
# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -449,6 +594,8 @@ def create_network(
in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
verbose=verbose,
)
@@ -561,6 +708,8 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -599,10 +748,16 @@ class LoRANetwork(torch.nn.Module):
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")
if train_t5xxl:
logger.info(f"train T5XXL as well")
@@ -722,6 +877,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
)
loras.append(lora)
@@ -790,6 +947,36 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled
def update_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms()
def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()
def grad_norms(self) -> Tensor | None:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else None
def weight_norms(self) -> Tensor | None:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else None
def combined_weight_norms(self) -> Tensor | None:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None
def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file

View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

View File

@@ -7,9 +7,11 @@ opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.4
pytorch-optimizer==3.5.0
prodigy-plus-schedule-free==1.9.0
prodigyopt==1.1.2
tensorboard
safetensors==0.4.4
# gradio==3.16.2

View File

@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.is_sdxl = True
def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)
if args.cache_text_encoder_outputs:

View File

@@ -0,0 +1,220 @@
import pytest
import torch
from unittest.mock import MagicMock, patch
from library.flux_train_utils import (
get_noisy_model_input_and_timesteps,
)
# Mock classes and functions
class MockNoiseScheduler:
def __init__(self, num_train_timesteps=1000):
self.config = MagicMock()
self.config.num_train_timesteps = num_train_timesteps
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
# Create fixtures for commonly used objects
@pytest.fixture
def args():
args = MagicMock()
args.timestep_sampling = "uniform"
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
args.ip_noise_gamma = None
args.ip_noise_gamma_random_strength = False
return args
@pytest.fixture
def noise_scheduler():
return MockNoiseScheduler(num_train_timesteps=1000)
@pytest.fixture
def latents():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def noise():
return torch.randn(2, 4, 8, 8)
@pytest.fixture
def device():
# return "cuda" if torch.cuda.is_available() else "cpu"
return "cpu"
# Mock the required functions
@pytest.fixture(autouse=True)
def mock_functions():
with (
patch("torch.sigmoid", side_effect=torch.sigmoid),
patch("torch.rand", side_effect=torch.rand),
patch("torch.randn", side_effect=torch.randn),
):
yield
# Test different timestep sampling methods
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "uniform"
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "sigmoid"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "shift"
args.sigmoid_scale = 1.0
args.discrete_flow_shift = 3.1582
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
args.timestep_sampling = "flux_shift"
args.sigmoid_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
# Mock the necessary functions for this specific test
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
return_value=torch.tensor([0.3, 0.7], device=device)), \
patch("library.flux_train_utils.get_sigmas",
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
args.timestep_sampling = "other" # Will trigger the weighting scheme path
args.weighting_scheme = "uniform"
args.logit_mean = 0.0
args.logit_std = 1.0
args.mode_scale = 1.0
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test IP noise options
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.5
args.ip_noise_gamma_random_strength = False
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
args.ip_noise_gamma = 0.1
args.ip_noise_gamma_random_strength = True
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (latents.shape[0],)
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
# Test different data types
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
dtype = torch.float16
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.dtype == dtype
assert timesteps.dtype == dtype
# Test different batch sizes
def test_different_batch_size(args, noise_scheduler, device):
latents = torch.randn(5, 4, 8, 8) # batch size of 5
noise = torch.randn(5, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (5,)
assert sigmas.shape == (5, 1, 1, 1)
# Test different image sizes
def test_different_image_size(args, noise_scheduler, device):
latents = torch.randn(2, 4, 16, 16) # larger image size
noise = torch.randn(2, 4, 16, 16)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
assert sigmas.shape == (2, 1, 1, 1)
# Test edge cases
def test_zero_batch_size(args, noise_scheduler, device):
with pytest.raises(AssertionError): # expecting an error with zero batch size
latents = torch.randn(0, 4, 8, 8)
noise = torch.randn(0, 4, 8, 8)
dtype = torch.float32
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
def test_different_timestep_count(args, device):
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
latents = torch.randn(2, 4, 8, 8)
noise = torch.randn(2, 4, 8, 8)
dtype = torch.float32
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
assert noisy_input.shape == latents.shape
assert timesteps.shape == (2,)
# Check that timesteps are within the proper range
assert torch.all(timesteps < 500)

View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h)
if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
rw = int(w * scale + .5)
rh = int(h * scale + .5)
face_img = resize_image(face_img, w, h, rw, rh)
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)

View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)
# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA
# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))
# Resize image
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
else:
new_height, new_width = img.shape[0:2]
@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / サイズの補方法')
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')

View File

@@ -69,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}
if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm
lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
@@ -382,7 +389,18 @@ class NetworkTrainer:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -651,6 +669,10 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")
# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None
if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
@@ -1017,6 +1039,7 @@ class NetworkTrainer:
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1042,6 +1065,7 @@ class NetworkTrainer:
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}
subsets_metadata = []
@@ -1059,6 +1083,7 @@ class NetworkTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}
image_dir_or_metadata_file = None
@@ -1400,6 +1425,11 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1408,9 +1438,25 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
if hasattr(network, "weight_norms"):
weight_norms = network.weight_norms()
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1442,14 +1488,21 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
progress_bar.set_postfix(**{**max_mean_logs, **logs})
if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)