mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 00:49:40 +00:00
Merge 99338a204f into 3e6935a07e
This commit is contained in:
@@ -20,6 +20,8 @@ from library import device_utils
|
||||
from library.device_utils import init_ipex, get_preferred_device
|
||||
from networks import oft_flux
|
||||
|
||||
from library.flux_utils import bypass_flux_guidance, restore_flux_guidance
|
||||
|
||||
init_ipex()
|
||||
|
||||
|
||||
@@ -151,6 +153,10 @@ def do_sample(
|
||||
logger.info(f"num_steps: {num_steps}")
|
||||
timesteps = get_schedule(num_steps, img.shape[1], shift=not is_schnell)
|
||||
|
||||
# bypass guidance module
|
||||
if args.bypass_flux_guidance:
|
||||
bypass_flux_guidance(model)
|
||||
|
||||
# denoise initial noise
|
||||
if accelerator:
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
@@ -364,6 +370,10 @@ def generate_image(
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
img = Image.fromarray((127.5 * (x + 1.0)).float().cpu().numpy().astype(np.uint8)[0])
|
||||
|
||||
# restore guidance module
|
||||
if args.bypass_flux_guidance:
|
||||
restore_flux_guidance(model)
|
||||
|
||||
# save image
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
@@ -643,6 +643,9 @@ def train(args):
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
if not args.apply_t5_attn_mask:
|
||||
t5_attn_mask = None
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.bypass_flux_guidance(flux)
|
||||
|
||||
with accelerator.autocast():
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
@@ -660,6 +663,9 @@ def train(args):
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.restore_flux_guidance(flux)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
|
||||
@@ -395,6 +395,8 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
return model_pred
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.bypass_flux_guidance(unet)
|
||||
|
||||
model_pred = call_dit(
|
||||
img=packed_noisy_model_input,
|
||||
@@ -409,6 +411,9 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
# unpack latents
|
||||
model_pred = flux_utils.unpack_latents(model_pred, packed_latent_height, packed_latent_width)
|
||||
|
||||
if args.bypass_flux_guidance:
|
||||
flux_utils.restore_flux_guidance(unet)
|
||||
|
||||
# apply model prediction type
|
||||
model_pred, weighting = flux_train_utils.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)
|
||||
|
||||
@@ -680,3 +680,10 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
default=3.0,
|
||||
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
|
||||
)
|
||||
|
||||
# bypass guidance module for flux
|
||||
parser.add_argument(
|
||||
"--bypass_flux_guidance"
|
||||
, action="store_true"
|
||||
, help="bypass flux guidance module for Flex.1-Alpha Training / Flex.1-Alpha トレーニング用バイパス フラックス ガイダンス モジュール"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,13 @@ MODEL_VERSION_FLUX_V1 = "flux1"
|
||||
MODEL_NAME_DEV = "dev"
|
||||
MODEL_NAME_SCHNELL = "schnell"
|
||||
|
||||
# bypass guidance
|
||||
def bypass_flux_guidance(transformer):
|
||||
transformer.params.guidance_embed = False
|
||||
|
||||
# restore the forward function
|
||||
def restore_flux_guidance(transformer):
|
||||
transformer.params.guidance_embed = True
|
||||
|
||||
def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user