mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
Fix bug in FLUX multi GPU training
This commit is contained in:
@@ -57,19 +57,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
model = flux_utils.load_flow_model(
|
||||
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
|
||||
)
|
||||
|
||||
if args.split_mode:
|
||||
model = self.prepare_split_model(model, weight_dtype, accelerator)
|
||||
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
|
||||
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
clip_l.eval()
|
||||
|
||||
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
|
||||
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
t5xxl.eval()
|
||||
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
|
||||
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
|
||||
|
||||
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user