mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 23:01:22 +00:00
update
This commit is contained in:
@@ -108,14 +108,6 @@ def load_gemma2(
|
||||
logger.info(f"Loaded Gemma2: {info}")
|
||||
return gemma2
|
||||
|
||||
def prepare_img_ids(batch_size: int, packed_latent_height: int, packed_latent_width: int):
|
||||
img_ids = torch.zeros(packed_latent_height, packed_latent_width, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(packed_latent_height)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(packed_latent_width)[None, :]
|
||||
img_ids = einops.repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
|
||||
return img_ids
|
||||
|
||||
|
||||
def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
|
||||
"""
|
||||
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
|
||||
|
||||
@@ -53,7 +53,7 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
self.train_gemma2 = not args.network_train_unet_only
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
loading_dtype = None if args.fp8 else weight_dtype
|
||||
loading_dtype = None if args.fp8_base else weight_dtype
|
||||
|
||||
model = lumina_util.load_lumina_model(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -67,8 +67,12 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
# model.enable_block_swap(args.blocks_to_swap, accelerator.device)
|
||||
# self.is_swapping_blocks = True
|
||||
|
||||
gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
|
||||
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
|
||||
gemma2 = lumina_util.load_gemma2(
|
||||
args.gemma2, weight_dtype, "cpu"
|
||||
)
|
||||
ae = lumina_util.load_ae(
|
||||
args.ae, weight_dtype, "cpu"
|
||||
)
|
||||
|
||||
return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model
|
||||
|
||||
@@ -168,11 +172,174 @@ class LuminaNetworkTrainer(train_network.NetworkTrainer):
|
||||
def shift_scale_latents(self, args, latents):
|
||||
return latents
|
||||
|
||||
def get_noise_pred_and_target(
|
||||
self,
|
||||
args,
|
||||
accelerator,
|
||||
noise_scheduler,
|
||||
latents,
|
||||
batch,
|
||||
text_encoder_conds,
|
||||
unet: lumina_models.NextDiT,
|
||||
network,
|
||||
weight_dtype,
|
||||
train_unet,
|
||||
is_train=True,
|
||||
):
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# get noisy model input and timesteps
|
||||
noisy_model_input, timesteps, sigmas = (
|
||||
flux_train_utils.get_noisy_model_input_and_timesteps(
|
||||
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
|
||||
)
|
||||
)
|
||||
|
||||
# pack latents and get img_ids - 这部分可以保留因为NextDiT也需要packed格式的输入
|
||||
packed_noisy_model_input = lumina_util.pack_latents(noisy_model_input)
|
||||
packed_latent_height, packed_latent_width = (
|
||||
noisy_model_input.shape[2] // 2,
|
||||
noisy_model_input.shape[3] // 2,
|
||||
)
|
||||
|
||||
# ensure the hidden state will require grad
|
||||
if args.gradient_checkpointing:
|
||||
noisy_model_input.requires_grad_(True)
|
||||
for t in text_encoder_conds:
|
||||
if t is not None and t.dtype.is_floating_point:
|
||||
t.requires_grad_(True)
|
||||
|
||||
# Unpack Gemma2 outputs
|
||||
gemma2_hidden_states, gemma2_attn_mask, input_ids = text_encoder_conds
|
||||
if not args.apply_gemma2_attn_mask:
|
||||
gemma2_attn_mask = None
|
||||
|
||||
def call_dit(img, gemma2_hidden_states, input_ids, timesteps, gemma2_attn_mask):
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
# NextDiT forward expects (x, t, cap_feats, cap_mask)
|
||||
model_pred = unet(
|
||||
x=img, # packed latents
|
||||
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
|
||||
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
|
||||
cap_mask=gemma2_attn_mask, # Gemma2的attention mask
|
||||
)
|
||||
return model_pred
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
gemma2_hidden_states=gemma2_hidden_states,
|
||||
input_ids=input_ids,
|
||||
timesteps=timesteps,
|
||||
gemma2_attn_mask=gemma2_attn_mask,
|
||||
)
|
||||
|
||||
# unpack latents
|
||||
model_pred = lumina_util.unpack_latents(
|
||||
model_pred, packed_latent_height, packed_latent_width
|
||||
)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(
|
||||
args, model_pred, noisy_model_input, sigmas
|
||||
)
|
||||
|
||||
# flow matching loss: this is different from SD3
|
||||
target = noise - latents
|
||||
|
||||
# differential output preservation
|
||||
if "custom_attributes" in batch:
|
||||
diff_output_pr_indices = []
|
||||
for i, custom_attributes in enumerate(batch["custom_attributes"]):
|
||||
if (
|
||||
"diff_output_preservation" in custom_attributes
|
||||
and custom_attributes["diff_output_preservation"]
|
||||
):
|
||||
diff_output_pr_indices.append(i)
|
||||
|
||||
if len(diff_output_pr_indices) > 0:
|
||||
network.set_multiplier(0.0)
|
||||
with torch.no_grad():
|
||||
model_pred_prior = call_dit(
|
||||
img=packed_noisy_model_input[diff_output_pr_indices],
|
||||
gemma2_hidden_states=gemma2_hidden_states[
|
||||
diff_output_pr_indices
|
||||
],
|
||||
input_ids=input_ids[diff_output_pr_indices],
|
||||
timesteps=timesteps[diff_output_pr_indices],
|
||||
gemma2_attn_mask=(
|
||||
gemma2_attn_mask[diff_output_pr_indices]
|
||||
if gemma2_attn_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
network.set_multiplier(1.0)
|
||||
|
||||
model_pred_prior = lumina_util.unpack_latents(
|
||||
model_pred_prior, packed_latent_height, packed_latent_width
|
||||
)
|
||||
model_pred_prior, _ = flux_train_utils.apply_model_prediction_type(
|
||||
args,
|
||||
model_pred_prior,
|
||||
noisy_model_input[diff_output_pr_indices],
|
||||
sigmas[diff_output_pr_indices] if sigmas is not None else None,
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
|
||||
return train_util.get_sai_model_spec(
|
||||
None, args, False, True, False, lumina="lumina2"
|
||||
)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
metadata["ss_apply_gemma2_attn_mask"] = args.apply_gemma2_attn_mask
|
||||
metadata["ss_weighting_scheme"] = args.weighting_scheme
|
||||
metadata["ss_logit_mean"] = args.logit_mean
|
||||
metadata["ss_logit_std"] = args.logit_std
|
||||
metadata["ss_mode_scale"] = args.mode_scale
|
||||
metadata["ss_guidance_scale"] = args.guidance_scale
|
||||
metadata["ss_timestep_sampling"] = args.timestep_sampling
|
||||
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
|
||||
metadata["ss_model_prediction_type"] = args.model_prediction_type
|
||||
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift
|
||||
|
||||
def is_text_encoder_not_needed_for_training(self, args):
|
||||
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)
|
||||
|
||||
def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
|
||||
text_encoder.model.embed_tokens.requires_grad_(True)
|
||||
|
||||
def prepare_text_encoder_fp8(
|
||||
self, index, text_encoder, te_weight_dtype, weight_dtype
|
||||
):
|
||||
logger.info(
|
||||
f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}"
|
||||
)
|
||||
text_encoder.to(te_weight_dtype) # fp8
|
||||
text_encoder.model.embed_tokens.to(dtype=weight_dtype)
|
||||
|
||||
def prepare_unet_with_accelerator(
|
||||
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
|
||||
) -> torch.nn.Module:
|
||||
if not self.is_swapping_blocks:
|
||||
return super().prepare_unet_with_accelerator(args, accelerator, unet)
|
||||
|
||||
# if we doesn't swap blocks, we can move the model to device
|
||||
nextdit: lumina_models.Nextdit = unet
|
||||
nextdit = accelerator.prepare(
|
||||
nextdit, device_placement=[not self.is_swapping_blocks]
|
||||
)
|
||||
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
|
||||
accelerator.device
|
||||
) # reduce peak memory usage
|
||||
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()
|
||||
|
||||
return nextdit
|
||||
|
||||
|
||||
def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
Reference in New Issue
Block a user