change name of arg

This commit is contained in:
Kohya S
2023-01-29 20:28:24 +09:00
parent 443ce7a30b
commit 7817e95a86

View File

@@ -100,9 +100,6 @@ def get_scheduler_fix(
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
diffusers.optimization.get_scheduler = get_scheduler_fix
def train(args):
session_id = random.randint(0, 2**32)
@@ -225,10 +222,11 @@ def train(args):
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
# lr schedulerを用意する
lr_scheduler = diffusers.optimization.get_scheduler(
# lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles = args.num_cycles, power = args.power)
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
# 実験的機能勾配も含めたfp16学習を行う モデル全体をfp16にする
if args.full_fp16:
@@ -516,6 +514,10 @@ if __name__ == '__main__':
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
parser.add_argument("--lr_scheduler_power", type=float, default=1,
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
parser.add_argument("--network_weights", type=str, default=None,
help="pretrained weights for network / 学習するネットワークの初期重み")
@@ -531,10 +533,6 @@ if __name__ == '__main__':
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
parser.add_argument("--training_comment", type=str, default=None,
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
parser.add_argument("--num_cycles", type=int, default=1,
help="Number of restarts for cosine scheduler with restarts")
parser.add_argument("--power", type=float, default=1,
help="Polynomial power for polynomial scheduler")
args = parser.parse_args()
train(args)