Text Encoder cache (WIP)

This commit is contained in:
Kohya S
2024-11-27 12:57:04 +09:00
parent bdac55ebbc
commit 3677094256
15 changed files with 628 additions and 471 deletions

View File

@@ -10,8 +10,6 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
from library.utils import setup_logging
setup_logging()
@@ -19,6 +17,9 @@ import logging
logger = logging.getLogger(__name__)
from library import flux_models, flux_train_utils, flux_utils, sd3_train_utils, strategy_base, strategy_flux, train_util
import train_network
class FluxNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
@@ -174,13 +175,17 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
fluxTokenizeStrategy: strategy_flux.FluxTokenizeStrategy = strategy_base.TokenizeStrategy.get_strategy()
t5xxl_max_token_length = fluxTokenizeStrategy.t5xxl_max_length
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_flux.FluxTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
t5xxl_max_token_length,
args.apply_t5_attn_mask,
is_partial=self.train_clip_l or self.train_t5xxl,
apply_t5_attn_mask=args.apply_t5_attn_mask,
)
else:
return None