From f4840ef29ef67878d7c7ccec92bdce89c3b61c6d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 6 Jan 2025 10:52:07 -0500 Subject: [PATCH] Revert train_db.py --- train_db.py | 121 ++-------------------------------------------------- 1 file changed, 3 insertions(+), 118 deletions(-) diff --git a/train_db.py b/train_db.py index 398489ff..ad21f8d1 100644 --- a/train_db.py +++ b/train_db.py @@ -2,6 +2,7 @@ # XXX dropped option: fine_tune import argparse +import itertools import math import os from multiprocessing import Value @@ -41,73 +42,11 @@ import library.strategy_sd as strategy_sd setup_logging() import logging -import itertools logger = logging.getLogger(__name__) # perlin_noise, -def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args): - total_loss = 0.0 - timesteps_list = [10, 350, 500, 650, 990] - - with accelerator.accumulate(*training_models): - with torch.no_grad(): - # latentに変換 - if cache_latents: - latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) - else: - latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 - b_size = latents.shape[0] - with torch.set_grad_enabled(False), accelerator.autocast(): - if args.weighted_captions: - encoder_hidden_states = get_weighted_text_embeddings( - tokenizer, - text_encoder, - batch["captions"], - accelerator.device, - args.max_token_length // 75 if args.max_token_length else 1, - clip_skip=args.clip_skip, - ) - else: - input_ids = batch["input_ids"].to(accelerator.device) - encoder_hidden_states = train_util.get_hidden_states( - args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype - ) - - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - - for fixed_timesteps in timesteps_list: - with torch.set_grad_enabled(False), accelerator.autocast(): - # Sample noise, sample a random timestep for each image, and add noise to the latents, - # with noise offset and/or multires noise if specified - noise = torch.randn_like(latents, device=latents.device) - b_size = latents.shape[0] - timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Predict the noise residual - with accelerator.autocast(): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - - if args.v_parameterization: - # v-parameterization training - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - target = noise - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") - if args.masked_loss: - loss = apply_masked_loss(loss, batch) - loss = loss.mean([1, 2, 3]) - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) - loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし - total_loss += loss - - average_loss = total_loss / len(timesteps_list) - return average_loss def train(args): train_util.verify_training_args(args) @@ -150,10 +89,9 @@ def train(args): } blueprint = blueprint_generator.generate(user_config, args) - train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) + train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group) else: train_dataset_group = train_util.load_arbitrary_dataset(args) - val_dataset_group = None current_epoch = Value("i", 0) current_step = Value("i", 0) @@ -274,15 +212,6 @@ def train(args): num_workers=n_workers, persistent_workers=args.persistent_data_loader_workers, ) - val_dataloader = torch.utils.data.DataLoader( - val_dataset_group if val_dataset_group is not None else [], - shuffle=False, - batch_size=1, - collate_fn=collator, - num_workers=n_workers, - persistent_workers=args.persistent_data_loader_workers, - ) - cyclic_val_dataloader = itertools.cycle(val_dataloader) # 学習ステップ数を計算する if args.max_train_epochs is not None: @@ -393,8 +322,6 @@ def train(args): accelerator.log({}, step=0) loss_recorder = train_util.LossRecorder() - val_loss_recorder = train_util.LossRecorder() - for epoch in range(num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -525,25 +452,6 @@ def train(args): avr_loss: float = loss_recorder.moving_average logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) - if len(val_dataloader) > 0: - if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps: - accelerator.print("Validating バリデーション処理...") - total_loss = 0.0 - with torch.no_grad(): - validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader) - for val_step in tqdm(range(validation_steps), desc='Validation Steps'): - batch = next(cyclic_val_dataloader) - loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args) - total_loss += loss.detach().item() - current_loss = total_loss / validation_steps - val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss) - - if args.logging_dir is not None: - logs = {"loss/current_val_loss": current_loss} - accelerator.log(logs, step=global_step) - avr_loss: float = val_loss_recorder.moving_average - logs = {"loss/average_val_loss": avr_loss} - accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break @@ -634,30 +542,7 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) - parser.add_argument( - "--validation_seed", - type=int, - default=None, - help="Validation seed" - ) - parser.add_argument( - "--validation_split", - type=float, - default=0.0, - help="Split for validation images out of the training dataset" - ) - parser.add_argument( - "--validation_every_n_step", - type=int, - default=None, - help="Number of train steps for counting validation loss. By default, validation per train epoch is performed" - ) - parser.add_argument( - "--max_validation_steps", - type=int, - default=None, - help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset" - ) + return parser