mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-16 17:02:45 +00:00
val
This commit is contained in:
@@ -81,23 +81,24 @@ class ControlNetSubsetParams(BaseSubsetParams):
|
||||
|
||||
@dataclass
|
||||
class BaseDatasetParams:
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
tokenizer: Union[CLIPTokenizer, List[CLIPTokenizer]] = None
|
||||
max_token_length: int = None
|
||||
resolution: Optional[Tuple[int, int]] = None
|
||||
network_multiplier: float = 1.0
|
||||
debug_dataset: bool = False
|
||||
validation_seed: Optional[int] = None
|
||||
validation_split: float = 0.0
|
||||
|
||||
@dataclass
|
||||
class DreamBoothDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
batch_size: int = 1
|
||||
enable_bucket: bool = False
|
||||
min_bucket_reso: int = 256
|
||||
max_bucket_reso: int = 1024
|
||||
bucket_reso_steps: int = 64
|
||||
bucket_no_upscale: bool = False
|
||||
prior_loss_weight: float = 1.0
|
||||
|
||||
@dataclass
|
||||
class FineTuningDatasetParams(BaseDatasetParams):
|
||||
batch_size: int = 1
|
||||
@@ -203,8 +204,9 @@ class ConfigSanitizer:
|
||||
"max_bucket_reso": int,
|
||||
"min_bucket_reso": int,
|
||||
"validation_seed": int,
|
||||
"validation_split": float,
|
||||
"validation_split": float,
|
||||
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
||||
"network_multiplier": float,
|
||||
}
|
||||
|
||||
# options handled by argparse but not handled by user config
|
||||
|
||||
@@ -122,6 +122,20 @@ IMAGE_TRANSFORMS = transforms.Compose(
|
||||
|
||||
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
print(f"Using validation seed: {validation_seed}")
|
||||
prevstate = random.getstate()
|
||||
random.seed(validation_seed)
|
||||
random.shuffle(paths)
|
||||
random.setstate(prevstate)
|
||||
else:
|
||||
random.shuffle(paths)
|
||||
|
||||
if is_train:
|
||||
return paths[0:math.ceil(len(paths) * (1 - validation_split))]
|
||||
else:
|
||||
return paths[len(paths) - round(len(paths) * validation_split):]
|
||||
|
||||
def split_train_val(paths, is_train, validation_split, validation_seed):
|
||||
if validation_seed is not None:
|
||||
@@ -1352,7 +1366,6 @@ class DreamBoothDataset(BaseDataset):
|
||||
self.is_train = is_train
|
||||
self.validation_split = validation_split
|
||||
self.validation_seed = validation_seed
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.size = min(self.width, self.height) # 短いほう
|
||||
self.prior_loss_weight = prior_loss_weight
|
||||
@@ -1405,10 +1418,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
return [], []
|
||||
|
||||
img_paths = glob_images(subset.image_dir, "*")
|
||||
|
||||
if self.validation_split > 0.0:
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
print(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
img_paths = split_train_val(img_paths, self.is_train, self.validation_split, self.validation_seed)
|
||||
logger.info(f"found directory {subset.image_dir} contains {len(img_paths)} image files")
|
||||
|
||||
# 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う
|
||||
captions = []
|
||||
|
||||
102
train_network.py
102
train_network.py
@@ -130,7 +130,9 @@ class NetworkTrainer:
|
||||
def sample_images(self, accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet):
|
||||
train_util.sample_images(accelerator, args, epoch, global_step, device, vae, tokenizer, text_encoder, unet)
|
||||
|
||||
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True):
|
||||
def process_batch(self, batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=True, timesteps_list=None):
|
||||
total_loss = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
if "latents" in batch and batch["latents"] is not None:
|
||||
latents = batch["latents"].to(accelerator.device)
|
||||
@@ -167,37 +169,40 @@ class NetworkTrainer:
|
||||
args, noise_scheduler, latents
|
||||
)
|
||||
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
# Use input timesteps_list or use described timesteps above
|
||||
timesteps_list = timesteps_list or [timesteps]
|
||||
for timesteps in timesteps_list:
|
||||
# Predict the noise residual
|
||||
with torch.set_grad_enabled(is_train), accelerator.autocast():
|
||||
noise_pred = self.call_unet(
|
||||
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
|
||||
)
|
||||
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
if args.v_parameterization:
|
||||
# v-parameterization training
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
loss_weights = batch["loss_weights"].to(accelerator.device) # 各sampleごとのweight
|
||||
loss = loss * loss_weights
|
||||
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
if args.min_snr_gamma:
|
||||
loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma)
|
||||
if args.scale_v_pred_loss_like_noise_pred:
|
||||
loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler)
|
||||
if args.v_pred_like_loss:
|
||||
loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss)
|
||||
if args.debiased_estimation_loss:
|
||||
loss = apply_debiased_estimation(loss, timesteps, noise_scheduler)
|
||||
|
||||
loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
return loss
|
||||
total_loss += loss.mean() # 平均なのでbatch_sizeで割る必要なし
|
||||
|
||||
average_loss = total_loss / len(timesteps_list)
|
||||
return average_loss
|
||||
|
||||
def train(self, args):
|
||||
session_id = random.randint(0, 2**32)
|
||||
@@ -283,10 +288,10 @@ class NetworkTrainer:
|
||||
train_dataset_group.is_latent_cacheable()
|
||||
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
if val_dataset_group is not None:
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
assert (
|
||||
val_dataset_group.is_latent_cacheable()
|
||||
), "when caching validation latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"
|
||||
|
||||
self.assert_extra_args(args, train_dataset_group)
|
||||
|
||||
# acceleratorを準備する
|
||||
@@ -430,6 +435,15 @@ class NetworkTrainer:
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
collate_fn=collator,
|
||||
num_workers=n_workers,
|
||||
persistent_workers=args.persistent_data_loader_workers,
|
||||
)
|
||||
|
||||
val_dataloader = torch.utils.data.DataLoader(
|
||||
val_dataset_group if val_dataset_group is not None else [],
|
||||
@@ -798,7 +812,6 @@ class NetworkTrainer:
|
||||
|
||||
loss_recorder = train_util.LossRecorder()
|
||||
val_loss_recorder = train_util.LossRecorder()
|
||||
|
||||
del train_dataset_group
|
||||
|
||||
# callback for step start
|
||||
@@ -848,7 +861,6 @@ class NetworkTrainer:
|
||||
on_step_start(text_encoder, unet)
|
||||
is_train = True
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, train_text_encoder=train_text_encoder)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
|
||||
params_to_clip = network.get_trainable_params()
|
||||
@@ -900,7 +912,25 @@ class NetworkTrainer:
|
||||
if args.logging_dir is not None:
|
||||
logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % 25 == 0:
|
||||
if len(val_dataloader) > 0:
|
||||
print("Validating バリデーション処理...")
|
||||
|
||||
with torch.no_grad():
|
||||
val_dataloader_iter = iter(val_dataloader)
|
||||
batch = next(val_dataloader_iter)
|
||||
is_train = False
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=global_step, loss=current_loss)
|
||||
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_current": current_loss}
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -912,7 +942,7 @@ class NetworkTrainer:
|
||||
with torch.no_grad():
|
||||
for val_step, batch in enumerate(val_dataloader):
|
||||
is_train = False
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args)
|
||||
loss = self.process_batch(batch, is_train, tokenizers, text_encoders, unet, vae, noise_scheduler, vae_dtype, weight_dtype, accelerator, args, timesteps_list=[10, 350, 500, 650, 990])
|
||||
|
||||
current_loss = loss.detach().item()
|
||||
val_loss_recorder.add(epoch=epoch, step=val_step, loss=current_loss)
|
||||
@@ -933,6 +963,12 @@ class NetworkTrainer:
|
||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
if len(val_dataloader) > 0:
|
||||
if args.logging_dir is not None:
|
||||
avr_loss: float = val_loss_recorder.moving_average
|
||||
logs = {"loss/validation_epoch_average": avr_loss}
|
||||
accelerator.log(logs, step=epoch + 1)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# 指定エポックごとにモデルを保存
|
||||
|
||||
Reference in New Issue
Block a user