add t5xxl max token length, support schnell

This commit is contained in:
Kohya S
2024-08-16 17:06:05 +09:00
parent 739a8969bc
commit 3921a4efda
3 changed files with 44 additions and 8 deletions

View File

@@ -44,11 +44,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
args.network_train_unet_only or not args.cache_text_encoder_outputs
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
if args.max_token_length is not None:
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
def get_flux_model_name(self, args):
return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
def load_target_model(self, args, weight_dtype, accelerator):
# currently offload to cpu for some models
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
name = self.get_flux_model_name(args)
# if we load to cpu, flux.to(fp8) takes a long time
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
@@ -104,7 +111,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return flux_lower
def get_tokenize_strategy(self, args):
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
name = self.get_flux_model_name(args)
if args.t5xxl_max_token_length is None:
if name == "schnell":
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
else:
t5xxl_max_token_length = args.t5xxl_max_token_length
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
@@ -239,7 +257,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
def encode_images_to_latents(self, args, accelerator, vae, images):
return vae.encode(images)
def shift_scale_latents(self, args, latents):
return latents
@@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser:
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,
default=None,
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
)
# copy from Diffusers
parser.add_argument(
"--weighting_scheme",