Update FLUX.1 support for compact models

This commit is contained in:
Kohya S
2024-10-12 21:49:07 +09:00
parent ecaea909b1
commit e277b5789e
4 changed files with 82 additions and 18 deletions

View File

@@ -137,7 +137,7 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
@@ -181,7 +181,7 @@ def train(args):
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
@@ -510,8 +510,8 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
@@ -603,8 +603,8 @@ def train(args):
parameter_optimizer_map = {}
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asynchronous purpose, no need to increase this number