Fix bug in FLUX multi GPU training

This commit is contained in:
kohya-ss
2024-08-22 12:37:41 +09:00
parent e1cd19c0c0
commit 98c91a7625
8 changed files with 156 additions and 38 deletions

View File

@@ -57,19 +57,21 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
model = flux_utils.load_flow_model(
name, args.pretrained_model_name_or_path, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
)
if args.split_mode:
model = self.prepare_split_model(model, weight_dtype, accelerator)
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu")
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
clip_l.eval()
# loading t5xxl to cpu takes a long time, so we should load to gpu in future
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
t5xxl.eval()
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu")
ae = flux_utils.load_ae(name, args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model