Merge branch 'sd3' into multi-gpu-caching

This commit is contained in:
kohya-ss
2024-10-13 11:52:42 +09:00
4 changed files with 82 additions and 18 deletions

View File

@@ -141,7 +141,7 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
@@ -514,8 +514,8 @@ def train(args):
library.adafactor_fused.patch_adafactor_fused(optimizer)
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
handled_unit_indices = set()
@@ -607,8 +607,8 @@ def train(args):
parameter_optimizer_map = {}
blocks_to_swap = args.blocks_to_swap
num_double_blocks = 19 # len(flux.double_blocks)
num_single_blocks = 38 # len(flux.single_blocks)
num_double_blocks = len(accelerator.unwrap_model(flux).double_blocks)
num_single_blocks = len(accelerator.unwrap_model(flux).single_blocks)
num_block_units = num_double_blocks + num_single_blocks // 2
n = 1 # only asynchronous purpose, no need to increase this number