mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
load Diffusers format, check schnell/dev
This commit is contained in:
@@ -419,9 +419,6 @@ if __name__ == "__main__":
|
||||
steps = args.steps
|
||||
guidance_scale = args.guidance
|
||||
|
||||
name = "schnell" if "schnell" in args.ckpt_path else "dev" # TODO change this to a more robust way
|
||||
is_schnell = name == "schnell"
|
||||
|
||||
def is_fp8(dt):
|
||||
return dt in [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
|
||||
@@ -455,12 +452,8 @@ if __name__ == "__main__":
|
||||
# if is_fp8(t5xxl_dtype):
|
||||
# t5xxl = accelerator.prepare(t5xxl)
|
||||
|
||||
t5xxl_max_length = 256 if is_schnell else 512
|
||||
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
||||
|
||||
# DiT
|
||||
model = flux_utils.load_flow_model(name, args.ckpt_path, None, loading_device)
|
||||
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
|
||||
model.eval()
|
||||
logger.info(f"Casting model to {flux_dtype}")
|
||||
model.to(flux_dtype) # make sure model is dtype
|
||||
@@ -469,8 +462,12 @@ if __name__ == "__main__":
|
||||
# if args.offload:
|
||||
# model = model.to("cpu")
|
||||
|
||||
t5xxl_max_length = 256 if is_schnell else 512
|
||||
tokenize_strategy = strategy_flux.FluxTokenizeStrategy(t5xxl_max_length)
|
||||
encoding_strategy = strategy_flux.FluxTextEncodingStrategy()
|
||||
|
||||
# AE
|
||||
ae = flux_utils.load_ae(name, args.ae, ae_dtype, loading_device)
|
||||
ae = flux_utils.load_ae(args.ae, ae_dtype, loading_device)
|
||||
ae.eval()
|
||||
# if is_fp8(ae_dtype):
|
||||
# ae = accelerator.prepare(ae)
|
||||
|
||||
Reference in New Issue
Block a user