mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
refactor SD3 CLIP to transformers etc.
This commit is contained in:
@@ -11,8 +11,8 @@ from safetensors.torch import save_file
|
||||
from accelerate import Accelerator, PartialState
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextModelWithProjection, T5EncoderModel
|
||||
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
@@ -28,60 +28,16 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .sdxl_train_util import match_mixed_precision
|
||||
|
||||
|
||||
def load_target_model(
|
||||
model_type: str,
|
||||
args: argparse.Namespace,
|
||||
state_dict: dict,
|
||||
accelerator: Accelerator,
|
||||
attn_mode: str,
|
||||
model_dtype: Optional[torch.dtype],
|
||||
device: Optional[torch.device],
|
||||
) -> Union[
|
||||
sd3_models.MMDiT,
|
||||
Optional[sd3_models.SDClipModel],
|
||||
Optional[sd3_models.SDXLClipG],
|
||||
Optional[sd3_models.T5XXLModel],
|
||||
sd3_models.SDVAE,
|
||||
]:
|
||||
loading_device = device if device is not None else (accelerator.device if args.lowram else "cpu")
|
||||
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
if pi == accelerator.state.local_process_index:
|
||||
logger.info(f"loading model for process {accelerator.state.local_process_index}/{accelerator.state.num_processes}")
|
||||
|
||||
if model_type == "mmdit":
|
||||
model = sd3_utils.load_mmdit(state_dict, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "clip_l":
|
||||
model = sd3_utils.load_clip_l(state_dict, args.clip_l, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "clip_g":
|
||||
model = sd3_utils.load_clip_g(state_dict, args.clip_g, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "t5xxl":
|
||||
model = sd3_utils.load_t5xxl(state_dict, args.t5xxl, attn_mode, model_dtype, loading_device)
|
||||
elif model_type == "vae":
|
||||
model = sd3_utils.load_vae(state_dict, args.vae, model_dtype, loading_device)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
# work on low-ram device: models are already loaded on accelerator.device, but we ensure they are on device
|
||||
if args.lowram:
|
||||
model = model.to(accelerator.device)
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
return model
|
||||
from library import sd3_models, sd3_utils, strategy_base, train_util
|
||||
|
||||
|
||||
def save_models(
|
||||
ckpt_path: str,
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
mmdit: Optional[sd3_models.MMDiT],
|
||||
vae: Optional[sd3_models.SDVAE],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
sai_metadata: Optional[dict],
|
||||
save_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
@@ -101,24 +57,35 @@ def save_models(
|
||||
update_sd("model.diffusion_model.", mmdit.state_dict())
|
||||
update_sd("first_stage_model.", vae.state_dict())
|
||||
|
||||
if clip_l is not None:
|
||||
update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
||||
if clip_g is not None:
|
||||
update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
||||
if t5xxl is not None:
|
||||
update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
||||
# do not support unified checkpoint format for now
|
||||
# if clip_l is not None:
|
||||
# update_sd("text_encoders.clip_l.", clip_l.state_dict())
|
||||
# if clip_g is not None:
|
||||
# update_sd("text_encoders.clip_g.", clip_g.state_dict())
|
||||
# if t5xxl is not None:
|
||||
# update_sd("text_encoders.t5xxl.", t5xxl.state_dict())
|
||||
|
||||
save_file(state_dict, ckpt_path, metadata=sai_metadata)
|
||||
|
||||
if clip_l is not None:
|
||||
clip_l_path = ckpt_path.replace(".safetensors", "_clip_l.safetensors")
|
||||
save_file(clip_l.state_dict(), clip_l_path)
|
||||
if clip_g is not None:
|
||||
clip_g_path = ckpt_path.replace(".safetensors", "_clip_g.safetensors")
|
||||
save_file(clip_g.state_dict(), clip_g_path)
|
||||
if t5xxl is not None:
|
||||
t5xxl_path = ckpt_path.replace(".safetensors", "_t5xxl.safetensors")
|
||||
save_file(t5xxl.state_dict(), t5xxl_path)
|
||||
|
||||
|
||||
def save_sd3_model_on_train_end(
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
@@ -141,9 +108,9 @@ def save_sd3_model_on_epoch_end_or_stepwise(
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
clip_l: sd3_models.SDClipModel,
|
||||
clip_g: sd3_models.SDXLClipG,
|
||||
t5xxl: Optional[sd3_models.T5XXLModel],
|
||||
clip_l: Optional[CLIPTextModelWithProjection],
|
||||
clip_g: Optional[CLIPTextModelWithProjection],
|
||||
t5xxl: Optional[T5EncoderModel],
|
||||
mmdit: sd3_models.MMDiT,
|
||||
vae: sd3_models.SDVAE,
|
||||
):
|
||||
@@ -208,23 +175,27 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
help="T5-XXL model path. if not specified, use ckpt's state_dict / T5-XXLモデルのパス。指定しない場合はckptのstate_dictを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_clip", action="store_true", help="save CLIP models to checkpoint / CLIPモデルをチェックポイントに保存する"
|
||||
"--save_clip",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_t5xxl", action="store_true", help="save T5-XXL model to checkpoint / T5-XXLモデルをチェックポイントに保存する"
|
||||
"--save_t5xxl",
|
||||
action="store_true",
|
||||
help="[DOES NOT WORK] unified checkpoint is not supported / 統合チェックポイントはまだサポートされていません",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t5xxl_device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL device. if not specified, use accelerator's device / T5-XXLデバイス。指定しない場合はacceleratorのデバイスを使用",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_dtype",
|
||||
type=str,
|
||||
default=None,
|
||||
help="T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
||||
help="[DOES NOT WORK] not supported yet. T5-XXL dtype. if not specified, use default dtype (from mixed precision) / T5-XXL dtype。指定しない場合はデフォルトのdtype(mixed precisionから)を使用",
|
||||
)
|
||||
|
||||
# copy from Diffusers
|
||||
@@ -233,16 +204,25 @@ def add_sd3_training_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
default="logit_normal",
|
||||
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"],
|
||||
help="weighting scheme for timestep distribution and loss / タイムステップ分布と損失のための重み付けスキーム",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
|
||||
"--logit_mean",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="mean to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合の平均",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logit_std",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="std to use when using the `'logit_normal'` weighting scheme for timestep distribution. / タイムステップ分布のために`'logit_normal'`重み付けスキームを使用する場合のstd",
|
||||
)
|
||||
parser.add_argument("--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme.")
|
||||
parser.add_argument(
|
||||
"--mode_scale",
|
||||
type=float,
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`. / モード重み付けスキームのスケール。`'mode'`を`weighting_scheme`として使用する場合のみ有効",
|
||||
)
|
||||
|
||||
|
||||
@@ -283,7 +263,7 @@ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCachin
|
||||
# temporary copied from sd3_minimal_inferece.py
|
||||
|
||||
|
||||
def get_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
def get_all_sigmas(sampling: sd3_utils.ModelSamplingDiscreteFlow, steps):
|
||||
start = sampling.timestep(sampling.sigma_max)
|
||||
end = sampling.timestep(sampling.sigma_min)
|
||||
timesteps = torch.linspace(start, end, steps)
|
||||
@@ -327,7 +307,7 @@ def do_sample(
|
||||
|
||||
model_sampling = sd3_utils.ModelSamplingDiscreteFlow(shift=3.0) # 3.0 is for SD3
|
||||
|
||||
sigmas = get_sigmas(model_sampling, steps).to(device)
|
||||
sigmas = get_all_sigmas(model_sampling, steps).to(device)
|
||||
|
||||
noise_scaled = model_sampling.noise_scaling(sigmas[0], noise, latent, max_denoise(model_sampling, sigmas))
|
||||
|
||||
@@ -371,37 +351,6 @@ def do_sample(
|
||||
return x
|
||||
|
||||
|
||||
def load_prompts(prompt_file: str) -> List[Dict]:
|
||||
# read prompts
|
||||
if prompt_file.endswith(".txt"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
prompts = [line.strip() for line in lines if len(line.strip()) > 0 and line[0] != "#"]
|
||||
elif prompt_file.endswith(".toml"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
data = toml.load(f)
|
||||
prompts = [dict(**data["prompt"], **subset) for subset in data["prompt"]["subset"]]
|
||||
elif prompt_file.endswith(".json"):
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
prompts = json.load(f)
|
||||
|
||||
# preprocess prompts
|
||||
for i in range(len(prompts)):
|
||||
prompt_dict = prompts[i]
|
||||
if isinstance(prompt_dict, str):
|
||||
from library.train_util import line_to_prompt_dict
|
||||
|
||||
prompt_dict = line_to_prompt_dict(prompt_dict)
|
||||
prompts[i] = prompt_dict
|
||||
assert isinstance(prompt_dict, dict)
|
||||
|
||||
# Adds an enumerator to the dict based on prompt position. Used later to name image files. Also cleanup of extra data in original prompt dict.
|
||||
prompt_dict["enum"] = i
|
||||
prompt_dict.pop("subset", None)
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
def sample_images(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
@@ -440,7 +389,7 @@ def sample_images(
|
||||
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
|
||||
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
|
||||
|
||||
prompts = load_prompts(args.sample_prompts)
|
||||
prompts = train_util.load_prompts(args.sample_prompts)
|
||||
|
||||
save_dir = args.output_dir + "/sample"
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
@@ -510,7 +459,7 @@ def sample_image_inference(
|
||||
accelerator: Accelerator,
|
||||
args: argparse.Namespace,
|
||||
mmdit: sd3_models.MMDiT,
|
||||
text_encoders: List[Union[sd3_models.SDClipModel, sd3_models.SDXLClipG, sd3_models.T5XXLModel]],
|
||||
text_encoders: List[Union[CLIPTextModelWithProjection, T5EncoderModel]],
|
||||
vae: sd3_models.SDVAE,
|
||||
save_dir,
|
||||
prompt_dict,
|
||||
@@ -568,7 +517,7 @@ def sample_image_inference(
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(prompt)
|
||||
te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
|
||||
|
||||
lg_out, t5_out, pooled = te_outputs
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = te_outputs
|
||||
cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# encode negative prompts
|
||||
@@ -578,7 +527,7 @@ def sample_image_inference(
|
||||
l_tokens, g_tokens, t5_tokens = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_te_outputs = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, [l_tokens, g_tokens, t5_tokens])
|
||||
|
||||
lg_out, t5_out, pooled = neg_te_outputs
|
||||
lg_out, t5_out, pooled, l_attn_mask, g_attn_mask, t5_attn_mask = neg_te_outputs
|
||||
neg_cond = encoding_strategy.concat_encodings(lg_out, t5_out, pooled)
|
||||
|
||||
# sample image
|
||||
@@ -609,14 +558,9 @@ def sample_image_inference(
|
||||
wandb_tracker = accelerator.get_tracker("wandb")
|
||||
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(
|
||||
image,
|
||||
caption=prompt # positive prompt as a caption
|
||||
)},
|
||||
commit=False
|
||||
)
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
# region Diffusers
|
||||
@@ -886,4 +830,78 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
|
||||
def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32):
|
||||
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
|
||||
schedule_timesteps = noise_scheduler.timesteps.to(device)
|
||||
timesteps = timesteps.to(device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < n_dim:
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
):
|
||||
"""Compute the density for sampling the timesteps when doing SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
||||
else:
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
return u
|
||||
|
||||
|
||||
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
||||
"""Computes loss weighting scheme for SD3 training.
|
||||
|
||||
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
||||
|
||||
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
||||
"""
|
||||
if weighting_scheme == "sigma_sqrt":
|
||||
weighting = (sigmas**-2.0).float()
|
||||
elif weighting_scheme == "cosmap":
|
||||
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
||||
weighting = 2 / (math.pi * bot)
|
||||
else:
|
||||
weighting = torch.ones_like(sigmas)
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
u = compute_density_for_timestep_sampling(
|
||||
weighting_scheme=args.weighting_scheme,
|
||||
batch_size=bsz,
|
||||
logit_mean=args.logit_mean,
|
||||
logit_std=args.logit_std,
|
||||
mode_scale=args.mode_scale,
|
||||
)
|
||||
indices = (u * noise_scheduler.config.num_train_timesteps).long()
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input, timesteps, sigmas
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user