This commit is contained in:
minux302
2024-11-21 15:55:27 +00:00
parent 31ca899b6b
commit 0b5229a955
2 changed files with 15 additions and 20 deletions

View File

@@ -266,7 +266,7 @@ def train(args):
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing:
@@ -613,9 +613,6 @@ def train(args):
text_encoder_conds = text_encoding_strategy.encode_tokens(
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
)
# if args.full_fp16:
# text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO: check
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
@@ -733,7 +730,7 @@ def train(args):
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
accelerator.unwrap_model(controlnet),
)
optimizer_train_fn()
@@ -759,19 +756,18 @@ def train(args):
accelerator.wait_for_everyone()
optimizer_eval_fn()
# TODO: save cn models
# if args.save_every_n_epochs is not None:
# if accelerator.is_main_process:
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
# args,
# True,
# accelerator,
# save_dtype,
# epoch,
# num_train_epochs,
# global_step,
# accelerator.unwrap_model(flux),
# )
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(controlnet),
)
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
@@ -791,7 +787,7 @@ def train(args):
del accelerator # この後メモリを使うのでこれは消す
if is_main_process:
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet)
logger.info("model saved.")