mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-15 08:36:41 +00:00
Minimal Example of Flex Training
This commit is contained in:
@@ -20,6 +20,8 @@ from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from networks import oft_flux
|
||||
|
||||
from library.flux_utils import bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
init_ipex()
|
||||
|
||||
|
||||
@@ -151,6 +153,9 @@ def do_sample(
|
||||
logger.info(f"num_steps: {num_steps}")
|
||||
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
||||
|
||||
# bypass guidance module
|
||||
bypass_flux_guidance(model)
|
||||
|
||||
# denoise initial noise
|
||||
if accelerator:
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
@@ -364,6 +369,9 @@ def generate_image(
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||
|
||||
# restore guidance module
|
||||
restore_flux_guidance(model)
|
||||
|
||||
# save image
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@@ -425,6 +425,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
return model_pred
|
||||
|
||||
flux_utils.bypass_flux_guidance(unet)
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
@@ -439,6 +441,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
flux_utils.restore_flux_guidance(unet)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
|
||||
@@ -24,6 +24,32 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
|
||||
def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(
|
||||
timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
return conditioning
|
||||
|
||||
# bypass the forward function
|
||||
def bypass_flux_guidance(transformer):
|
||||
if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
# dont bypass if it doesnt have the guidance embedding
|
||||
if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
|
||||
return
|
||||
transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
|
||||
transformer.time_text_embed.forward = partial(
|
||||
guidance_embed_bypass_forward, transformer.time_text_embed
|
||||
)
|
||||
|
||||
# restore the forward function
|
||||
def restore_flux_guidance(transformer):
|
||||
if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
|
||||
return
|
||||
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
|
||||
del transformer.time_text_embed._bfg_orig_forward
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
@@ -60,6 +86,7 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int
|
||||
|
||||
is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
|
||||
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
|
||||
# is_schnell = True
|
||||
|
||||
# check number of double and single blocks
|
||||
if not is_diffusers:
|
||||
|
||||
Reference in New Issue
Block a user