diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 8b29e78e..25a5b2d9 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -1845,12 +1845,12 @@ def main(args): text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.ckpt) else: print("load Diffusers pretrained models") - pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) - text_encoder = pipe.text_encoder - vae = pipe.vae - unet = pipe.unet - tokenizer = pipe.tokenizer - del pipe + loading_pipe = StableDiffusionPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype) + text_encoder = loading_pipe.text_encoder + vae = loading_pipe.vae + unet = loading_pipe.unet + tokenizer = loading_pipe.tokenizer + del loading_pipe # VAEを読み込む if args.vae is not None: diff --git a/train_network.py b/train_network.py index aebc4a40..88405221 100644 --- a/train_network.py +++ b/train_network.py @@ -1,3 +1,6 @@ +from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from torch.optim import Optimizer +from typing import Optional, Union import importlib import argparse import gc @@ -40,9 +43,6 @@ def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_sche # Which is a newer release of diffusers than currently packaged with sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts -from typing import Optional, Union -from torch.optim import Optimizer -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION def get_scheduler_fix( name: Union[str, SchedulerType], @@ -52,53 +52,53 @@ def get_scheduler_fix( num_cycles: int = 1, power: float = 1.0, ): - """ - Unified API to get any scheduler from its name. - Args: - name (`str` or `SchedulerType`): - The name of the scheduler to use. - optimizer (`torch.optim.Optimizer`): - The optimizer that will be used during training. - num_warmup_steps (`int`, *optional*): - The number of warmup steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_training_steps (`int``, *optional*): - The number of training steps to do. This is not required by all schedulers (hence the argument being - optional), the function will raise an error if it's unset and the scheduler type requires it. - num_cycles (`int`, *optional*): - The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. - power (`float`, *optional*, defaults to 1.0): - Power factor. See `POLYNOMIAL` scheduler - last_epoch (`int`, *optional*, defaults to -1): - The index of the last epoch when resuming training. - """ - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] - if name == SchedulerType.CONSTANT: - return schedule_func(optimizer) + """ + Unified API to get any scheduler from its name. + Args: + name (`str` or `SchedulerType`): + The name of the scheduler to use. + optimizer (`torch.optim.Optimizer`): + The optimizer that will be used during training. + num_warmup_steps (`int`, *optional*): + The number of warmup steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_training_steps (`int``, *optional*): + The number of training steps to do. This is not required by all schedulers (hence the argument being + optional), the function will raise an error if it's unset and the scheduler type requires it. + num_cycles (`int`, *optional*): + The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. + power (`float`, *optional*, defaults to 1.0): + Power factor. See `POLYNOMIAL` scheduler + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + """ + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + if name == SchedulerType.CONSTANT: + return schedule_func(optimizer) - # All other schedulers require `num_warmup_steps` - if num_warmup_steps is None: - raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") + # All other schedulers require `num_warmup_steps` + if num_warmup_steps is None: + raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") - if name == SchedulerType.CONSTANT_WITH_WARMUP: - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) + if name == SchedulerType.CONSTANT_WITH_WARMUP: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) - # All other schedulers require `num_training_steps` - if num_training_steps is None: - raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") + # All other schedulers require `num_training_steps` + if num_training_steps is None: + raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") - if name == SchedulerType.COSINE_WITH_RESTARTS: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles - ) + if name == SchedulerType.COSINE_WITH_RESTARTS: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles + ) - if name == SchedulerType.POLYNOMIAL: - return schedule_func( - optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power - ) + if name == SchedulerType.POLYNOMIAL: + return schedule_func( + optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power + ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) def train(args): @@ -135,7 +135,7 @@ def train(args): train_util.debug_dataset(train_dataset) return if len(train_dataset) == 0: - print("No data found. Please verify arguments / 画像がありません。引数指定を確認してください") + print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") return # acceleratorを準備する @@ -224,7 +224,7 @@ def train(args): # lr schedulerを用意する # lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = get_scheduler_fix( - args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, + args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power)