From 874353296304c753b452511a412472f8a3e4ba09 Mon Sep 17 00:00:00 2001 From: gesen2egee <79357052+gesen2egee@users.noreply.github.com> Date: Sun, 10 Mar 2024 04:37:16 +0800 Subject: [PATCH] val --- library/config_util.py | 32 +++++++------ library/train_util.py | 20 ++++++-- train_network.py | 102 ++++++++++++++++++++++++++++------------- 3 files changed, 102 insertions(+), 52 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 1bf7ed95..cb2c5b68 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams): @dataclass class BaseDatasetParams: - tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None - max_token_length: int = None - resolution: Optional[Tuple[int, int]] = None - debug_dataset: bool = False - validation_seed: Optional[int] = None - validation_split: float = 0.0 + tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None + max_token_length: int = None + resolution: Optional[Tuple[int, int]] = None + network_multiplier: float = 1.0 + debug_dataset: bool = False + validation_seed: Optional[int] = None + validation_split: float = 0.0 @dataclass class DreamBoothDatasetParams(BaseDatasetParams): - batch_size: int = 1 - enable_bucket: bool = False - min_bucket_reso: int = 256 - max_bucket_reso: int = 1024 - bucket_reso_steps: int = 64 - bucket_no_upscale: bool = False - prior_loss_weight: float = 1.0 - + batch_size: int = 1 + enable_bucket: bool = False + min_bucket_reso: int = 256 + max_bucket_reso: int = 1024 + bucket_reso_steps: int = 64 + bucket_no_upscale: bool = False + prior_loss_weight: float = 1.0 + @dataclass class FineTuningDatasetParams(BaseDatasetParams): batch_size: int = 1 @@ -203,8 +204,9 @@ class ConfigSanitizer: "max_bucket_reso": int, "min_bucket_reso": int, "validation_seed": int, - "validation_split": float, + "validation_split": float, "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int), + "network_multiplier": float, } # options handled by argparse but not handled by user config diff --git a/library/train_util.py b/library/train_util.py index 1979207b..2364d62b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose( TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz" +def split_train_val(paths, is_train, validation_split, validation_seed): + if validation_seed is not None: + print(f"Using validation seed: {validation_seed}") + prevstate = random.getstate() + random.seed(validation_seed) + random.shuffle(paths) + random.setstate(prevstate) + else: + random.shuffle(paths) + + if is_train: + return paths[0:math.ceil(len(paths) * (1 - validation_split))] + else: + return paths[len(paths) - round(len(paths) * validation_split):] def split_train_val(paths, is_train, validation_split, validation_seed): if validation_seed is not None: @@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset): self.is_train = is_train self.validation_split = validation_split self.validation_seed = validation_seed - self.batch_size = batch_size self.size = min(self.width, self.height) # 短いほう self.prior_loss_weight = prior_loss_weight @@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset): return [], [] img_paths = glob_images(subset.image_dir, "*") - if self.validation_split > 0.0: - img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) - print(f"found directory {subset.image_dir} contains {len(img_paths)} image files") + img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed) + logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files") # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う captions = [] diff --git a/train_network.py b/train_network.py index edd3ff94..48885503 100644 --- a/train_network.py +++ b/train_network.py @@ -130,7 +130,9 @@ class NetworkTrainer: def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet): train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet) - def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True): + def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None): + total_loss = 0.0 + with torch.no_grad(): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device) @@ -167,37 +169,40 @@ class NetworkTrainer: args, noise_scheduler, latents ) - # Predict the noise residual - with torch.set_grad_enabled(is_train), accelerator.autocast(): - noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype - ) + # Use input timesteps_list or use described timesteps above + timesteps_list = timesteps_list or [timesteps] + for timesteps in timesteps_list: + # Predict the noise residual + with torch.set_grad_enabled(is_train), accelerator.autocast(): + noise_pred = self.call_unet( + args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + ) - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - loss = loss.mean([1, 2, 3]) + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") + loss = loss.mean([1, 2, 3]) - loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight - loss = loss * loss_weights + loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight + loss = loss * loss_weights - if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) - if args.scale_v_pred_loss_like_noise_pred: - loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) - if args.v_pred_like_loss: - loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) - if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + if args.min_snr_gamma: + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + if args.scale_v_pred_loss_like_noise_pred: + loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) + if args.v_pred_like_loss: + loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) + if args.debiased_estimation_loss: + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - - return loss + total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし + average_loss = total_loss / len(timesteps_list) + return average_loss def train(self, args): session_id = random.randint(0, 2**32) @@ -283,10 +288,10 @@ class NetworkTrainer: train_dataset_group.is_latent_cacheable() ), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" if val_dataset_group is not None: - assert ( - val_dataset_group.is_latent_cacheable() - ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" - + assert ( + val_dataset_group.is_latent_cacheable() + ), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません" + self.assert_extra_args(args, train_dataset_group) # acceleratorを準備する @@ -430,6 +435,15 @@ class NetworkTrainer: num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) + + val_dataloader = torch.utils.data.DataLoader( + val_dataset_group if val_dataset_group is not None else [], + shuffle=False, + batch_size=1, + collate_fn=collator, + num_workers=n_workers, + persistent_workers=args.persistent_data_loader_workers, + ) val_dataloader = torch.utils.data.DataLoader( val_dataset_group if val_dataset_group is not None else [], @@ -798,7 +812,6 @@ class NetworkTrainer: loss_recorder = train_util.LossRecorder() val_loss_recorder = train_util.LossRecorder() - del train_dataset_group # callback for step start @@ -848,7 +861,6 @@ class NetworkTrainer: on_step_start(text_encoder, unet) is_train = True loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder) - accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = network.get_trainable_params() @@ -900,7 +912,25 @@ class NetworkTrainer: if args.logging_dir is not None: logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) accelerator.log(logs, step=global_step) + + if global_step % 25 == 0: + if len(val_dataloader) > 0: + print("Validating バリデーション処理...") + with torch.no_grad(): + val_dataloader_iter = iter(val_dataloader) + batch = next(val_dataloader_iter) + is_train = False + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) + + current_loss = loss.detach().item() + val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss) + + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_current": current_loss} + accelerator.log(logs, step=global_step) + if global_step >= args.max_train_steps: break @@ -912,7 +942,7 @@ class NetworkTrainer: with torch.no_grad(): for val_step, batch in enumerate(val_dataloader): is_train = False - loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) + loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990]) current_loss = loss.detach().item() val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss) @@ -933,6 +963,12 @@ class NetworkTrainer: logs = {"loss/epoch_average": loss_recorder.moving_average} accelerator.log(logs, step=epoch + 1) + if len(val_dataloader) > 0: + if args.logging_dir is not None: + avr_loss: float = val_loss_recorder.moving_average + logs = {"loss/validation_epoch_average": avr_loss} + accelerator.log(logs, step=epoch + 1) + accelerator.wait_for_everyone() # 指定エポックごとにモデルを保存