load Diffusers format, check schnell/dev

This commit is contained in:
Kohya S
2024-10-06 21:32:21 +09:00
parent ba08a89894
commit 83e3048cb0
6 changed files with 196 additions and 111 deletions

View File

@@ -419,9 +419,6 @@ if __name__ == "__main__":
steps = args.steps
guidance_scale = args.guidance
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
is_schnell = name == "schnell"
def is_fp8(dt):
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
@@ -455,12 +452,8 @@ if __name__ == "__main__":
# if is_fp8(t5xxl_dtype):
# t5xxl = accelerator.prepare(t5xxl)
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# DiT
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype
@@ -469,8 +462,12 @@ if __name__ == "__main__":
# if args.offload:
# model = model.to("cpu")
t5xxl_max_length = 256 if is_schnell else 512
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
# AE
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
ae.eval()
# if is_fp8(ae_dtype):
# ae = accelerator.prepare(ae)