from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from torch.optim import Optimizer from typing import Optional, Tuple, Union import importlib import argparse import gc import math import os import random import time import json from tqdm import tqdm import torch from accelerate.utils import set_seed import diffusers from diffusers import DDPMScheduler import library.train_util as train_util from library.train_util import BaseDataset, ImageInfo, glob_images import networks.control_net_lora as control_net_rola def collate_fn(examples): return examples[0] def generate_step_logs(args: argparse.Namespace, current_loss, avr_loss, lr_scheduler): logs = {"loss/current": current_loss, "loss/average": avr_loss} if args.network_train_unet_only: logs["lr/unet"] = lr_scheduler.get_last_lr()[0] elif args.network_train_text_encoder_only: logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] else: logs["lr/textencoder"] = lr_scheduler.get_last_lr()[0] logs["lr/unet"] = lr_scheduler.get_last_lr()[-1] # may be same to textencoder return logs # Monkeypatch newer get_scheduler() function overridng current version of diffusers.optimizer.get_scheduler # code is taken from https://github.com/huggingface/diffusers diffusers.optimizer, commit d87cc15977b87160c30abaace3894e802ad9e1e6 # Which is a newer release of diffusers than currently packaged with sd-scripts # This code can be removed when newer diffusers version (v0.12.1 or greater) is tested and implemented to sd-scripts def get_scheduler_fix( name: Union[str, SchedulerType], optimizer: Optimizer, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, power: float = 1.0, ): """ Unified API to get any scheduler from its name. Args: name (`str` or `SchedulerType`): The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_training_steps (`int``, *optional*): The number of training steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. num_cycles (`int`, *optional*): The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. power (`float`, *optional*, defaults to 1.0): Power factor. See `POLYNOMIAL` scheduler last_epoch (`int`, *optional*, defaults to -1): The index of the last epoch when resuming training. """ name = SchedulerType(name) schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return schedule_func(optimizer) # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps) # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") if name == SchedulerType.COSINE_WITH_RESTARTS: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_cycles=num_cycles ) if name == SchedulerType.POLYNOMIAL: return schedule_func( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power ) return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps) class ImagesWithHintDataset(BaseDataset): def __init__(self, batch_size, train_data_dir, tokenizer, max_token_length, caption_extension, shuffle_caption, shuffle_keep_tokens, resolution, enable_bucket, min_bucket_reso, max_bucket_reso, bucket_reso_steps, bucket_no_upscale, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) -> None: super().__init__(tokenizer, max_token_length, shuffle_caption, shuffle_keep_tokens, resolution, flip_aug, color_aug, face_crop_aug_range, random_crop, debug_dataset) assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です" self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.latents_cache = None self.enable_bucket = enable_bucket if self.enable_bucket: assert min(resolution) >= min_bucket_reso, f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" assert max(resolution) <= max_bucket_reso, f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps self.bucket_no_upscale = bucket_no_upscale else: self.min_bucket_reso = None self.max_bucket_reso = None self.bucket_reso_steps = None # この情報は使われない self.bucket_no_upscale = False # fill50k print("loading fill50k dataset") with open(os.path.join(train_data_dir, "prompt.json")) as f: annos = f.readlines() captions = [] src_paths = [] trg_paths = [] for anno in annos: anno1 = json.loads(anno) captions.append(anno1["prompt"]) src_paths.append(os.path.join(train_data_dir, anno1["source"])) trg_paths.append(os.path.join(train_data_dir, anno1["target"])) self.set_tag_frequency(os.path.basename(train_data_dir), captions) # タグ頻度を記録 self.dataset_dirs_info[os.path.basename(train_data_dir)] = {"n_repeats": 1, "img_count": len(src_paths)} for src_path, trg_path, caption in zip(src_paths, trg_paths, captions): info = ImageInfo(src_path, 1, caption, False, src_path) self.register_image(info) num_train_images = len(src_paths) print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images self.num_reg_images = 0 """ def read_caption(img_path): # captionの候補ファイル名を作る base_name = os.path.splitext(img_path)[0] base_name_face_det = base_name tokens = base_name.split("_") if len(tokens) >= 5: base_name_face_det = "_".join(tokens[:-4]) cap_paths = [base_name + caption_extension, base_name_face_det + caption_extension] caption = None for cap_path in cap_paths: if os.path.isfile(cap_path): with open(cap_path, "rt", encoding='utf-8') as f: try: lines = f.readlines() except UnicodeDecodeError as e: print(f"illegal char in file (not UTF-8) / ファイルにUTF-8以外の文字があります: {cap_path}") raise e assert len(lines) > 0, f"caption file is empty / キャプションファイルが空です: {cap_path}" caption = lines[0].strip() break return caption def load_dreambooth_dir(dir): if not os.path.isdir(dir): # print(f"ignore file: {dir}") return 0, [], [] tokens = os.path.basename(dir).split('_') try: n_repeats = int(tokens[0]) except ValueError as e: print(f"ignore directory without repeats / 繰り返し回数のないディレクトリを無視します: {dir}") return 0, [], [] caption_by_folder = '_'.join(tokens[1:]) img_paths = glob_images(dir, "*") print(f"found directory {n_repeats}_{caption_by_folder} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] for img_path in img_paths: cap_for_img = read_caption(img_path) captions.append(caption_by_folder if cap_for_img is None else cap_for_img) self.set_tag_frequency(os.path.basename(dir), captions) # タグ頻度を記録 return n_repeats, img_paths, captions print("prepare train images.") train_dirs = os.listdir(train_data_dir) num_train_images = 0 for dir in train_dirs: n_repeats, img_paths, captions = load_dreambooth_dir(os.path.join(train_data_dir, dir)) num_train_images += n_repeats * len(img_paths) for img_path, caption in zip(img_paths, captions): info = ImageInfo(img_path, n_repeats, caption, False, img_path) self.register_image(info) self.dataset_dirs_info[os.path.basename(dir)] = {"n_repeats": n_repeats, "img_count": len(img_paths)} print(f"{num_train_images} train images with repeating.") self.num_train_images = num_train_images self.num_reg_images = 0 """ def __getitem__(self, index): # latentsのcacheをサポートしてない if index == 0: self.shuffle_buckets() bucket = self.bucket_manager.buckets[self.buckets_indices[index].bucket_index] bucket_batch_size = self.buckets_indices[index].bucket_batch_size image_index = self.buckets_indices[index].batch_index * bucket_batch_size loss_weights = [] captions = [] input_ids_list = [] images = [] hint_images = [] for image_key in bucket[image_index:image_index + bucket_batch_size]: image_info = self.image_data[image_key] loss_weights.append(1.0) # image/latentsを処理する # 画像を読み込み、必要ならcropする src_path = image_info.absolute_path trg_path = src_path.replace("source", "target") img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(trg_path) hint_img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(src_path) assert img.shape[0:2] == hint_img.shape[0:2] im_h, im_w = img.shape[0:2] if self.enable_bucket: img = self.trim_and_resize_if_required(img, image_info.bucket_reso, image_info.resized_size) else: if face_cx > 0: # 顔位置情報あり img = self.crop_target(img, face_cx, face_cy, face_w, face_h) elif im_h > self.height or im_w > self.width: assert self.random_crop, f"image too large, but cropping and bucketing are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_crop、またはbucketを有効にしてください: {image_info.absolute_path}" if im_h > self.height: p = random.randint(0, im_h - self.height) img = img[p:p + self.height] if im_w > self.width: p = random.randint(0, im_w - self.width) img = img[:, p:p + self.width] im_h, im_w = img.shape[0:2] assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # augmentation if self.aug is not None: # TODO color aug does not work auged = self.aug(image=img, image2=hint_img) img = auged['image'] hint_img = auged['image2'] image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる hint_image = self.image_transforms(hint_img) # -1.0~1.0のtorch.Tensorになる images.append(image) hint_images.append(hint_image) caption = self.process_caption(image_info.caption) captions.append(caption) if not self.token_padding_disabled: # this option might be omitted in future input_ids_list.append(self.get_input_ids(caption)) example = {} example['loss_weights'] = torch.FloatTensor(loss_weights) if self.token_padding_disabled: # padding=True means pad in the batch example['input_ids'] = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids else: # batch processing seems to be good example['input_ids'] = torch.stack(input_ids_list) images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() example['images'] = images hint_images = torch.stack(hint_images) hint_images = hint_images.to(memory_format=torch.contiguous_format).float() example['hint_images'] = hint_images example['latents'] = None if self.debug_dataset: example['image_keys'] = bucket[image_index:image_index + self.batch_size] example['captions'] = captions return example def train(args): session_id = random.randint(0, 2**32) training_started_at = time.time() train_util.verify_training_args(args) train_util.prepare_dataset_args(args, True) cache_latents = args.cache_latents use_dreambooth_method = args.in_json is None if args.seed is not None: set_seed(args.seed) tokenizer = train_util.load_tokenizer(args) # データセットを準備する train_dataset = ImagesWithHintDataset(args.train_batch_size, args.train_data_dir, tokenizer, args.max_token_length, args.caption_extension, args.shuffle_caption, args.keep_tokens, args.resolution, args.enable_bucket, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps, args.bucket_no_upscale, args.flip_aug, args.color_aug, args.face_crop_aug_range, args.random_crop, args.debug_dataset) # 学習データのdropout率を設定する train_dataset.set_caption_dropout(args.caption_dropout_rate, args.caption_dropout_every_n_epochs, args.caption_tag_dropout_rate) train_dataset.make_buckets() if args.debug_dataset: train_util.debug_dataset(train_dataset) return if len(train_dataset) == 0: print("No data found. Please verify arguments (train_data_dir must be the parent of folders with images) / 画像がありません。引数指定を確認してください(train_data_dirには画像があるフォルダではなく、画像があるフォルダの親フォルダを指定する必要があります)") return # acceleratorを準備する print("prepare accelerator") accelerator, unwrap_model = train_util.prepare_accelerator(args) # mixed precisionに対応した型を用意しておき適宜castする weight_dtype, save_dtype = train_util.prepare_dtype(args) # モデルを読み込む text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype) # モデルに xformers とか memory efficient attention を組み込む train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers) # 学習を準備する if cache_latents: vae.to(accelerator.device, dtype=weight_dtype) vae.requires_grad_(False) vae.eval() with torch.no_grad(): train_dataset.cache_latents(vae) vae.to("cpu") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # prepare network print("import network module:", args.network_module) network_module = importlib.import_module(args.network_module) net_kwargs = {} if args.network_args is not None: for net_arg in args.network_args: key, value = net_arg.split('=') net_kwargs[key] = value # if a new network is added in future, add if ~ then blocks for each network (;'∀') network: control_net_rola.LoRANetwork = network_module.create_network( 1.0, args.network_dim, args.network_alpha, vae, text_encoder, unet, **net_kwargs) if network is None: return if args.network_weights is not None: print("load network weights from:", args.network_weights) network.load_weights(args.network_weights) train_unet = not args.network_train_text_encoder_only train_text_encoder = not args.network_train_unet_only network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() text_encoder.gradient_checkpointing_enable() network.enable_gradient_checkpointing() # may have no effect # 学習に必要なクラスを準備する print("prepare optimizer, data loader etc.") # 8-bit Adamを使う if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです") print("use 8-bit Adam optimizer") optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略 optimizer = optimizer_class(trainable_params, lr=args.learning_rate) # dataloaderを準備する # DataLoaderのプロセス数:0はメインプロセスになる n_workers = min(args.max_data_loader_n_workers, os.cpu_count() - 1) # cpu_count-1 ただし最大で指定された数まで train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers) # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * len(train_dataloader) print(f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}") # lr schedulerを用意する # lr_scheduler = diffusers.optimization.get_scheduler( lr_scheduler = get_scheduler_fix( args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, num_cycles=args.lr_scheduler_num_cycles, power=args.lr_scheduler_power) # 実験的機能:勾配も含めたfp16学習を行う モデル全体をfp16にする if args.full_fp16: assert args.mixed_precision == "fp16", "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。" print("enable full fp16 training.") network.to(weight_dtype) # acceleratorがなんかよろしくやってくれるらしい if train_unet and train_text_encoder: unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, text_encoder, network, optimizer, train_dataloader, lr_scheduler) elif train_unet: unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler) elif train_text_encoder: text_encoder, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( text_encoder, network, optimizer, train_dataloader, lr_scheduler) else: network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( network, optimizer, train_dataloader, lr_scheduler) unet.requires_grad_(False) unet.to(accelerator.device, dtype=weight_dtype) text_encoder.requires_grad_(False) text_encoder.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: # according to TI example in Diffusers, train is required unet.train() text_encoder.train() # set top parameter requires_grad = True for gradient checkpointing works text_encoder.text_model.embeddings.requires_grad_(True) else: unet.eval() text_encoder.eval() network.prepare_grad_etc(text_encoder, unet) if not cache_latents: vae.requires_grad_(False) vae.eval() vae.to(accelerator.device, dtype=weight_dtype) # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする if args.full_fp16: train_util.patch_accelerator_for_fp16_training(accelerator) # resumeする if args.resume is not None: print(f"resume training from state: {args.resume}") accelerator.load_state(args.resume) # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0): args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1 # 学習する total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps print("running training / 学習開始") print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}") print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}") print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}") print(f" num epochs / epoch数: {num_train_epochs}") print(f" batch size per device / バッチサイズ: {args.train_batch_size}") print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}") print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}") print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}") metadata = { "ss_session_id": session_id, # random integer indicating which group of epochs the model came from "ss_training_started_at": training_started_at, # unix timestamp "ss_output_name": args.output_name, "ss_learning_rate": args.learning_rate, "ss_text_encoder_lr": args.text_encoder_lr, "ss_unet_lr": args.unet_lr, "ss_num_train_images": train_dataset.num_train_images, # includes repeating "ss_num_reg_images": train_dataset.num_reg_images, "ss_num_batches_per_epoch": len(train_dataloader), "ss_num_epochs": num_train_epochs, "ss_batch_size_per_device": args.train_batch_size, "ss_total_batch_size": total_batch_size, "ss_gradient_checkpointing": args.gradient_checkpointing, "ss_gradient_accumulation_steps": args.gradient_accumulation_steps, "ss_max_train_steps": args.max_train_steps, "ss_lr_warmup_steps": args.lr_warmup_steps, "ss_lr_scheduler": args.lr_scheduler, "ss_network_module": "control_net_" + args.network_module, "ss_network_dim": args.network_dim, # None means default because another network than LoRA may have another default dim "ss_network_alpha": args.network_alpha, # some networks may not use this value "ss_mixed_precision": args.mixed_precision, "ss_full_fp16": bool(args.full_fp16), "ss_v2": bool(args.v2), "ss_resolution": args.resolution, "ss_clip_skip": args.clip_skip, "ss_max_token_length": args.max_token_length, "ss_color_aug": bool(args.color_aug), "ss_flip_aug": bool(args.flip_aug), "ss_random_crop": bool(args.random_crop), "ss_shuffle_caption": bool(args.shuffle_caption), "ss_cache_latents": bool(args.cache_latents), "ss_enable_bucket": bool(train_dataset.enable_bucket), "ss_min_bucket_reso": train_dataset.min_bucket_reso, "ss_max_bucket_reso": train_dataset.max_bucket_reso, "ss_seed": args.seed, "ss_keep_tokens": args.keep_tokens, "ss_dataset_dirs": json.dumps(train_dataset.dataset_dirs_info), "ss_reg_dataset_dirs": json.dumps(train_dataset.reg_dataset_dirs_info), "ss_tag_frequency": json.dumps(train_dataset.tag_frequency), "ss_bucket_info": json.dumps(train_dataset.bucket_info), "ss_training_comment": args.training_comment # will not be updated after training } # uncomment if another network is added # for key, value in net_kwargs.items(): # metadata["ss_arg_" + key] = value if args.pretrained_model_name_or_path is not None: sd_model_name = args.pretrained_model_name_or_path if os.path.exists(sd_model_name): metadata["ss_sd_model_hash"] = train_util.model_hash(sd_model_name) metadata["ss_new_sd_model_hash"] = train_util.calculate_sha256(sd_model_name) sd_model_name = os.path.basename(sd_model_name) metadata["ss_sd_model_name"] = sd_model_name if args.vae is not None: vae_name = args.vae if os.path.exists(vae_name): metadata["ss_vae_hash"] = train_util.model_hash(vae_name) metadata["ss_new_vae_hash"] = train_util.calculate_sha256(vae_name) vae_name = os.path.basename(vae_name) metadata["ss_vae_name"] = vae_name metadata = {k: str(v) for k, v in metadata.items()} progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") global_step = 0 noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False) if accelerator.is_main_process: accelerator.init_trackers("network_train") for epoch in range(num_train_epochs): print(f"epoch {epoch+1}/{num_train_epochs}") train_dataset.set_current_epoch(epoch + 1) metadata["ss_epoch"] = str(epoch+1) network.on_epoch_start(text_encoder, unet) loss_total = 0 for step, batch in enumerate(train_dataloader): with accelerator.accumulate(network): with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) else: # latentに変換 latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() hint_latents = vae.encode(batch["hint_images"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 hint_latents = hint_latents * 0.18215 # hint = torch.nn.functional.interpolate(batch["hint_images"], scale_factor=(1/8, 1/8), mode="bilinear") # hint = hint[:, 0].unsqueeze(1) # RGB -> BW b_size = latents.shape[0] with torch.set_grad_enabled(train_text_encoder): # Get the text embedding for conditioning input_ids = batch["input_ids"].to(accelerator.device) encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype) # Sample noise that we'll add to the latents noise = torch.randn_like(latents, device=latents.device) # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Predict the noise residual network.set_as_control_path(True) unet(hint_latents, timesteps, encoder_hidden_states) # めちゃくちゃ乱暴だが入力にhintを加える # unet(noisy_latents * hint, timesteps, encoder_hidden_states) # めちゃくちゃ乱暴だが入力にhintを乗算 network.set_as_control_path(False) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if args.v_parameterization: # v-parameterization training target = noise_scheduler.get_velocity(latents, noise, timesteps) else: target = noise loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight loss = loss * loss_weights loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = network.get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 current_loss = loss.detach().item() loss_total += current_loss avr_loss = loss_total / (step+1) logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) if args.logging_dir is not None: logs = generate_step_logs(args, current_loss, avr_loss, lr_scheduler) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break if args.logging_dir is not None: logs = {"loss/epoch": loss_total / len(train_dataloader)} accelerator.log(logs, step=epoch+1) accelerator.wait_for_everyone() if args.save_every_n_epochs is not None: model_name = train_util.DEFAULT_EPOCH_NAME if args.output_name is None else args.output_name def save_func(): ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, epoch + 1) + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"saving checkpoint: {ckpt_file}") unwrap_model(network).save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) def remove_old_func(old_epoch_no): old_ckpt_name = train_util.EPOCH_FILE_NAME.format(model_name, old_epoch_no) + '.' + args.save_model_as old_ckpt_file = os.path.join(args.output_dir, old_ckpt_name) if os.path.exists(old_ckpt_file): print(f"removing old checkpoint: {old_ckpt_file}") os.remove(old_ckpt_file) saving = train_util.save_on_epoch_end(args, save_func, remove_old_func, epoch + 1, num_train_epochs) if saving and args.save_state: train_util.save_state_on_epoch_end(args, accelerator, model_name, epoch + 1) # end of epoch metadata["ss_epoch"] = str(num_train_epochs) is_main_process = accelerator.is_main_process if is_main_process: network = unwrap_model(network) accelerator.end_training() if args.save_state: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す if is_main_process: os.makedirs(args.output_dir, exist_ok=True) model_name = train_util.DEFAULT_LAST_OUTPUT_NAME if args.output_name is None else args.output_name ckpt_name = model_name + '.' + args.save_model_as ckpt_file = os.path.join(args.output_dir, ckpt_name) print(f"save trained model to {ckpt_file}") network.save_weights(ckpt_file, save_dtype, None if args.no_metadata else metadata) print("model saved.") if __name__ == '__main__': parser = argparse.ArgumentParser() train_util.add_sd_models_arguments(parser) train_util.add_dataset_arguments(parser, True, True, True) train_util.add_training_arguments(parser, True) parser.add_argument("--no_metadata", action='store_true', help="do not save metadata in output model / メタデータを出力先モデルに保存しない") parser.add_argument("--save_model_as", type=str, default="safetensors", choices=[None, "ckpt", "pt", "safetensors"], help="format to save the model (default is .safetensors) / モデル保存時の形式(デフォルトはsafetensors)") parser.add_argument("--unet_lr", type=float, default=None, help="learning rate for U-Net / U-Netの学習率") parser.add_argument("--text_encoder_lr", type=float, default=None, help="learning rate for Text Encoder / Text Encoderの学習率") parser.add_argument("--lr_scheduler_num_cycles", type=int, default=1, help="Number of restarts for cosine scheduler with restarts / cosine with restartsスケジューラでのリスタート回数") parser.add_argument("--lr_scheduler_power", type=float, default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power") parser.add_argument("--network_weights", type=str, default=None, help="pretrained weights for network / 学習するネットワークの初期重み") parser.add_argument("--network_module", type=str, default=None, help='network module to train / 学習対象のネットワークのモジュール') parser.add_argument("--network_dim", type=int, default=None, help='network dimensions (depends on each network) / モジュールの次元数(ネットワークにより定義は異なります)') parser.add_argument("--network_alpha", type=float, default=1, help='alpha for LoRA weight scaling, default 1 (same as network_dim for same behavior as old version) / LoRaの重み調整のalpha値、デフォルト1(旧バージョンと同じ動作をするにはnetwork_dimと同じ値を指定)') parser.add_argument("--network_args", type=str, default=None, nargs='*', help='additional argmuments for network (key=value) / ネットワークへの追加の引数') parser.add_argument("--network_train_unet_only", action="store_true", help="only training U-Net part / U-Net関連部分のみ学習する") parser.add_argument("--network_train_text_encoder_only", action="store_true", help="only training Text Encoder part / Text Encoder関連部分のみ学習する") parser.add_argument("--training_comment", type=str, default=None, help="arbitrary comment string stored in metadata / メタデータに記録する任意のコメント文字列") args = parser.parse_args() train(args)