Merge branch 'sd3' into new_cache

This commit is contained in:
Kohya S
2024-12-04 20:44:42 +09:00
25 changed files with 1604 additions and 128 deletions

View File

@@ -681,8 +681,8 @@ def train(args):
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
global_step = 0
# noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
# noise_scheduler_copy = copy.deepcopy(noise_scheduler)
# only used to get timesteps, etc. TODO manage timesteps etc. separately
dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
if accelerator.is_main_process:
init_kwargs = {}
@@ -850,9 +850,8 @@ def train(args):
# 1,
# )
# calculate loss
loss = train_util.conditional_loss(
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
)
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler)
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])