mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
save cn
This commit is contained in:
@@ -266,7 +266,7 @@ def train(args):
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, torch.float32, "cpu", args.disable_mmap_load_safetensors)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -613,9 +613,6 @@ def train(args):
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
# if args.full_fp16:
|
||||
# text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
# TODO: check
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
@@ -733,7 +730,7 @@ def train(args):
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(flux),
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
@@ -759,19 +756,18 @@ def train(args):
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
optimizer_eval_fn()
|
||||
# TODO: save cn models
|
||||
# if args.save_every_n_epochs is not None:
|
||||
# if accelerator.is_main_process:
|
||||
# flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
# args,
|
||||
# True,
|
||||
# accelerator,
|
||||
# save_dtype,
|
||||
# epoch,
|
||||
# num_train_epochs,
|
||||
# global_step,
|
||||
# accelerator.unwrap_model(flux),
|
||||
# )
|
||||
if args.save_every_n_epochs is not None:
|
||||
if accelerator.is_main_process:
|
||||
flux_train_utils.save_flux_model_on_epoch_end_or_stepwise(
|
||||
args,
|
||||
True,
|
||||
accelerator,
|
||||
save_dtype,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
accelerator.unwrap_model(controlnet),
|
||||
)
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
@@ -791,7 +787,7 @@ def train(args):
|
||||
del accelerator # この後メモリを使うのでこれは消す
|
||||
|
||||
if is_main_process:
|
||||
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, flux)
|
||||
flux_train_utils.save_flux_model_on_train_end(args, save_dtype, epoch, global_step, controlnet)
|
||||
logger.info("model saved.")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user