diff --git a/fine_tune.py b/fine_tune.py index 08afedc2..b82a67ae 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -250,36 +250,32 @@ def train(args): unet.to(weight_dtype) text_encoder.to(weight_dtype) + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") if args.deepspeed: if args.train_text_encoder: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - if args.optimizer_type.lower().endswith("schedulefree"): - ds_model, optimizer, train_dataloader = accelerator.prepare( - ds_model, optimizer, train_dataloader - ) - else: - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) training_models = [ds_model] else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: - if args.optimizer_type.lower().endswith("schedulefree"): - unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader - ) - else: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader) else: - if args.optimizer_type.lower().endswith("schedulefree"): - unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() + else: + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -337,8 +333,7 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(*training_models): with torch.no_grad(): @@ -369,7 +364,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -383,7 +380,9 @@ def train(args): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -395,7 +394,9 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -405,12 +406,10 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -490,7 +489,7 @@ def train(args): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index 1c04fbfe..31e37bf7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4087,17 +4087,17 @@ def get_optimizer(args, trainable_params): logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) - + elif optimizer_type.endswith("schedulefree".lower()): try: import schedulefree as sf except ImportError: raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") if optimizer_type == "AdamWScheduleFree".lower(): - optimizer_class = sf.AdamWScheduleFree - logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") elif optimizer_type == "SGDScheduleFree".lower(): - optimizer_class = sf.SGDScheduleFree + optimizer_class = sf.SGDScheduleFree logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") else: raise ValueError(f"Unknown optimizer type: {optimizer_type}") @@ -4131,6 +4131,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ Unified API to get any scheduler from its name. """ + # supports schedule free optimizer + if args.optimizer_type.lower().endswith("schedulefree"): + # return dummy scheduler: it has 'step' method but does nothing + logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します") + lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer) + lr_scheduler.step = lambda: None + return lr_scheduler + name = args.lr_scheduler num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps @@ -4265,7 +4273,7 @@ def load_tokenizer(args: argparse.Namespace): return tokenizer -def prepare_accelerator(args: argparse.Namespace): +def prepare_accelerator(args: argparse.Namespace) -> Accelerator: """ this function also prepares deepspeed plugin """ diff --git a/requirements.txt b/requirements.txt index e99775b8..e5fee6cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -accelerate==0.25.0 +accelerate==0.29.2 transformers==4.36.2 diffusers[torch]==0.25.0 ftfy==6.1.1 diff --git a/sdxl_train.py b/sdxl_train.py index a78687fd..8944f3a0 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -407,6 +407,7 @@ def train(args): text_encoder1.text_model.encoder.layers[-1].requires_grad_(False) text_encoder1.text_model.final_layer_norm.requires_grad_(False) + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -415,14 +416,9 @@ def train(args): text_encoder2=text_encoder2 if train_text_encoder2 else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - if args.optimizer_type.lower().endswith("schedulefree"): - ds_model, optimizer, train_dataloader = accelerator.prepare( - ds_model, optimizer, train_dataloader - ) - else: - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) training_models = [ds_model] else: @@ -433,10 +429,17 @@ def train(args): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - if args.optimizer_type.lower().endswith("schedulefree"): - optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) - else: - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() + else: + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -511,8 +514,7 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -592,7 +594,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -610,7 +614,9 @@ def train(args): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -626,7 +632,9 @@ def train(args): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -636,12 +644,10 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -750,7 +756,7 @@ def train(args): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index d2b578c9..6e0c2c8a 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -15,6 +15,7 @@ from tqdm import tqdm import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -286,19 +287,22 @@ def train(args): unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if args.optimizer_type.lower().endswith("schedulefree"): - unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None if args.gradient_checkpointing: - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる - else: - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() unet.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する @@ -398,8 +402,7 @@ def train(args): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(unet): with torch.no_grad(): @@ -449,7 +452,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -468,7 +473,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -491,12 +498,10 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index ee0ef930..8585df13 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,24 +254,27 @@ def train(args): network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if args.optimizer_type.lower().endswith("schedulefree"): + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") unet, network, optimizer, train_dataloader = accelerator.prepare( unet, network, optimizer, train_dataloader ) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() else: - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None + network: control_net_lllite.ControlNetLLLite if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() else: unet.eval() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() network.prepare_grad_etc() @@ -366,8 +369,7 @@ def train(args): network.on_epoch_start() # train() for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): @@ -460,12 +462,10 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_controlnet.py b/train_controlnet.py index 38cfb7f2..1785607b 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -13,6 +13,7 @@ from tqdm import tqdm import torch from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -226,7 +227,7 @@ def train(args): ) vae.to("cpu") clean_memory_on_device(accelerator.device) - + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -276,14 +277,18 @@ def train(args): controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - if args.optimizer_type.lower().endswith("schedulefree"): - controlnet, optimizer, train_dataloader = accelerator.prepare( - controlnet, optimizer, train_dataloader - ) + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") + controlnet, optimizer, train_dataloader = accelerator.prepare(controlnet, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() else: - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None unet.requires_grad_(False) text_encoder.requires_grad_(False) @@ -398,8 +403,7 @@ def train(args): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(controlnet): with torch.no_grad(): @@ -427,7 +431,9 @@ def train(args): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + timesteps, huber_c = train_util.get_timesteps_and_huber_c( + args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + ) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -459,7 +465,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight @@ -479,8 +487,7 @@ def train(args): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_db.py b/train_db.py index 6a946aaa..36ed867a 100644 --- a/train_db.py +++ b/train_db.py @@ -224,38 +224,36 @@ def train(args): text_encoder.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") if args.deepspeed: if args.train_text_encoder: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - if args.optimizer_type.lower().endswith("schedulefree"): - ds_model, optimizer, train_dataloader = accelerator.prepare( - ds_model, optimizer, train_dataloader - ) - else: - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) training_models = [ds_model] else: if train_text_encoder: - if args.optimizer_type.lower().endswith("schedulefree"): - unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader - ) - else: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) training_models = [unet, text_encoder] else: - if args.optimizer_type.lower().endswith("schedulefree"): - unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) - else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) training_models = [unet] + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() + else: + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None if not train_text_encoder: text_encoder.to(accelerator.device, dtype=weight_dtype) # to avoid 'cpu' vs 'cuda' error @@ -320,8 +318,7 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -361,7 +358,9 @@ def train(args): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -373,7 +372,9 @@ def train(args): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -399,12 +400,10 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_network.py b/train_network.py index 22b6509c..68341378 100644 --- a/train_network.py +++ b/train_network.py @@ -412,6 +412,7 @@ class NetworkTrainer: t_enc.text_model.embeddings.to(dtype=(weight_dtype if te_weight_dtype != weight_dtype else te_weight_dtype)) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") if args.deepspeed: ds_model = deepspeed_utils.prepare_deepspeed_model( args, @@ -420,14 +421,9 @@ class NetworkTrainer: text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, network=network, ) - if args.optimizer_type.lower().endswith("schedulefree"): - ds_model, optimizer, train_dataloader = accelerator.prepare( - ds_model, optimizer, train_dataloader - ) - else: - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) training_model = ds_model else: if train_unet: @@ -442,21 +438,22 @@ class NetworkTrainer: text_encoders = [text_encoder] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - if args.optimizer_type.lower().endswith("schedulefree"): - network, optimizer, train_dataloader = accelerator.prepare( - network, optimizer, train_dataloader - ) - else: - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + + network, optimizer, train_dataloader = accelerator.prepare(network, optimizer, train_dataloader) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) training_model = network + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() + else: + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None + if args.gradient_checkpointing: # according to TI example in Diffusers, train is required - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() unet.train() for t_enc in text_encoders: @@ -467,8 +464,6 @@ class NetworkTrainer: t_enc.text_model.embeddings.requires_grad_(True) else: - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() unet.eval() for t_enc in text_encoders: t_enc.eval() @@ -819,8 +814,7 @@ class NetworkTrainer: accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -926,8 +920,7 @@ class NetworkTrainer: accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: @@ -938,8 +931,7 @@ class NetworkTrainer: else: keys_scaled, mean_norm, maximum_norm = None, None, None - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_textual_inversion.py b/train_textual_inversion.py index fa1c24e0..de93273b 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -415,30 +415,28 @@ class TextualInversionTrainer: lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") if len(text_encoders) == 1: - if args.optimizer_type.lower().endswith("schedulefree"): - text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet( - text_encoder_or_list, optimizer, train_dataloader - ) - else: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) - + text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader + ) elif len(text_encoders) == 2: - if args.optimizer_type.lower().endswith("schedulefree"): - text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader - ) - else: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) - + text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader + ) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] - else: raise NotImplementedError() + if not use_schedule_free_optimizer: + optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() + else: + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None index_no_updates_list = [] orig_embeds_params_list = [] @@ -462,12 +460,8 @@ class TextualInversionTrainer: unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() unet.train() else: - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() unet.eval() if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する @@ -571,8 +565,7 @@ class TextualInversionTrainer: loss_total = 0 for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): @@ -604,7 +597,9 @@ class TextualInversionTrainer: else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -643,8 +638,7 @@ class TextualInversionTrainer: index_no_updates ] - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 4e95197f..cb38e798 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -335,14 +335,20 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - if args.optimizer_type.lower().endswith("schedulefree"): - text_encoder, optimizer, train_dataloader = accelerator.prepare( - text_encoder, optimizer, train_dataloader - ) + use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree") + text_encoder, optimizer, train_dataloader = accelerator.prepare( + text_encoder, optimizer, train_dataloader + ) + if not use_schedule_free_optimizer: + lr_scheduler = accelerator.prepare(lr_scheduler) + + # make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used + if use_schedule_free_optimizer: + optimizer_train_if_needed = lambda: optimizer.train() + optimizer_eval_if_needed = lambda: optimizer.eval() else: - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + optimizer_train_if_needed = lambda: None + optimizer_eval_if_needed = lambda: None index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # logger.info(len(index_no_updates), torch.sum(index_no_updates)) @@ -359,12 +365,8 @@ def train(args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() else: unet.eval() - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() if not cache_latents: vae.requires_grad_(False) @@ -447,8 +449,7 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.train() + optimizer_train_if_needed() current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -507,8 +508,7 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - if not args.optimizer_type.lower().endswith("schedulefree"): - lr_scheduler.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token @@ -517,8 +517,7 @@ def train(args): index_no_updates ] - if (args.optimizer_type.lower().endswith("schedulefree")): - optimizer.eval() + optimizer_eval_if_needed() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: