From 3921a4efda1cd1d7d873177ea7f51b77c3f15d3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 16 Aug 2024 17:06:05 +0900 Subject: [PATCH] add t5xxl max token length, support schnell --- README.md | 8 ++++++++ flux_train_network.py | 32 ++++++++++++++++++++++++++++---- library/flux_models.py | 12 ++++++++---- 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index bdb6bf2e..6fb050df 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv The command to install PyTorch is as follows: `pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124` +Aug 16, 2024: + +FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model. + +Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell. + +Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training. + Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified. Aug 13, 2024: diff --git a/flux_train_network.py b/flux_train_network.py index 59b9d84b..b9a29c16 100644 --- a/flux_train_network.py +++ b/flux_train_network.py @@ -44,11 +44,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): args.network_train_unet_only or not args.cache_text_encoder_outputs ), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません" + if args.max_token_length is not None: + logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません") + train_dataset_group.verify_bucket_reso_steps(32) # TODO check this + def get_flux_model_name(self, args): + return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" + def load_target_model(self, args, weight_dtype, accelerator): # currently offload to cpu for some models - name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way + name = self.get_flux_model_name(args) + # if we load to cpu, flux.to(fp8) takes a long time model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu") @@ -104,7 +111,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): return flux_lower def get_tokenize_strategy(self, args): - return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir) + name = self.get_flux_model_name(args) + + if args.t5xxl_max_token_length is None: + if name == "schnell": + t5xxl_max_token_length = 256 + else: + t5xxl_max_token_length = 512 + else: + t5xxl_max_token_length = args.t5xxl_max_token_length + + logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}") + return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir) def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy): return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl] @@ -239,7 +257,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer): def encode_images_to_latents(self, args, accelerator, vae, images): return vae.encode(images) - + def shift_scale_latents(self, args, latents): return latents @@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser: help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required" + "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要", ) - + parser.add_argument( + "--t5xxl_max_token_length", + type=int, + default=None, + help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev" + " / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512", + ) # copy from Diffusers parser.add_argument( "--weighting_scheme", diff --git a/library/flux_models.py b/library/flux_models.py index 3c7766b8..ed0bc8c7 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -863,7 +863,8 @@ class Flux(nn.Module): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.enable_gradient_checkpointing() @@ -875,7 +876,8 @@ class Flux(nn.Module): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks + self.single_blocks: block.disable_gradient_checkpointing() @@ -972,7 +974,8 @@ class FluxUpper(nn.Module): self.time_in.enable_gradient_checkpointing() self.vector_in.enable_gradient_checkpointing() - self.guidance_in.enable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.enable_gradient_checkpointing() for block in self.double_blocks: block.enable_gradient_checkpointing() @@ -984,7 +987,8 @@ class FluxUpper(nn.Module): self.time_in.disable_gradient_checkpointing() self.vector_in.disable_gradient_checkpointing() - self.guidance_in.disable_gradient_checkpointing() + if self.guidance_in.__class__ != nn.Identity: + self.guidance_in.disable_gradient_checkpointing() for block in self.double_blocks: block.disable_gradient_checkpointing()