mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge branch 'sd3' into new_cache
This commit is contained in:
@@ -681,8 +681,8 @@ def train(args):
|
||||
progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
|
||||
global_step = 0
|
||||
|
||||
# noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
||||
# noise_scheduler_copy = copy.deepcopy(noise_scheduler)
|
||||
# only used to get timesteps, etc. TODO manage timesteps etc. separately
|
||||
dummy_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3.0)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
init_kwargs = {}
|
||||
@@ -850,9 +850,8 @@ def train(args):
|
||||
# 1,
|
||||
# )
|
||||
# calculate loss
|
||||
loss = train_util.conditional_loss(
|
||||
model_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=None
|
||||
)
|
||||
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, dummy_scheduler)
|
||||
loss = train_util.conditional_loss(model_pred.float(), target.float(), args.loss_type, "none", huber_c)
|
||||
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
|
||||
loss = apply_masked_loss(loss, batch)
|
||||
loss = loss.mean([1, 2, 3])
|
||||
|
||||
Reference in New Issue
Block a user