Update train_db.py

This commit is contained in:
gesen2egee
2024-08-04 14:53:46 +08:00
committed by GitHub
parent 31507b9901
commit 1db495127f

View File

@@ -2,7 +2,6 @@
# XXX dropped option: fine_tune
import argparse
import itertools
import math
import os
from multiprocessing import Value
@@ -41,11 +40,73 @@ from library.utils import setup_logging, add_logging_arguments
setup_logging()
import logging
import itertools
logger = logging.getLogger(__name__)
# perlin_noise,
def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args):
total_loss = 0.0
timesteps_list = [10, 350, 500, 650, 990]
with accelerator.accumulate(*training_models):
with torch.no_grad():
# latentに変換
if cache_latents:
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
else:
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
b_size = latents.shape[0]
with torch.set_grad_enabled(False), accelerator.autocast():
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(
tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
for fixed_timesteps in timesteps_list:
with torch.set_grad_enabled(False), accelerator.autocast():
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise = torch.randn_like(latents, device=latents.device)
b_size = latents.shape[0]
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
with accelerator.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss
average_loss = total_loss / len(timesteps_list)
return average_loss
def train(args):
train_util.verify_training_args(args)
@@ -81,9 +142,10 @@ def train(args):
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
val_dataset_group = None
current_epoch = Value("i", 0)
current_step = Value("i", 0)
@@ -148,6 +210,9 @@ def train(args):
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()
@@ -195,6 +260,15 @@ def train(args):
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
cyclic_val_dataloader = itertools.cycle(val_dataloader)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -296,6 +370,8 @@ def train(args):
train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet)
loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder()
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1
@@ -427,12 +503,33 @@ def train(args):
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if len(val_dataloader) > 0:
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
accelerator.print("Validating バリデーション処理...")
total_loss = 0.0
with torch.no_grad():
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
batch = next(cyclic_val_dataloader)
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
total_loss += loss.detach().item()
current_loss = total_loss / validation_steps
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
if args.logging_dir is not None:
logs = {"loss/current_val_loss": current_loss}
accelerator.log(logs, step=global_step)
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/average_val_loss": avr_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
@@ -515,7 +612,30 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--validation_seed",
type=int,
default=None,
help="Validation seed"
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset"
)
parser.add_argument(
"--validation_every_n_step",
type=int,
default=None,
help="Number of train steps for counting validation loss. By default, validation per train epoch is performed"
)
parser.add_argument(
"--max_validation_steps",
type=int,
default=None,
help="Number of max validation steps for counting validation loss. By default, validation will run entire validation dataset"
)
return parser