This commit is contained in:
gesen2egee
2024-03-10 04:37:16 +08:00
parent 2d7389185c
commit b558a5b73d
3 changed files with 237 additions and 88 deletions

View File

@@ -136,6 +136,67 @@ class NetworkTrainer:
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
def process_val_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
total_loss = 0.0
timesteps_list = [10, 350, 500, 650, 990]
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
else:
# latentに変換
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae_dtype)).latent_dist.sample()
# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.where(torch.isnan(latents), torch.zeros_like(latents), latents)
latents = latents * self.vae_scale_factor
b_size = latents.shape[0]
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
tokenizers[0],
text_encoders[0],
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
text_encoder_conds = self.get_text_cond(
args, accelerator, batch, tokenizers, text_encoders, weight_dtype
)
# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, _ = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)
for timesteps in timesteps_list:
# Predict the noise residual
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
)
if args.v_parameterization:
# v-parameterization training
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
total_loss += loss
average_loss = total_loss / len(timesteps_list)
return average_loss
def train(self, args):
session_id = random.randint(0, 2**32)
training_started_at = time.time()
@@ -196,11 +257,12 @@ class NetworkTrainer:
}
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
else:
# use arbitrary dataset class
train_dataset_group = train_util.load_arbitrary_dataset(args, tokenizer)
val_dataset_group = None # placeholder until validation dataset supported for arbitrary
current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None
@@ -219,7 +281,11 @@ class NetworkTrainer:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
if val_dataset_group is not None:
assert (
val_dataset_group.is_latent_cacheable()
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
self.assert_extra_args(args, train_dataset_group)
# acceleratorを準備する
@@ -271,6 +337,9 @@ class NetworkTrainer:
vae.eval()
with torch.no_grad():
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
if val_dataset_group is not None:
print("Cache validation latents...")
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
vae.to("cpu")
clean_memory_on_device(accelerator.device)
@@ -360,6 +429,15 @@ class NetworkTrainer:
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset_group if val_dataset_group is not None else [],
shuffle=False,
batch_size=1,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)
# 学習ステップ数を計算する
if args.max_train_epochs is not None:
@@ -707,6 +785,8 @@ class NetworkTrainer:
)
loss_recorder = train_util.LossRecorder()
val_loss_recorder = train_util.LossRecorder()
del train_dataset_group
# callback for step start
@@ -755,7 +835,8 @@ class NetworkTrainer:
current_step.value = global_step
with accelerator.accumulate(network):
on_step_start(text_encoder, unet)
is_train = True
with torch.no_grad():
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(accelerator.device)
@@ -780,7 +861,7 @@ class NetworkTrainer:
# print(f"set multiplier: {multipliers}")
accelerator.unwrap_model(network).set_multiplier(multipliers)
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
text_encoder_conds = get_weighted_text_embeddings(
@@ -810,7 +891,7 @@ class NetworkTrainer:
t.requires_grad_(True)
# Predict the noise residual
with accelerator.autocast():
with torch.set_grad_enabled(is_train), accelerator.autocast():
noise_pred = self.call_unet(
args,
accelerator,
@@ -844,7 +925,7 @@ class NetworkTrainer:
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
accelerator.backward(loss)
if accelerator.sync_gradients:
self.all_reduce_network(accelerator, network) # sync DDP grad manually
@@ -898,14 +979,38 @@ class NetworkTrainer:
if args.logging_dir is not None:
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
accelerator.log(logs, step=global_step)
if global_step % 25 == 0:
if len(val_dataloader) > 0:
print("Validating バリデーション処理...")
with torch.no_grad():
val_dataloader_iter = iter(val_dataloader)
batch = next(val_dataloader_iter)
is_train = False
loss = self.process_val_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
current_loss = loss.detach().item()
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_current": current_loss}
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.logging_dir is not None:
logs = {"loss/epoch": loss_recorder.moving_average}
logs = {"loss/epoch_average": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)
if len(val_dataloader) > 0:
if args.logging_dir is not None:
avr_loss: float = val_loss_recorder.moving_average
logs = {"loss/validation_epoch_average": avr_loss}
accelerator.log(logs, step=epoch + 1)
accelerator.wait_for_everyone()
# 指定エポックごとにモデルを保存
@@ -1045,6 +1150,18 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う",
)
parser.add_argument(
"--validation_seed",
type=int,
default=None,
help="Validation seed"
)
parser.add_argument(
"--validation_split",
type=float,
default=0.0,
help="Split for validation images out of the training dataset"
)
return parser