mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
change name of arg
This commit is contained in:
@@ -100,9 +100,6 @@ def get_scheduler_fix(
|
|||||||
|
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||||
|
|
||||||
diffusers.optimization.get_scheduler = get_scheduler_fix
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
session_id = random.randint(0, 2**32)
|
session_id = random.randint(0, 2**32)
|
||||||
@@ -225,10 +222,11 @@ def train(args):
|
|||||||
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}")
|
||||||
|
|
||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
lr_scheduler = diffusers.optimization.get_scheduler(
|
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
|
lr_scheduler = get_scheduler_fix(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
num_cycles = args.num_cycles, power = args.power)
|
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
|
|
||||||
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
# 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする
|
||||||
if args.full_fp16:
|
if args.full_fp16:
|
||||||
@@ -516,6 +514,10 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率")
|
||||||
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率")
|
||||||
|
parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1,
|
||||||
|
help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数")
|
||||||
|
parser.add_argument("--lr_scheduler_power", type=float, default=1,
|
||||||
|
help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power")
|
||||||
|
|
||||||
parser.add_argument("--network_weights", type=str, default=None,
|
parser.add_argument("--network_weights", type=str, default=None,
|
||||||
help="pretrained weights for network / 学習するネットワークの初期重み")
|
help="pretrained weights for network / 学習するネットワークの初期重み")
|
||||||
@@ -531,10 +533,6 @@ if __name__ == '__main__':
|
|||||||
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
help="only training Text Encoder part / Text Encoder関連部分のみ学習する")
|
||||||
parser.add_argument("--training_comment", type=str, default=None,
|
parser.add_argument("--training_comment", type=str, default=None,
|
||||||
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列")
|
||||||
parser.add_argument("--num_cycles", type=int, default=1,
|
|
||||||
help="Number of restarts for cosine scheduler with restarts")
|
|
||||||
parser.add_argument("--power", type=float, default=1,
|
|
||||||
help="Polynomial power for polynomial scheduler")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
train(args)
|
train(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user