Add comment

This commit is contained in:
Kohya S
2023-02-03 21:04:03 +09:00
parent 93134cdd15
commit 58a809eaff
2 changed files with 52 additions and 52 deletions

View File

@@ -1845,12 +1845,12 @@ def main(args):
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
else: else:
print("load Diffusers pretrained models") print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
text_encoder = pipe.text_encoder text_encoder = loading_pipe.text_encoder
vae = pipe.vae vae = loading_pipe.vae
unet = pipe.unet unet = loading_pipe.unet
tokenizer = pipe.tokenizer tokenizer = loading_pipe.tokenizer
del pipe del loading_pipe
# VAEを読み込む # VAEを読み込む
if args.vae is not None: if args.vae is not None:

View File

@@ -1,3 +1,6 @@
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
from torch.optim import Optimizer
from typing import Optional, Union
import importlib import importlib
import argparse import argparse
import gc import gc
@@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
# Which is a newer release of diffusers than currently packaged with sd-scripts # Which is a newer release of diffusers than currently packaged with sd-scripts
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
from typing import Optional, Union
from torch.optim import Optimizer
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
def get_scheduler_fix( def get_scheduler_fix(
name: Union[str, SchedulerType], name: Union[str, SchedulerType],
@@ -52,53 +52,53 @@ def get_scheduler_fix(
num_cycles: int = 1, num_cycles: int = 1,
power: float = 1.0, power: float = 1.0,
): ):
""" """
Unified API to get any scheduler from its name. Unified API to get any scheduler from its name.
Args: Args:
name (`str` or `SchedulerType`): name (`str` or `SchedulerType`):
The name of the scheduler to use. The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`): optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training. The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*): num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*): num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it. optional), the function will raise an error if it's unset and the scheduler type requires it.
num_cycles (`int`, *optional*): num_cycles (`int`, *optional*):
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
power (`float`, *optional*, defaults to 1.0): power (`float`, *optional*, defaults to 1.0):
Power factor. See `POLYNOMIAL` scheduler Power factor. See `POLYNOMIAL` scheduler
last_epoch (`int`, *optional*, defaults to -1): last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training. The index of the last epoch when resuming training.
""" """
name = SchedulerType(name) name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT: if name == SchedulerType.CONSTANT:
return schedule_func(optimizer) return schedule_func(optimizer)
# All other schedulers require `num_warmup_steps` # All other schedulers require `num_warmup_steps`
if num_warmup_steps is None: if num_warmup_steps is None:
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
if name == SchedulerType.CONSTANT_WITH_WARMUP: if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
# All other schedulers require `num_training_steps` # All other schedulers require `num_training_steps`
if num_training_steps is None: if num_training_steps is None:
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
if name == SchedulerType.COSINE_WITH_RESTARTS: if name == SchedulerType.COSINE_WITH_RESTARTS:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
) )
if name == SchedulerType.POLYNOMIAL: if name == SchedulerType.POLYNOMIAL:
return schedule_func( return schedule_func(
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
) )
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
def train(args): def train(args):
@@ -135,7 +135,7 @@ def train(args):
train_util.debug_dataset(train_dataset) train_util.debug_dataset(train_dataset)
return return
if len(train_dataset) == 0: if len(train_dataset) == 0:
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してくださいtrain_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります")
return return
# acceleratorを準備する # acceleratorを準備する
@@ -224,7 +224,7 @@ def train(args):
# lr schedulerを用意する # lr schedulerを用意する
# lr_scheduler = diffusers.optimization.get_scheduler( # lr_scheduler = diffusers.optimization.get_scheduler(
lr_scheduler = get_scheduler_fix( lr_scheduler = get_scheduler_fix(
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)