load Diffusers format, check schnell/dev

This commit is contained in:
Kohya S
2024-10-06 21:32:21 +09:00
parent ba08a89894
commit 83e3048cb0
6 changed files with 196 additions and 111 deletions

View File

@@ -137,6 +137,7 @@ def train(args):
train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認
_, is_schnell, _ = flux_utils.check_flux_state_dict_diffusers_schnell(args.pretrained_model_name_or_path)
if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
@@ -144,9 +145,8 @@ def train(args):
args.cache_text_encoder_outputs_to_disk, args.text_encoder_batch_size, False, False
)
)
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
t5xxl_max_token_length = (
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if name == "schnell" else 512)
args.t5xxl_max_token_length if args.t5xxl_max_token_length is not None else (256 if is_schnell else 512)
)
strategy_base.TokenizeStrategy.set_strategy(strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length))
@@ -177,12 +177,11 @@ def train(args):
weight_dtype, save_dtype = train_util.prepare_dtype(args)
# モデルを読み込む
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
# load VAE for caching latents
ae = None
if cache_latents:
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae = flux_utils.load_ae( args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()
@@ -196,7 +195,7 @@ def train(args):
# prepare tokenize strategy
if args.t5xxl_max_token_length is None:
if name == "schnell":
if is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
@@ -258,8 +257,8 @@ def train(args):
clean_memory_on_device(accelerator.device)
# load FLUX
flux = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
if args.gradient_checkpointing:
@@ -294,7 +293,7 @@ def train(args):
if not cache_latents:
# load VAE here if not cached
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)