mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Add comment
This commit is contained in:
@@ -1845,12 +1845,12 @@ def main(args):
|
|||||||
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt)
|
||||||
else:
|
else:
|
||||||
print("load Diffusers pretrained models")
|
print("load Diffusers pretrained models")
|
||||||
pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
|
||||||
text_encoder = pipe.text_encoder
|
text_encoder = loading_pipe.text_encoder
|
||||||
vae = pipe.vae
|
vae = loading_pipe.vae
|
||||||
unet = pipe.unet
|
unet = loading_pipe.unet
|
||||||
tokenizer = pipe.tokenizer
|
tokenizer = loading_pipe.tokenizer
|
||||||
del pipe
|
del loading_pipe
|
||||||
|
|
||||||
# VAEを読み込む
|
# VAEを読み込む
|
||||||
if args.vae is not None:
|
if args.vae is not None:
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from typing import Optional, Union
|
||||||
import importlib
|
import importlib
|
||||||
import argparse
|
import argparse
|
||||||
import gc
|
import gc
|
||||||
@@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche
|
|||||||
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
# Which is a newer release of diffusers than currently packaged with sd-scripts
|
||||||
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
# This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
|
|
||||||
|
|
||||||
def get_scheduler_fix(
|
def get_scheduler_fix(
|
||||||
name: Union[str, SchedulerType],
|
name: Union[str, SchedulerType],
|
||||||
@@ -52,53 +52,53 @@ def get_scheduler_fix(
|
|||||||
num_cycles: int = 1,
|
num_cycles: int = 1,
|
||||||
power: float = 1.0,
|
power: float = 1.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Unified API to get any scheduler from its name.
|
Unified API to get any scheduler from its name.
|
||||||
Args:
|
Args:
|
||||||
name (`str` or `SchedulerType`):
|
name (`str` or `SchedulerType`):
|
||||||
The name of the scheduler to use.
|
The name of the scheduler to use.
|
||||||
optimizer (`torch.optim.Optimizer`):
|
optimizer (`torch.optim.Optimizer`):
|
||||||
The optimizer that will be used during training.
|
The optimizer that will be used during training.
|
||||||
num_warmup_steps (`int`, *optional*):
|
num_warmup_steps (`int`, *optional*):
|
||||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
num_training_steps (`int``, *optional*):
|
num_training_steps (`int``, *optional*):
|
||||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||||
num_cycles (`int`, *optional*):
|
num_cycles (`int`, *optional*):
|
||||||
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
||||||
power (`float`, *optional*, defaults to 1.0):
|
power (`float`, *optional*, defaults to 1.0):
|
||||||
Power factor. See `POLYNOMIAL` scheduler
|
Power factor. See `POLYNOMIAL` scheduler
|
||||||
last_epoch (`int`, *optional*, defaults to -1):
|
last_epoch (`int`, *optional*, defaults to -1):
|
||||||
The index of the last epoch when resuming training.
|
The index of the last epoch when resuming training.
|
||||||
"""
|
"""
|
||||||
name = SchedulerType(name)
|
name = SchedulerType(name)
|
||||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||||
if name == SchedulerType.CONSTANT:
|
if name == SchedulerType.CONSTANT:
|
||||||
return schedule_func(optimizer)
|
return schedule_func(optimizer)
|
||||||
|
|
||||||
# All other schedulers require `num_warmup_steps`
|
# All other schedulers require `num_warmup_steps`
|
||||||
if num_warmup_steps is None:
|
if num_warmup_steps is None:
|
||||||
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
||||||
|
|
||||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
|
||||||
|
|
||||||
# All other schedulers require `num_training_steps`
|
# All other schedulers require `num_training_steps`
|
||||||
if num_training_steps is None:
|
if num_training_steps is None:
|
||||||
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
||||||
|
|
||||||
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
||||||
return schedule_func(
|
return schedule_func(
|
||||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles
|
||||||
)
|
)
|
||||||
|
|
||||||
if name == SchedulerType.POLYNOMIAL:
|
if name == SchedulerType.POLYNOMIAL:
|
||||||
return schedule_func(
|
return schedule_func(
|
||||||
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power
|
||||||
)
|
)
|
||||||
|
|
||||||
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
@@ -135,7 +135,7 @@ def train(args):
|
|||||||
train_util.debug_dataset(train_dataset)
|
train_util.debug_dataset(train_dataset)
|
||||||
return
|
return
|
||||||
if len(train_dataset) == 0:
|
if len(train_dataset) == 0:
|
||||||
print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください")
|
print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)")
|
||||||
return
|
return
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
@@ -224,7 +224,7 @@ def train(args):
|
|||||||
# lr schedulerを用意する
|
# lr schedulerを用意する
|
||||||
# lr_scheduler = diffusers.optimization.get_scheduler(
|
# lr_scheduler = diffusers.optimization.get_scheduler(
|
||||||
lr_scheduler = get_scheduler_fix(
|
lr_scheduler = get_scheduler_fix(
|
||||||
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps,
|
||||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||||
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user