mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
333 lines
14 KiB
Python
333 lines
14 KiB
Python
import argparse
|
|
import copy
|
|
import math
|
|
import random
|
|
from typing import Any
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from library.device_utils import init_ipex, clean_memory_on_device
|
|
|
|
init_ipex()
|
|
|
|
from library import flux_models, flux_utils, sd3_train_utils, sd3_utils, sdxl_model_util, sdxl_train_util, strategy_flux, train_util
|
|
import train_network
|
|
from library.utils import setup_logging
|
|
|
|
setup_logging()
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def assert_extra_args(self, args, train_dataset_group):
|
|
super().assert_extra_args(args, train_dataset_group)
|
|
# sdxl_train_util.verify_sdxl_training_args(args)
|
|
|
|
if args.cache_text_encoder_outputs:
|
|
assert (
|
|
train_dataset_group.is_text_encoder_output_cacheable()
|
|
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"
|
|
|
|
assert (
|
|
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
|
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
|
|
|
train_dataset_group.verify_bucket_reso_steps(32)
|
|
|
|
def load_target_model(self, args, weight_dtype, accelerator):
|
|
# currently offload to cpu for some models
|
|
|
|
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
|
|
clip_l.eval()
|
|
|
|
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
|
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
|
|
t5xxl.eval()
|
|
|
|
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
|
|
# if we load to cpu, flux.to(fp8) takes a long time
|
|
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
|
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
|
|
|
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
|
|
|
def get_tokenize_strategy(self, args):
|
|
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
|
|
|
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
|
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
|
|
|
def get_latents_caching_strategy(self, args):
|
|
latents_caching_strategy = strategy_flux.FluxLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)
|
|
return latents_caching_strategy
|
|
|
|
def get_text_encoding_strategy(self, args):
|
|
return strategy_flux.FluxTextEncodingStrategy()
|
|
|
|
def get_models_for_text_encoding(self, args, accelerator, text_encoders):
|
|
return text_encoders # + [accelerator.unwrap_model(text_encoders[-1])]
|
|
|
|
def get_text_encoder_outputs_caching_strategy(self, args):
|
|
if args.cache_text_encoder_outputs:
|
|
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(args.cache_text_encoder_outputs_to_disk, None, False)
|
|
else:
|
|
return None
|
|
|
|
def cache_text_encoder_outputs_if_needed(
|
|
self, args, accelerator: Accelerator, unet, vae, text_encoders, dataset: train_util.DatasetGroup, weight_dtype
|
|
):
|
|
if args.cache_text_encoder_outputs:
|
|
if not args.lowram:
|
|
# メモリ消費を減らす
|
|
logger.info("move vae and unet to cpu to save memory")
|
|
org_vae_device = vae.device
|
|
org_unet_device = unet.device
|
|
vae.to("cpu")
|
|
unet.to("cpu")
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
|
logger.info("move text encoders to gpu")
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
|
with accelerator.autocast():
|
|
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator.is_main_process)
|
|
accelerator.wait_for_everyone()
|
|
|
|
logger.info("move text encoders back to cpu")
|
|
text_encoders[0].to("cpu") # , dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU
|
|
text_encoders[1].to("cpu") # , dtype=torch.float32)
|
|
clean_memory_on_device(accelerator.device)
|
|
|
|
if not args.lowram:
|
|
logger.info("move vae and unet back to original device")
|
|
vae.to(org_vae_device)
|
|
unet.to(org_unet_device)
|
|
else:
|
|
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
|
|
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
|
|
text_encoders[1].to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
|
|
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype
|
|
|
|
# # get size embeddings
|
|
# orig_size = batch["original_sizes_hw"]
|
|
# crop_size = batch["crop_top_lefts"]
|
|
# target_size = batch["target_sizes_hw"]
|
|
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)
|
|
|
|
# # concat embeddings
|
|
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
|
|
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
|
|
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)
|
|
|
|
# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
|
|
# return noise_pred
|
|
|
|
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
|
# logger.warning("Sampling images is not supported for Flux model")
|
|
pass
|
|
|
|
def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
|
|
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
|
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
|
return noise_scheduler
|
|
|
|
def encode_images_to_latents(self, args, accelerator, vae, images):
|
|
return vae.encode(images).latent_dist.sample()
|
|
|
|
def shift_scale_latents(self, args, latents):
|
|
return latents
|
|
|
|
def get_noise_pred_and_target(
|
|
self,
|
|
args,
|
|
accelerator,
|
|
noise_scheduler,
|
|
latents,
|
|
batch,
|
|
text_encoder_conds,
|
|
unet: flux_models.Flux,
|
|
network,
|
|
weight_dtype,
|
|
train_unet,
|
|
):
|
|
# copy from sd3_train.py and modified
|
|
|
|
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
|
sigmas = self.noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
|
|
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(accelerator.device)
|
|
timesteps = timesteps.to(accelerator.device)
|
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
|
|
sigma = sigmas[step_indices].flatten()
|
|
while len(sigma.shape) < n_dim:
|
|
sigma = sigma.unsqueeze(-1)
|
|
return sigma
|
|
|
|
def compute_density_for_timestep_sampling(
|
|
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
|
):
|
|
"""Compute the density for sampling the timesteps when doing SD3 training.
|
|
|
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
|
|
|
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
|
"""
|
|
if weighting_scheme == "logit_normal":
|
|
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
|
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
|
u = torch.nn.functional.sigmoid(u)
|
|
elif weighting_scheme == "mode":
|
|
u = torch.rand(size=(batch_size,), device="cpu")
|
|
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
|
else:
|
|
u = torch.rand(size=(batch_size,), device="cpu")
|
|
return u
|
|
|
|
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
|
"""Computes loss weighting scheme for SD3 training.
|
|
|
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
|
|
|
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
|
"""
|
|
if weighting_scheme == "sigma_sqrt":
|
|
weighting = (sigmas**-2.0).float()
|
|
elif weighting_scheme == "cosmap":
|
|
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
|
weighting = 2 / (math.pi * bot)
|
|
else:
|
|
weighting = torch.ones_like(sigmas)
|
|
return weighting
|
|
|
|
# Sample noise that we'll add to the latents
|
|
noise = torch.randn_like(latents)
|
|
bsz = latents.shape[0]
|
|
|
|
# Sample a random timestep for each image
|
|
# for weighting schemes where we sample timesteps non-uniformly
|
|
u = compute_density_for_timestep_sampling(
|
|
weighting_scheme=args.weighting_scheme,
|
|
batch_size=bsz,
|
|
logit_mean=args.logit_mean,
|
|
logit_std=args.logit_std,
|
|
mode_scale=args.mode_scale,
|
|
)
|
|
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
|
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=accelerator.device)
|
|
|
|
# Add noise according to flow matching.
|
|
sigmas = get_sigmas(timesteps, n_dim=latents.ndim, dtype=weight_dtype)
|
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
|
|
|
# pack latents and get img_ids
|
|
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
|
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
|
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
|
|
|
# get guidance
|
|
guidance_vec = torch.full((bsz,), args.guidance_scale, device=accelerator.device)
|
|
|
|
# ensure the hidden state will require grad
|
|
if args.gradient_checkpointing:
|
|
noisy_model_input.requires_grad_(True)
|
|
for t in text_encoder_conds:
|
|
t.requires_grad_(True)
|
|
img_ids.requires_grad_(True)
|
|
guidance_vec.requires_grad_(True)
|
|
|
|
# Predict the noise residual
|
|
l_pooled, t5_out, txt_ids = text_encoder_conds
|
|
# print(
|
|
# f"model_input: {noisy_model_input.shape}, img_ids: {img_ids.shape}, t5_out: {t5_out.shape}, txt_ids: {txt_ids.shape}, l_pooled: {l_pooled.shape}, timesteps: {timesteps.shape}, guidance_vec: {guidance_vec.shape}"
|
|
# )
|
|
|
|
with accelerator.autocast():
|
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
|
model_pred = unet(
|
|
img=packed_noisy_model_input,
|
|
img_ids=img_ids,
|
|
txt=t5_out,
|
|
txt_ids=txt_ids,
|
|
y=l_pooled,
|
|
timesteps=timesteps / 1000,
|
|
guidance=guidance_vec,
|
|
)
|
|
|
|
# unpack latents
|
|
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
|
|
|
model_pred = model_pred * (-sigmas) + noisy_model_input
|
|
|
|
# these weighting schemes use a uniform timestep sampling
|
|
# and instead post-weight the loss
|
|
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
|
|
|
# flow matching loss: this is different from SD3
|
|
target = noise - latents
|
|
|
|
return model_pred, target, timesteps, None, weighting
|
|
|
|
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
|
return loss
|
|
|
|
|
|
def setup_parser() -> argparse.ArgumentParser:
|
|
parser = train_network.setup_parser()
|
|
# sdxl_train_util.add_sdxl_training_arguments(parser)
|
|
parser.add_argument("--clip_l", type=str, help="path to clip_l")
|
|
parser.add_argument("--t5xxl", type=str, help="path to t5xxl")
|
|
parser.add_argument("--ae", type=str, help="path to ae")
|
|
parser.add_argument("--apply_t5_attn_mask", action="store_true")
|
|
parser.add_argument(
|
|
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
|
)
|
|
parser.add_argument(
|
|
"--cache_text_encoder_outputs_to_disk",
|
|
action="store_true",
|
|
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
|
)
|
|
|
|
# copy from Diffusers
|
|
parser.add_argument(
|
|
"--weighting_scheme",
|
|
type=str,
|
|
default="none",
|
|
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
|
|
)
|
|
parser.add_argument(
|
|
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
|
)
|
|
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
|
parser.add_argument(
|
|
"--mode_scale",
|
|
type=float,
|
|
default=1.29,
|
|
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
|
)
|
|
parser.add_argument(
|
|
"--guidance_scale",
|
|
type=float,
|
|
default=3.5,
|
|
help="the FLUX.1 dev variant is a guidance distilled model",
|
|
)
|
|
return parser
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = setup_parser()
|
|
|
|
args = parser.parse_args()
|
|
train_util.verify_command_line_training_args(args)
|
|
args = train_util.read_config_from_file(args, parser)
|
|
|
|
trainer = FluxNetworkTrainer()
|
|
trainer.train(args)
|