mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Update FLUX.1 support for compact models
This commit is contained in:
@@ -137,7 +137,7 @@ def train(args):
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
|
||||
|
||||
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
|
||||
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
|
||||
if args.debug_dataset:
|
||||
if args.cache_text_encoder_outputs:
|
||||
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
|
||||
@@ -181,7 +181,7 @@ def train(args):
|
||||
# load VAE for caching latents
|
||||
ae = None
|
||||
if cache_latents:
|
||||
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
ae.to(accelerator.device, dtype=weight_dtype)
|
||||
ae.requires_grad_(False)
|
||||
ae.eval()
|
||||
@@ -510,8 +510,8 @@ def train(args):
|
||||
library.adafactor_fused.patch_adafactor_fused(optimizer)
|
||||
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
handled_unit_indices = set()
|
||||
|
||||
@@ -603,8 +603,8 @@ def train(args):
|
||||
parameter_optimizer_map = {}
|
||||
|
||||
blocks_to_swap = args.blocks_to_swap
|
||||
num_double_blocks = 19 # len(flux.double_blocks)
|
||||
num_single_blocks = 38 # len(flux.single_blocks)
|
||||
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
|
||||
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
|
||||
num_block_units = num_double_blocks + num_single_blocks // 2
|
||||
|
||||
n = 1 # only asynchronous purpose, no need to increase this number
|
||||
|
||||
Reference in New Issue
Block a user