diff --git a/train_network.py b/train_network.py index 6bc8bd08..37a10f65 100644 --- a/train_network.py +++ b/train_network.py @@ -100,9 +100,6 @@ def get_scheduler_fix( return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) -diffusers.optimization.get_scheduler = get_scheduler_fix - - def train(args): session_id = random.randint(0, 2**32) @@ -225,10 +222,11 @@ def train(args): print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する - lr_scheduler = diffusers.optimization.get_scheduler( + # lr_scheduler = diffusers.optimization.get_scheduler( + lr_scheduler = get_scheduler_fix( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - num_cycles = args.num_cycles, power = args.power) + num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: @@ -516,6 +514,10 @@ if __name__ == '__main__': parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") + parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, + help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") + parser.add_argument("--lr_scheduler_power", type=float, default=1, + help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") @@ -531,10 +533,6 @@ if __name__ == '__main__': help="only training Text Encoder part / Text Encoder関連部分のみ学習する") parser.add_argument("--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") - parser.add_argument("--num_cycles", type=int, default=1, - help="Number of restarts for cosine scheduler with restarts") - parser.add_argument("--power", type=float, default=1, - help="Polynomial power for polynomial scheduler") args = parser.parse_args() train(args)