fix flux fine tuning to work

This commit is contained in:
kohya-ss
2024-08-17 15:54:32 +09:00
parent 400955d3ea
commit 25f77f6ef0
2 changed files with 6 additions and 4 deletions

View File

@@ -674,9 +674,7 @@ def train(args):
# if is_main_process:
flux = accelerator.unwrap_model(flux)
clip_l = accelerator.unwrap_model(clip_l)
clip_g = accelerator.unwrap_model(clip_g)
if t5xxl is not None:
t5xxl = accelerator.unwrap_model(t5xxl)
t5xxl = accelerator.unwrap_model(t5xxl)
accelerator.end_training()
@@ -686,7 +684,7 @@ def train(args):
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux, ae)
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
logger.info("model saved.")