mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
Update train_db.py
This commit is contained in:
148
train_db.py
148
train_db.py
@@ -46,67 +46,67 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# perlin_noise,
|
||||
def process_val_batch(*training_models, batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args):
|
||||
total_loss = 0.0
|
||||
timesteps_list = [10, 350, 500, 650, 990]
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
total_loss = 0.0
|
||||
timesteps_list = [10, 350, 500, 650, 990]
|
||||
|
||||
with accelerator.accumulate(*training_models):
|
||||
with torch.no_grad():
|
||||
# latentに変換
|
||||
if cache_latents:
|
||||
latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype)
|
||||
else:
|
||||
latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
b_size = latents.shape[0]
|
||||
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
if args.weighted_captions:
|
||||
encoder_hidden_states = get_weighted_text_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
batch["captions"],
|
||||
accelerator.device,
|
||||
args.max_token_length // 75 if args.max_token_length else 1,
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
else:
|
||||
input_ids = batch["input_ids"].to(accelerator.device)
|
||||
encoder_hidden_states = train_util.get_hidden_states(
|
||||
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
|
||||
)
|
||||
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
|
||||
for fixed_timesteps in timesteps_list:
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
for fixed_timesteps in timesteps_list:
|
||||
with torch.set_grad_enabled(False), accelerator.autocast():
|
||||
# Sample noise, sample a random timestep for each image, and add noise to the latents,
|
||||
# with noise offset and/or multires noise if specified
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
b_size = latents.shape[0]
|
||||
timesteps = torch.full((b_size,), fixed_timesteps, dtype=torch.long, device=latents.device)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
# Predict the noise residual
|
||||
with accelerator.autocast():
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
total_loss += loss
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
if args.masked_loss:
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
total_loss += loss
|
||||
|
||||
average_loss = total_loss / len(timesteps_list)
|
||||
return average_loss
|
||||
average_loss = total_loss / len(timesteps_list)
|
||||
return average_loss
|
||||
|
||||
def train(args):
|
||||
train_util.verify_training_args(args)
|
||||
@@ -210,8 +210,8 @@ def train(args):
|
||||
with torch.no_grad():
|
||||
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
if val_dataset_group is not None:
|
||||
print("Cache validation latents...")
|
||||
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
print("Cache validation latents...")
|
||||
val_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
|
||||
vae.to("cpu")
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
@@ -503,25 +503,25 @@ def train(args):
|
||||
avr_loss: float = loss_recorder.moving_average
|
||||
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
if len(val_dataloader) > 0:
|
||||
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / validation_steps
|
||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||
if len(val_dataloader) > 0:
|
||||
if (args.validation_every_n_step is not None and global_step % args.validation_every_n_step == 0) or (args.validation_every_n_step is None and step == len(train_dataloader) - 1) or global_step >= args.max_train_steps:
|
||||
accelerator.print("Validating バリデーション処理...")
|
||||
total_loss = 0.0
|
||||
with torch.no_grad():
|
||||
validation_steps = min(args.max_validation_steps, len(val_dataloader)) if args.max_validation_steps is not None else len(val_dataloader)
|
||||
for val_step in tqdm(range(validation_steps), desc='Validation Steps'):
|
||||
batch = next(cyclic_val_dataloader)
|
||||
loss = self.process_val_batch(batch, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
total_loss += loss.detach().item()
|
||||
current_loss = total_loss / validation_steps
|
||||
val_loss_recorder.add(epoch=0, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/current_val_loss": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/average_val_loss": avr_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
if args.logging_dir is not None:
|
||||
logs = {"loss/current_val_loss": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/average_val_loss": avr_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user