update FLUX LoRA training

This commit is contained in:
Kohya S
2024-08-10 23:42:05 +09:00
parent 358f13f2c9
commit 8a0f12dde8
7 changed files with 148 additions and 39 deletions

View File

@@ -226,6 +226,12 @@ class NetworkTrainer:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
return loss
def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
def update_metadata(self, metadata, args):
pass
# endregion
def train(self, args):
@@ -521,10 +527,13 @@ class NetworkTrainer:
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn
unet.to(accelerator.device) # this makes faster `to(dtype)` below
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
unet.requires_grad_(False)
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
@@ -718,8 +727,11 @@ class NetworkTrainer:
"ss_loss_type": args.loss_type,
"ss_huber_schedule": args.huber_schedule,
"ss_huber_c": args.huber_c,
"ss_fp8_base": args.fp8_base,
}
self.update_metadata(metadata, args) # architecture specific metadata
if use_user_config:
# save metadata of multiple datasets
# NOTE: pack "ss_datasets" value as json one time
@@ -964,7 +976,7 @@ class NetworkTrainer:
metadata["ss_epoch"] = str(epoch_no)
metadata_to_save = minimum_metadata if args.no_metadata else metadata
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
sai_metadata = self.get_sai_model_spec(args)
metadata_to_save.update(sai_metadata)
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)