mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Remove unused train_util code, fix accelerate.log for wandb, add init_trackers library code
This commit is contained in:
@@ -5900,51 +5900,9 @@ def save_sd_model_on_train_end_common(
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.IntTensor:
|
||||
def get_timesteps(min_timestep: int, max_timestep: int, b_size: int, device: torch.device = torch.device("cpu")) -> torch.Tensor:
|
||||
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
|
||||
return timesteps
|
||||
|
||||
|
||||
|
||||
|
||||
def modify_noise(args, noise: torch.Tensor, latents: torch.Tensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Apply noise modifications like noise offset and multires noise
|
||||
"""
|
||||
if args.noise_offset:
|
||||
if args.noise_offset_random_strength:
|
||||
noise_offset = torch.rand(1, device=latents.device) * args.noise_offset
|
||||
else:
|
||||
noise_offset = args.noise_offset
|
||||
noise = custom_train_functions.apply_noise_offset(latents, noise, noise_offset, args.adaptive_noise_scale)
|
||||
if args.multires_noise_iterations:
|
||||
noise = custom_train_functions.pyramid_noise_like(
|
||||
noise, latents.device, args.multires_noise_iterations, args.multires_noise_discount
|
||||
)
|
||||
return noise
|
||||
|
||||
|
||||
def make_noise(args, latents: torch.Tensor) -> torch.FloatTensor:
|
||||
"""
|
||||
Make a noise tensor to denoise and apply noise modifications (noise offset, multires noise). See `modify_noise`
|
||||
"""
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents, device=latents.device)
|
||||
noise = modify_noise(args, noise, latents)
|
||||
|
||||
return typing.cast(torch.FloatTensor, noise)
|
||||
|
||||
|
||||
def make_random_timesteps(args, noise_scheduler: DDPMScheduler, batch_size: int, device: torch.device) -> torch.IntTensor:
|
||||
"""
|
||||
From args, produce random timesteps for each image in the batch
|
||||
"""
|
||||
min_timestep = 0 if args.min_timestep is None else args.min_timestep
|
||||
max_timestep = noise_scheduler.config.get('num_train_timesteps', 1000) if args.max_timestep is None else args.max_timestep
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = get_timesteps(min_timestep, max_timestep, batch_size, device)
|
||||
|
||||
timesteps = timesteps.long().to(device)
|
||||
return timesteps
|
||||
|
||||
|
||||
@@ -6457,6 +6415,30 @@ def sample_image_inference(
|
||||
wandb_tracker.log({f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption
|
||||
|
||||
|
||||
def init_trackers(accelerator: Accelerator, args: argparse.Namespace, default_tracker_name: str):
|
||||
"""
|
||||
Initialize experiment trackers with tracker specific behaviors
|
||||
"""
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
if args.wandb_run_name:
|
||||
init_kwargs["wandb"] = {"name": args.wandb_run_name}
|
||||
if args.log_tracker_config is not None:
|
||||
init_kwargs = toml.load(args.log_tracker_config)
|
||||
accelerator.init_trackers(
|
||||
default_tracker_name if args.log_tracker_name is None else args.log_tracker_name,
|
||||
config=get_sanitized_config_or_none(args),
|
||||
init_kwargs=init_kwargs,
|
||||
)
|
||||
|
||||
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
|
||||
import wandb
|
||||
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
|
||||
|
||||
# Define specific metrics to handle validation and epochs "steps"
|
||||
wandb_tracker.define_metric("epoch", hidden=True)
|
||||
wandb_tracker.define_metric("val_step", hidden=True)
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user