mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code
This commit is contained in:
@@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common(
|
|||||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||||
|
|
||||||
|
|
||||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor:
|
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor:
|
||||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||||
return timesteps
|
timesteps = timesteps.long().to(device)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
|
|
||||||
"""
|
|
||||||
Apply noise modifications like noise offset and multires noise
|
|
||||||
"""
|
|
||||||
if args.noise_offset:
|
|
||||||
if args.noise_offset_random_strength:
|
|
||||||
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
|
|
||||||
else:
|
|
||||||
noise_offset = args.noise_offset
|
|
||||||
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
|
|
||||||
if args.multires_noise_iterations:
|
|
||||||
noise = custom_train_functions.pyramid_noise_like(
|
|
||||||
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
|
||||||
)
|
|
||||||
return noise
|
|
||||||
|
|
||||||
|
|
||||||
def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor:
|
|
||||||
"""
|
|
||||||
Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise`
|
|
||||||
"""
|
|
||||||
# Sample noise that we'll add to the latents
|
|
||||||
noise = torch.randn_like(latents, device=latents.device)
|
|
||||||
noise = modify_noise(args, noise, latents)
|
|
||||||
|
|
||||||
return typing.cast(torch.FloatTensor, noise)
|
|
||||||
|
|
||||||
|
|
||||||
def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor:
|
|
||||||
"""
|
|
||||||
From args, produce random timesteps for each image in the batch
|
|
||||||
"""
|
|
||||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
|
||||||
max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep
|
|
||||||
|
|
||||||
# Sample a random timestep for each image
|
|
||||||
timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device)
|
|
||||||
|
|
||||||
return timesteps
|
return timesteps
|
||||||
|
|
||||||
|
|
||||||
@@ -6457,6 +6415,30 @@ def sample_image_inference(
|
|||||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||||
|
|
||||||
|
|
||||||
|
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||||
|
"""
|
||||||
|
Initialize experiment trackers with tracker specific behaviors
|
||||||
|
"""
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
init_kwargs = {}
|
||||||
|
if args.wandb_run_name:
|
||||||
|
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||||
|
if args.log_tracker_config is not None:
|
||||||
|
init_kwargs = toml.load(args.log_tracker_config)
|
||||||
|
accelerator.init_trackers(
|
||||||
|
default_tracker_name if args.log_tracker_name is None else args.log_tracker_name,
|
||||||
|
config=get_sanitized_config_or_none(args),
|
||||||
|
init_kwargs=init_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||||
|
import wandb
|
||||||
|
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||||
|
|
||||||
|
# Define specific metrics to handle validation and epochs "steps"
|
||||||
|
wandb_tracker.define_metric("epoch", hidden=True)
|
||||||
|
wandb_tracker.define_metric("val_step", hidden=True)
|
||||||
|
|
||||||
# endregion
|
# endregion
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -327,8 +327,8 @@ class NetworkTrainer:
|
|||||||
weight_dtype,
|
weight_dtype,
|
||||||
accelerator,
|
accelerator,
|
||||||
args,
|
args,
|
||||||
text_encoding_strategy: strategy_sd.SdTextEncodingStrategy,
|
text_encoding_strategy: strategy_base.TextEncodingStrategy,
|
||||||
tokenize_strategy: strategy_sd.SdTokenizeStrategy,
|
tokenize_strategy: strategy_base.TokenizeStrategy,
|
||||||
is_train=True,
|
is_train=True,
|
||||||
train_text_encoder=True,
|
train_text_encoder=True,
|
||||||
train_unet=True
|
train_unet=True
|
||||||
@@ -1183,17 +1183,7 @@ class NetworkTrainer:
|
|||||||
|
|
||||||
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
noise_scheduler = self.get_noise_scheduler(args, accelerator.device)
|
||||||
|
|
||||||
if accelerator.is_main_process:
|
train_util.init_trackers(accelerator, args, "network_train")
|
||||||
init_kwargs = {}
|
|
||||||
if args.wandb_run_name:
|
|
||||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
|
||||||
if args.log_tracker_config is not None:
|
|
||||||
init_kwargs = toml.load(args.log_tracker_config)
|
|
||||||
accelerator.init_trackers(
|
|
||||||
"network_train" if args.log_tracker_name is None else args.log_tracker_name,
|
|
||||||
config=train_util.get_sanitized_config_or_none(args),
|
|
||||||
init_kwargs=init_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_recorder = train_util.LossRecorder()
|
loss_recorder = train_util.LossRecorder()
|
||||||
val_step_loss_recorder = train_util.LossRecorder()
|
val_step_loss_recorder = train_util.LossRecorder()
|
||||||
@@ -1386,15 +1376,14 @@ class NetworkTrainer:
|
|||||||
mean_norm,
|
mean_norm,
|
||||||
maximum_norm
|
maximum_norm
|
||||||
)
|
)
|
||||||
# accelerator.log(logs, step=global_step)
|
accelerator.log(logs, step=global_step)
|
||||||
accelerator.log(logs)
|
|
||||||
|
|
||||||
# VALIDATION PER STEP
|
# VALIDATION PER STEP
|
||||||
should_validate_epoch = (
|
should_validate_step = (
|
||||||
args.validate_every_n_steps is not None
|
args.validate_every_n_steps is not None
|
||||||
and global_step % args.validate_every_n_steps == 0
|
and global_step % args.validate_every_n_steps == 0
|
||||||
)
|
)
|
||||||
if validation_steps > 0 and should_validate_epoch:
|
if validation_steps > 0 and should_validate_step:
|
||||||
accelerator.print("Validating バリデーション処理...")
|
accelerator.print("Validating バリデーション処理...")
|
||||||
|
|
||||||
val_progress_bar = tqdm(
|
val_progress_bar = tqdm(
|
||||||
@@ -1406,6 +1395,9 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# temporary, for batch processing
|
||||||
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
text_encoders,
|
text_encoders,
|
||||||
@@ -1428,18 +1420,22 @@ class NetworkTrainer:
|
|||||||
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
val_progress_bar.set_postfix({ "val_avg_loss": val_step_loss_recorder.moving_average })
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/step_validation_current": current_loss}
|
logs = {
|
||||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
"loss/validation/step/current": current_loss,
|
||||||
accelerator.log(logs)
|
"val_step": (epoch * validation_steps) + val_step,
|
||||||
|
}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
logs = {"loss/step_validation_average": val_step_loss_recorder.moving_average}
|
if is_tracking:
|
||||||
# accelerator.log(logs, step=global_step)
|
logs = {
|
||||||
accelerator.log(logs)
|
"loss/validation/step/average": val_step_loss_recorder.moving_average,
|
||||||
|
}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
# VALIDATION EPOCH
|
# EPOCH VALIDATION
|
||||||
should_validate_epoch = (
|
should_validate_epoch = (
|
||||||
(epoch + 1) % args.validate_every_n_epochs == 0
|
(epoch + 1) % args.validate_every_n_epochs == 0
|
||||||
if args.validate_every_n_epochs is not None
|
if args.validate_every_n_epochs is not None
|
||||||
@@ -1458,6 +1454,9 @@ class NetworkTrainer:
|
|||||||
if val_step >= validation_steps:
|
if val_step >= validation_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# temporary, for batch processing
|
||||||
|
self.on_step_start(args, accelerator, network, text_encoders, unet, batch, weight_dtype)
|
||||||
|
|
||||||
loss = self.process_batch(
|
loss = self.process_batch(
|
||||||
batch,
|
batch,
|
||||||
text_encoders,
|
text_encoders,
|
||||||
@@ -1480,21 +1479,22 @@ class NetworkTrainer:
|
|||||||
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
val_progress_bar.set_postfix({ "val_epoch_avg_loss": val_epoch_loss_recorder.moving_average })
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/epoch_validation_current": current_loss}
|
logs = {
|
||||||
# accelerator.log(logs, step=(len(val_dataloader) * epoch) + 1 + val_step)
|
"loss/validation/epoch_current": current_loss,
|
||||||
accelerator.log(logs)
|
"epoch": epoch + 1,
|
||||||
|
"val_step": (epoch * validation_steps) + val_step
|
||||||
|
}
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
avr_loss: float = val_epoch_loss_recorder.moving_average
|
avr_loss: float = val_epoch_loss_recorder.moving_average
|
||||||
logs = {"loss/epoch_validation_average": avr_loss}
|
logs = {"loss/validation/epoch_average": avr_loss, "epoch": epoch + 1}
|
||||||
# accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=global_step)
|
||||||
accelerator.log(logs)
|
|
||||||
|
|
||||||
# END OF EPOCH
|
# END OF EPOCH
|
||||||
if is_tracking:
|
if is_tracking:
|
||||||
logs = {"loss/epoch_average": loss_recorder.moving_average}
|
logs = {"loss/epoch_average": loss_recorder.moving_average, "epoch": epoch + 1}
|
||||||
# accelerator.log(logs, step=epoch + 1)
|
accelerator.log(logs, step=global_step)
|
||||||
accelerator.log(logs)
|
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user