mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
update FLUX LoRA training
This commit is contained in:
@@ -226,6 +226,12 @@ class NetworkTrainer:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
return loss
|
||||
|
||||
def get_sai_model_spec(self, args):
|
||||
return train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||
|
||||
def update_metadata(self, metadata, args):
|
||||
pass
|
||||
|
||||
# endregion
|
||||
|
||||
def train(self, args):
|
||||
@@ -521,10 +527,13 @@ class NetworkTrainer:
|
||||
unet_weight_dtype = torch.float8_e4m3fn
|
||||
te_weight_dtype = torch.float8_e4m3fn
|
||||
|
||||
unet.to(accelerator.device) # this makes faster `to(dtype)` below
|
||||
# unet.to(accelerator.device) # this makes faster `to(dtype)` below, but consumes 23 GB VRAM
|
||||
# unet.to(dtype=unet_weight_dtype) # without moving to gpu, this takes a lot of time and main memory
|
||||
|
||||
unet.to(accelerator.device, dtype=unet_weight_dtype) # this seems to be safer than above
|
||||
|
||||
unet.requires_grad_(False)
|
||||
unet.to(dtype=unet_weight_dtype) # this takes long time and large memory
|
||||
unet.to(dtype=unet_weight_dtype)
|
||||
for t_enc in text_encoders:
|
||||
t_enc.requires_grad_(False)
|
||||
|
||||
@@ -718,8 +727,11 @@ class NetworkTrainer:
|
||||
"ss_loss_type": args.loss_type,
|
||||
"ss_huber_schedule": args.huber_schedule,
|
||||
"ss_huber_c": args.huber_c,
|
||||
"ss_fp8_base": args.fp8_base,
|
||||
}
|
||||
|
||||
self.update_metadata(metadata, args) # architecture specific metadata
|
||||
|
||||
if use_user_config:
|
||||
# save metadata of multiple datasets
|
||||
# NOTE: pack "ss_datasets" value as json one time
|
||||
@@ -964,7 +976,7 @@ class NetworkTrainer:
|
||||
metadata["ss_epoch"] = str(epoch_no)
|
||||
|
||||
metadata_to_save = minimum_metadata if args.no_metadata else metadata
|
||||
sai_metadata = train_util.get_sai_model_spec(None, args, self.is_sdxl, True, False)
|
||||
sai_metadata = self.get_sai_model_spec(args)
|
||||
metadata_to_save.update(sai_metadata)
|
||||
|
||||
unwrapped_nw.save_weights(ckpt_file, save_dtype, metadata_to_save)
|
||||
|
||||
Reference in New Issue
Block a user