Merge branch 'sd3' into new_cache

This commit is contained in:
Kohya S
2024-12-04 20:44:42 +09:00
25 changed files with 1604 additions and 128 deletions

View File

@@ -6,7 +6,8 @@ from typing import Any, Optional
import torch
from accelerate import Accelerator
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
@@ -177,7 +178,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
@@ -473,7 +474,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
return model_pred, target, timesteps, None, weighting
return model_pred, target, timesteps, weighting
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss