mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-18 01:30:02 +00:00
Merge branch 'sd3' into new_cache
This commit is contained in:
@@ -6,7 +6,8 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
from library.device_utils import clean_memory_on_device, init_ipex
|
||||
|
||||
init_ipex()
|
||||
|
||||
@@ -177,7 +178,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
if args.cache_text_encoder_outputs:
|
||||
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
|
||||
|
||||
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
@@ -473,7 +474,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
)
|
||||
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
|
||||
|
||||
return model_pred, target, timesteps, None, weighting
|
||||
return model_pred, target, timesteps, weighting
|
||||
|
||||
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user