clean code and add finetune code

This commit is contained in:
sdbds
2025-02-26 11:20:03 +08:00
parent 5f9047c8cf
commit ce37c08b9a
3 changed files with 1118 additions and 84 deletions

View File

@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
# region sample images
def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], None, None]:
def batchify(
prompt_dicts, batch_size=None
) -> Generator[list[dict[str, str]], None, None]:
"""
Group prompt dictionaries into batches with configurable batch size.
@@ -64,7 +66,15 @@ def batchify(prompt_dicts, batch_size=None) -> Generator[list[dict[str, str]], N
seed = int(seed) if seed is not None else None
# Create a key based on the parameters
key = (width, height, guidance_scale, seed, sample_steps, cfg_trunc_ratio, renorm_cfg)
key = (
width,
height,
guidance_scale,
seed,
sample_steps,
cfg_trunc_ratio,
renorm_cfg,
)
# Add the prompt_dict to the corresponding batch
if key not in batches:
@@ -131,7 +141,9 @@ def sample_images(
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if global_step % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
if (
global_step % args.sample_every_n_steps != 0 or epoch is not None
): # steps is not divisible or end of epoch
return
assert (
@@ -139,12 +151,21 @@ def sample_images(
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"
logger.info("")
logger.info(f"generating sample images at step / サンプル画像生成 ステップ: {global_step}")
if not os.path.isfile(args.sample_prompts) and sample_prompts_gemma2_outputs is None:
logger.error(f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}")
logger.info(
f"generating sample images at step / サンプル画像生成 ステップ: {global_step}"
)
if (
not os.path.isfile(args.sample_prompts)
and sample_prompts_gemma2_outputs is None
):
logger.error(
f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}"
)
return
distributed_state = PartialState() # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
distributed_state = (
PartialState()
) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here
# unwrap nextdit and gemma2_model
nextdit = accelerator.unwrap_model(nextdit)
@@ -163,7 +184,9 @@ def sample_images(
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None
cuda_rng_state = (
torch.cuda.get_rng_state() if torch.cuda.is_available() else None
)
except Exception:
pass
@@ -194,7 +217,9 @@ def sample_images(
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
with distributed_state.split_between_processes(
per_process_prompts
) as prompt_dict_lists:
# TODO: batch prompts together with buckets of image sizes
for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
sample_image_inference(
@@ -289,7 +314,9 @@ def sample_image_inference(
if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
negative_prompt = negative_prompt.replace(
prompt_replacement[0], prompt_replacement[1]
)
if negative_prompt is None:
negative_prompt = ""
@@ -314,17 +341,26 @@ def sample_image_inference(
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")
if sample_prompts_gemma2_outputs and negative_prompt in sample_prompts_gemma2_outputs:
if (
sample_prompts_gemma2_outputs
and negative_prompt in sample_prompts_gemma2_outputs
):
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}")
logger.info(
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
)
# Load sample prompts from Gemma 2
if gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, [gemma2_model], tokens_and_masks
)
tokens_and_masks = tokenize_strategy.tokenize(negative_prompt)
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2_model], tokens_and_masks)
neg_gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, [gemma2_model], tokens_and_masks
)
# Unpack Gemma2 outputs
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
@@ -340,10 +376,18 @@ def sample_image_inference(
)
# Stack conditioning
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(accelerator.device)
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(accelerator.device)
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(accelerator.device)
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(accelerator.device)
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(
accelerator.device
)
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(
accelerator.device
)
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(
accelerator.device
)
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(
accelerator.device
)
# sample image
weight_dtype = vae.dtype # TOFO give dtype as argument
@@ -362,7 +406,9 @@ def sample_image_inference(
noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)
scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps=sample_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler, num_inference_steps=sample_steps
)
# if controlnet_image is not None:
# controlnet_image = Image.open(controlnet_image).convert("RGB")
@@ -422,7 +468,9 @@ def sample_image_inference(
import wandb
# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
wandb_tracker.log(
{f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False
) # positive prompt as a caption
vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)
@@ -437,7 +485,9 @@ def time_shift(mu: float, sigma: float, t: torch.Tensor):
return t
def get_lin_function(x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15) -> Callable[[float], float]:
def get_lin_function(
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
) -> Callable[[float], float]:
"""
Get linear function
@@ -481,7 +531,9 @@ def get_schedule(
# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(image_seq_len)
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
image_seq_len
)
timesteps = time_shift(mu, 1.0, timesteps)
return timesteps.tolist()
@@ -520,9 +572,13 @@ def retrieve_timesteps(
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -532,7 +588,9 @@ def retrieve_timesteps(
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -593,7 +651,9 @@ def denoise(
# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep * torch.ones(img.shape[0], device=img.device)
current_timestep = current_timestep * torch.ones(
img.shape[0], device=img.device
)
noise_pred_cond = model(
img,
@@ -610,12 +670,20 @@ def denoise(
cap_feats=neg_txt, # Gemma2的hidden states作为caption features
cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# apply normalization after classifier-free guidance
if float(renorm_cfg) > 0.0:
cond_norm = torch.linalg.vector_norm(noise_pred_cond, dim=tuple(range(1, len(noise_pred_cond.shape))), keepdim=True)
cond_norm = torch.linalg.vector_norm(
noise_pred_cond,
dim=tuple(range(1, len(noise_pred_cond.shape))),
keepdim=True,
)
max_new_norm = cond_norm * float(renorm_cfg)
noise_norm = torch.linalg.vector_norm(noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True)
noise_norm = torch.linalg.vector_norm(
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
)
if noise_norm >= max_new_norm:
noise_pred = noise_pred * (max_new_norm / noise_norm)
else:
@@ -640,7 +708,11 @@ def denoise(
# region train
def get_sigmas(
noise_scheduler: FlowMatchEulerDiscreteScheduler, timesteps: Tensor, device: torch.device, n_dim=4, dtype=torch.float32
noise_scheduler: FlowMatchEulerDiscreteScheduler,
timesteps: Tensor,
device: torch.device,
n_dim=4,
dtype=torch.float32,
) -> Tensor:
"""
Get sigmas for timesteps
@@ -667,7 +739,11 @@ def get_sigmas(
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
weighting_scheme: str,
batch_size: int,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
@@ -688,7 +764,9 @@ def compute_density_for_timestep_sampling(
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.normal(
mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
@@ -722,7 +800,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor
return weighting
def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype) -> Tuple[Tensor, Tensor, Tensor]:
def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.
@@ -753,27 +833,27 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 16) * (w // 16)) # lumina use //16
t = time_shift(mu, 1.0, t)
t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -788,8 +868,10 @@ def get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, d
timesteps = noise_scheduler.timesteps[indices].to(device=device)
# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
@@ -821,7 +903,9 @@ def apply_model_prediction_type(
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
weighting = compute_loss_weighting_for_sd3(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)
return model_pred, weighting
@@ -863,15 +947,27 @@ def save_models(
def save_lumina_model_on_train_end(
args: argparse.Namespace, save_dtype: torch.dtype, epoch: int, global_step: int, lumina: lumina_models.NextDiT
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
lumina: lumina_models.NextDiT,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2"
None,
args,
False,
False,
False,
is_stable_diffusion_ckpt=True,
lumina="lumina2",
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_train_end_common(args, True, True, epoch, global_step, sd_saver, None)
train_util.save_sd_model_on_train_end_common(
args, True, True, epoch, global_step, sd_saver, None
)
# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている
@@ -901,7 +997,15 @@ def save_lumina_model_on_epoch_end_or_stepwise(
"""
def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
sai_metadata = train_util.get_sai_model_spec({}, args, False, False, False, is_stable_diffusion_ckpt=True, lumina="lumina2")
sai_metadata = train_util.get_sai_model_spec(
{},
args,
False,
False,
False,
is_stable_diffusion_ckpt=True,
lumina="lumina2",
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)
train_util.save_sd_model_on_epoch_end_or_stepwise_common(
@@ -927,7 +1031,11 @@ def add_lumina_train_arguments(parser: argparse.ArgumentParser):
type=str,
help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--ae",
type=str,
help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors",
)
parser.add_argument(
"--gemma2_max_token_length",
type=int,