mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-17 17:24:21 +00:00
Text Encoder cache (WIP)
This commit is contained in:
@@ -10,8 +10,6 @@ from library.device_utils import init_ipex, clean_memory_on_device
|
||||
|
||||
init_ipex()
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
from library.utils import setup_logging
|
||||
|
||||
setup_logging()
|
||||
@@ -19,6 +17,9 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
|
||||
import train_network
|
||||
|
||||
|
||||
class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
def __init__(self):
|
||||
@@ -174,13 +175,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
def get_text_encoder_outputs_caching_strategy(self, args):
|
||||
if args.cache_text_encoder_outputs:
|
||||
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
|
||||
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
|
||||
|
||||
# if the text encoders is trained, we need tokenization, so is_partial is True
|
||||
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
|
||||
args.cache_text_encoder_outputs_to_disk,
|
||||
args.text_encoder_batch_size,
|
||||
args.skip_cache_check,
|
||||
t5xxl_max_token_length,
|
||||
args.apply_t5_attn_mask,
|
||||
is_partial=self.train_clip_l or self.train_t5xxl,
|
||||
apply_t5_attn_mask=args.apply_t5_attn_mask,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user