diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2..08afedc2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -255,18 +255,31 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: # acceleratorがなんかよろしくやってくれるらしい if args.train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: @@ -324,6 +337,8 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): with torch.no_grad(): @@ -390,9 +405,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3c..1c04fbfe 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", @@ -4087,6 +4087,21 @@ def get_optimizer(args, trainable_params): logger.info(f"use AdamW optimizer | {optimizer_kwargs}") optimizer_class = torch.optim.AdamW optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + + elif optimizer_type.endswith("schedulefree".lower()): + try: + import schedulefree as sf + except ImportError: + raise ImportError("No schedulefree / schedulefreeがインストールされていないようです") + if optimizer_type == "AdamWScheduleFree".lower(): + optimizer_class = sf.AdamWScheduleFree + logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}") + elif optimizer_type == "SGDScheduleFree".lower(): + optimizer_class = sf.SGDScheduleFree + logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}") + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) if optimizer is None: # 任意のoptimizerを使う diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860b..a78687fd 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -415,9 +415,14 @@ def train(args): text_encoder2=text_encoder2 if train_text_encoder2 else None, ) # most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007 - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: @@ -428,7 +433,10 @@ def train(args): text_encoder1 = accelerator.prepare(text_encoder1) if train_text_encoder2: text_encoder2 = accelerator.prepare(text_encoder2) - optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader) + else: + optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -503,6 +511,8 @@ def train(args): m.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -626,9 +636,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628..d2b578c9 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -286,11 +286,19 @@ def train(args): unet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) if args.gradient_checkpointing: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() # TextEncoderの出力をキャッシュするときにはCPUへ移動する @@ -390,6 +398,8 @@ def train(args): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(unet): with torch.no_grad(): @@ -481,9 +491,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c..ee0ef930 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -254,15 +254,24 @@ def train(args): network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, network, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, network, optimizer, train_dataloader = accelerator.prepare( + unet, network, optimizer, train_dataloader + ) + else: + unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, network, optimizer, train_dataloader, lr_scheduler + ) network: control_net_lllite.ControlNetLLLite if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() network.prepare_grad_etc() @@ -357,6 +366,8 @@ def train(args): network.on_epoch_start() # train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(network): with torch.no_grad(): @@ -449,9 +460,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d..38cfb7f2 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -276,9 +276,14 @@ def train(args): controlnet.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい - controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - controlnet, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + controlnet, optimizer, train_dataloader = accelerator.prepare( + controlnet, optimizer, train_dataloader + ) + else: + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) unet.requires_grad_(False) text_encoder.requires_grad_(False) @@ -393,6 +398,8 @@ def train(args): current_epoch.value = epoch + 1 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(controlnet): with torch.no_grad(): @@ -472,6 +479,9 @@ def train(args): lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_db.py b/train_db.py index 1de504ed..6a946aaa 100644 --- a/train_db.py +++ b/train_db.py @@ -229,19 +229,32 @@ def train(args): ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder) else: ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_models = [ds_model] else: if train_text_encoder: - unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, text_encoder, optimizer, train_dataloader = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader + ) + else: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) training_models = [unet, text_encoder] else: - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if args.optimizer_type.lower().endswith("schedulefree"): + unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) training_models = [unet] if not train_text_encoder: @@ -307,6 +320,8 @@ def train(args): text_encoder.train() for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step # 指定したステップ数でText Encoderの学習を止める if global_step == args.stop_text_encoder_training: @@ -384,9 +399,13 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_network.py b/train_network.py index aad5a719..22b6509c 100644 --- a/train_network.py +++ b/train_network.py @@ -420,9 +420,14 @@ class NetworkTrainer: text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None, network=network, ) - ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - ds_model, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + ds_model, optimizer, train_dataloader = accelerator.prepare( + ds_model, optimizer, train_dataloader + ) + else: + ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + ds_model, optimizer, train_dataloader, lr_scheduler + ) training_model = ds_model else: if train_unet: @@ -437,15 +442,23 @@ class NetworkTrainer: text_encoders = [text_encoder] else: pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set - - network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - network, optimizer, train_dataloader, lr_scheduler - ) + + if args.optimizer_type.lower().endswith("schedulefree"): + network, optimizer, train_dataloader = accelerator.prepare( + network, optimizer, train_dataloader + ) + else: + network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + network, optimizer, train_dataloader, lr_scheduler + ) training_model = network if args.gradient_checkpointing: # according to TI example in Diffusers, train is required + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() + for t_enc in text_encoders: t_enc.train() @@ -454,6 +467,8 @@ class NetworkTrainer: t_enc.text_model.embeddings.requires_grad_(True) else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() for t_enc in text_encoders: t_enc.eval() @@ -804,6 +819,8 @@ class NetworkTrainer: accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -909,7 +926,8 @@ class NetworkTrainer: accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if args.scale_weight_norms: @@ -920,6 +938,9 @@ class NetworkTrainer: else: keys_scaled, mean_norm, maximum_norm = None, None, None + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce267..fa1c24e0 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -416,14 +416,24 @@ class TextualInversionTrainer: # acceleratorがなんかよろしくやってくれるらしい if len(text_encoders) == 1: - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_or_list, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder_or_list, optimizer, train_dataloader = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader + ) + else: + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler = accelerator.preparet( + text_encoder_or_list, optimizer, train_dataloader, lr_scheduler + ) elif len(text_encoders) == 2: - text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder1, text_encoder2, optimizer, train_dataloader = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader + ) + else: + text_encoder1, text_encoder2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoders[0], text_encoders[1], optimizer, train_dataloader, lr_scheduler + ) text_encoder_or_list = text_encoders = [text_encoder1, text_encoder2] @@ -452,8 +462,12 @@ class TextualInversionTrainer: unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required # TODO U-Netをオリジナルに置き換えたのでいらないはずなので、後で確認して消す + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() unet.train() else: + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() unet.eval() if not cache_latents: # キャッシュしない場合はVAEを使うのでVAEを準備する @@ -557,6 +571,8 @@ class TextualInversionTrainer: loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoders[0]): with torch.no_grad(): @@ -627,6 +643,9 @@ class TextualInversionTrainer: index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d53..4e95197f 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -335,9 +335,14 @@ def train(args): lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # acceleratorがなんかよろしくやってくれるらしい - text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, lr_scheduler - ) + if args.optimizer_type.lower().endswith("schedulefree"): + text_encoder, optimizer, train_dataloader = accelerator.prepare( + text_encoder, optimizer, train_dataloader + ) + else: + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) index_no_updates = torch.arange(len(tokenizer)) < token_ids_XTI[0] # logger.info(len(index_no_updates), torch.sum(index_no_updates)) @@ -354,8 +359,12 @@ def train(args): unet.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() else: unet.eval() + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() if not cache_latents: vae.requires_grad_(False) @@ -438,6 +447,8 @@ def train(args): loss_total = 0 for step, batch in enumerate(train_dataloader): + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.train() current_step.value = global_step with accelerator.accumulate(text_encoder): with torch.no_grad(): @@ -496,7 +507,8 @@ def train(args): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - lr_scheduler.step() + if not args.optimizer_type.lower().endswith("schedulefree"): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Let's make sure we don't update any embedding weights besides the newly added token @@ -505,6 +517,9 @@ def train(args): index_no_updates ] + if (args.optimizer_type.lower().endswith("schedulefree")): + optimizer.eval() + # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1)