mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 15:00:23 +00:00
clean code and add finetune code
This commit is contained in:
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
|
||||
# region sample images
|
||||
|
||||
|
||||
def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]:
|
||||
def batchify(
|
||||
prompt_dicts, batch_size=None
|
||||
) -> Generator[list[dict[str, str]], None, None]:
|
||||
"""
|
||||
Group prompt dictionaries into batches with configurable batch size.
|
||||
|
||||
@@ -64,7 +66,15 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N
|
||||
seed = int(seed) if seed is not None else None
|
||||
|
||||
# Create a key based on the parameters
|
||||
key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg)
|
||||
key = (
|
||||
width,
|
||||
height,
|
||||
guidance_scale,
|
||||
seed,
|
||||
sample_steps,
|
||||
cfg_trunc_ratio,
|
||||
renorm_cfg,
|
||||
)
|
||||
|
||||
# Add the prompt_dict to the corresponding batch
|
||||
if key not in batches:
|
||||
@@ -131,7 +141,9 @@ def sample_images(
|
||||
if epoch is None or epoch % args.sample_every_n_epochs != 0:
|
||||
return
|
||||
else:
|
||||
if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
|
||||
if (
|
||||
global_step % args.sample_every_n_steps != 0 or epoch is not None
|
||||
): # steps is not divisible or end of epoch
|
||||
return
|
||||
|
||||
assert (
|
||||
@@ -139,12 +151,21 @@ def sample_images(
|
||||
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}")
|
||||
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
|
||||
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
|
||||
logger.info(
|
||||
f"generating sample images at step / サンプル画像生成 ステップ: {global_step}"
|
||||
)
|
||||
if (
|
||||
not os.path.isfile(args.sample_prompts)
|
||||
and sample_prompts_gemma2_outputs is None
|
||||
):
|
||||
logger.error(
|
||||
f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}"
|
||||
)
|
||||
return
|
||||
|
||||
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
distributed_state = (
|
||||
PartialState()
|
||||
) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
|
||||
|
||||
# unwrap nextdit and gemma2_model
|
||||
nextdit = accelerator.unwrap_model(nextdit)
|
||||
@@ -163,7 +184,9 @@ def sample_images(
|
||||
rng_state = torch.get_rng_state()
|
||||
cuda_rng_state = None
|
||||
try:
|
||||
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
cuda_rng_state = (
|
||||
torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -194,7 +217,9 @@ def sample_images(
|
||||
for i in range(distributed_state.num_processes):
|
||||
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
|
||||
|
||||
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
|
||||
with distributed_state.split_between_processes(
|
||||
per_process_prompts
|
||||
) as prompt_dict_lists:
|
||||
# TODO: batch prompts together with buckets of image sizes
|
||||
for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
|
||||
sample_image_inference(
|
||||
@@ -289,7 +314,9 @@ def sample_image_inference(
|
||||
if prompt_replacement is not None:
|
||||
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
|
||||
negative_prompt = negative_prompt.replace(
|
||||
prompt_replacement[0], prompt_replacement[1]
|
||||
)
|
||||
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ""
|
||||
@@ -314,17 +341,26 @@ def sample_image_inference(
|
||||
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
|
||||
|
||||
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
|
||||
if (
|
||||
sample_prompts_gemma2_outputs
|
||||
and negative_prompt in sample_prompts_gemma2_outputs
|
||||
):
|
||||
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
|
||||
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
|
||||
logger.info(
|
||||
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
|
||||
)
|
||||
|
||||
# Load sample prompts from Gemma 2
|
||||
if gemma2_model is not None:
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
||||
)
|
||||
|
||||
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
|
||||
neg_gemma2_conds = encoding_strategy.encode_tokens(
|
||||
tokenize_strategy, [gemma2_model], tokens_and_masks
|
||||
)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
|
||||
@@ -340,10 +376,18 @@ def sample_image_inference(
|
||||
)
|
||||
|
||||
# Stack conditioning
|
||||
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device)
|
||||
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device)
|
||||
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device)
|
||||
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device)
|
||||
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(
|
||||
accelerator.device
|
||||
)
|
||||
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(
|
||||
accelerator.device
|
||||
)
|
||||
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(
|
||||
accelerator.device
|
||||
)
|
||||
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(
|
||||
accelerator.device
|
||||
)
|
||||
|
||||
# sample image
|
||||
weight_dtype = vae.dtype # TOFO give dtype as argument
|
||||
@@ -362,7 +406,9 @@ def sample_image_inference(
|
||||
noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler, num_inference_steps=sample_steps
|
||||
)
|
||||
|
||||
# if controlnet_image is not None:
|
||||
# controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
@@ -422,7 +468,9 @@ def sample_image_inference(
|
||||
import wandb
|
||||
|
||||
# not to commit images to avoid inconsistency between training and logging steps
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
wandb_tracker.log(
|
||||
{f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False
|
||||
) # positive prompt as a caption
|
||||
|
||||
vae.to(org_vae_device)
|
||||
clean_memory_on_device(accelerator.device)
|
||||
@@ -437,7 +485,9 @@ def time_shift(mu: float, sigma: float, t: torch.Tensor):
|
||||
return t
|
||||
|
||||
|
||||
def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]:
|
||||
def get_lin_function(
|
||||
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
|
||||
) -> Callable[[float], float]:
|
||||
"""
|
||||
Get linear function
|
||||
|
||||
@@ -481,7 +531,9 @@ def get_schedule(
|
||||
# shifting the schedule to favor high timesteps for higher signal images
|
||||
if shift:
|
||||
# eastimate mu based on linear estimation between two points
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len)
|
||||
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
|
||||
image_seq_len
|
||||
)
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
|
||||
return timesteps.tolist()
|
||||
@@ -520,9 +572,13 @@ def retrieve_timesteps(
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
raise ValueError(
|
||||
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
||||
)
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_timesteps = "timesteps" in set(
|
||||
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
||||
)
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
@@ -532,7 +588,9 @@ def retrieve_timesteps(
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accept_sigmas = "sigmas" in set(
|
||||
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
||||
)
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
@@ -593,7 +651,9 @@ def denoise(
|
||||
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
|
||||
current_timestep = 1 - t / scheduler.config.num_train_timesteps
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device)
|
||||
current_timestep = current_timestep * torch.ones(
|
||||
img.shape[0], device=img.device
|
||||
)
|
||||
|
||||
noise_pred_cond = model(
|
||||
img,
|
||||
@@ -610,12 +670,20 @@ def denoise(
|
||||
cap_feats=neg_txt, # Gemma2的hidden states作为caption features
|
||||
cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
|
||||
)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_cond - noise_pred_uncond
|
||||
)
|
||||
# apply normalization after classifier-free guidance
|
||||
if float(renorm_cfg) > 0.0:
|
||||
cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True)
|
||||
cond_norm = torch.linalg.vector_norm(
|
||||
noise_pred_cond,
|
||||
dim=tuple(range(1, len(noise_pred_cond.shape))),
|
||||
keepdim=True,
|
||||
)
|
||||
max_new_norm = cond_norm * float(renorm_cfg)
|
||||
noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True)
|
||||
noise_norm = torch.linalg.vector_norm(
|
||||
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
|
||||
)
|
||||
if noise_norm >= max_new_norm:
|
||||
noise_pred = noise_pred * (max_new_norm / noise_norm)
|
||||
else:
|
||||
@@ -640,7 +708,11 @@ def denoise(
|
||||
|
||||
# region train
|
||||
def get_sigmas(
|
||||
noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32
|
||||
noise_scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
timesteps: Tensor,
|
||||
device: torch.device,
|
||||
n_dim=4,
|
||||
dtype=torch.float32,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Get sigmas for timesteps
|
||||
@@ -667,7 +739,11 @@ def get_sigmas(
|
||||
|
||||
|
||||
def compute_density_for_timestep_sampling(
|
||||
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
||||
weighting_scheme: str,
|
||||
batch_size: int,
|
||||
logit_mean: float = None,
|
||||
logit_std: float = None,
|
||||
mode_scale: float = None,
|
||||
):
|
||||
"""
|
||||
Compute the density for sampling the timesteps when doing SD3 training.
|
||||
@@ -688,7 +764,9 @@ def compute_density_for_timestep_sampling(
|
||||
"""
|
||||
if weighting_scheme == "logit_normal":
|
||||
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
||||
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
||||
u = torch.normal(
|
||||
mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
|
||||
)
|
||||
u = torch.nn.functional.sigmoid(u)
|
||||
elif weighting_scheme == "mode":
|
||||
u = torch.rand(size=(batch_size,), device="cpu")
|
||||
@@ -722,7 +800,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
|
||||
return weighting
|
||||
|
||||
|
||||
def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
def get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, device, dtype
|
||||
) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
"""
|
||||
Get noisy model input and timesteps.
|
||||
|
||||
@@ -753,27 +833,27 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d
|
||||
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
elif args.timestep_sampling == "shift":
|
||||
shift = args.discrete_flow_shift
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
logits_norm = (
|
||||
logits_norm * args.sigmoid_scale
|
||||
) # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
elif args.timestep_sampling == "nextdit_shift":
|
||||
logits_norm = torch.randn(bsz, device=device)
|
||||
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
|
||||
timesteps = logits_norm.sigmoid()
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
|
||||
timesteps = time_shift(mu, 1.0, timesteps)
|
||||
t = torch.rand((bsz,), device=device)
|
||||
mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16
|
||||
t = time_shift(mu, 1.0, t)
|
||||
|
||||
t = timesteps.view(-1, 1, 1, 1)
|
||||
timesteps = timesteps * 1000.0
|
||||
noisy_model_input = (1 - t) * latents + t * noise
|
||||
timesteps = t * 1000.0
|
||||
t = t.view(-1, 1, 1, 1)
|
||||
noisy_model_input = (1 - t) * noise + t * latents
|
||||
else:
|
||||
# Sample a random timestep for each image
|
||||
# for weighting schemes where we sample timesteps non-uniformly
|
||||
@@ -788,8 +868,10 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d
|
||||
timesteps = noise_scheduler.timesteps[indices].to(device=device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
sigmas = get_sigmas(
|
||||
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
|
||||
)
|
||||
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
|
||||
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
@@ -821,7 +903,9 @@ def apply_model_prediction_type(
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
weighting = compute_loss_weighting_for_sd3(
|
||||
weighting_scheme=args.weighting_scheme, sigmas=sigmas
|
||||
)
|
||||
|
||||
return model_pred, weighting
|
||||
|
||||
@@ -863,15 +947,27 @@ def save_models(
|
||||
|
||||
|
||||
def save_lumina_model_on_train_end(
|
||||
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
|
||||
args: argparse.Namespace,
|
||||
save_dtype: torch.dtype,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
lumina: lumina_models.NextDiT,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2"
|
||||
None,
|
||||
args,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
is_stable_diffusion_ckpt=True,
|
||||
lumina="lumina2",
|
||||
)
|
||||
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
|
||||
train_util.save_sd_model_on_train_end_common(
|
||||
args, True, True, epoch, global_step, sd_saver, None
|
||||
)
|
||||
|
||||
|
||||
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている
|
||||
@@ -901,7 +997,15 @@ def save_lumina_model_on_epoch_end_or_stepwise(
|
||||
"""
|
||||
|
||||
def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
|
||||
sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
|
||||
sai_metadata = train_util.get_sai_model_spec(
|
||||
{},
|
||||
args,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
is_stable_diffusion_ckpt=True,
|
||||
lumina="lumina2",
|
||||
)
|
||||
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
|
||||
|
||||
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
@@ -927,7 +1031,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
|
||||
type=str,
|
||||
help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提",
|
||||
)
|
||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||
parser.add_argument(
|
||||
"--ae",
|
||||
type=str,
|
||||
help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gemma2_max_token_length",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user