mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-10 06:54:17 +00:00
add t5xxl max token length, support schnell
This commit is contained in:
@@ -9,6 +9,14 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
|
||||
The command to install PyTorch is as follows:
|
||||
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`
|
||||
|
||||
Aug 16, 2024:
|
||||
|
||||
FLUX.1 schnell model based training is now supported (but not tested). If the name of the model file contains `schnell`, the model is treated as a schnell model.
|
||||
|
||||
Added `--t5xxl_max_token_length` option to specify the maximum token length of T5XXL. The default is 512 in dev and 256 in schnell.
|
||||
|
||||
Previously, when `--max_token_length` was specified, that value was used, and 512 was used when omitted (default). Therefore, there is no impact if `--max_token_length` was not specified. If `--max_token_length` was specified, please specify `--t5xxl_max_token_length` instead. `--max_token_length` is ignored during FLUX.1 training.
|
||||
|
||||
Aug 14, 2024: Sample image generation during training is now supported. Specify options such as `--sample_prompts` and `--sample_every_n_epochs`. It will be very slow when `--split_mode` is specified.
|
||||
|
||||
Aug 13, 2024:
|
||||
|
||||
@@ -44,11 +44,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
args.network_train_unet_only or not args.cache_text_encoder_outputs
|
||||
), "network for Text Encoder cannot be trained with caching Text Encoder outputs / Text Encoderの出力をキャッシュしながらText Encoderのネットワークを学習することはできません"
|
||||
|
||||
if args.max_token_length is not None:
|
||||
logger.warning("max_token_length is not used in Flux training / max_token_lengthはFluxのトレーニングでは使用されません")
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
def get_flux_model_name(self, args):
|
||||
return "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev"
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
name = "schnell" if "schnell" in args.pretrained_model_name_or_path else "dev" # TODO change this to a more robust way
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
# if we load to cpu, flux.to(fp8) takes a long time
|
||||
model = flux_utils.load_flow_model(name, args.pretrained_model_name_or_path, weight_dtype, "cpu")
|
||||
|
||||
@@ -104,7 +111,18 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
|
||||
return flux_lower
|
||||
|
||||
def get_tokenize_strategy(self, args):
|
||||
return strategy_flux.FluxTokenizeStrategy(args.max_token_length, args.tokenizer_cache_dir)
|
||||
name = self.get_flux_model_name(args)
|
||||
|
||||
if args.t5xxl_max_token_length is None:
|
||||
if name == "schnell":
|
||||
t5xxl_max_token_length = 256
|
||||
else:
|
||||
t5xxl_max_token_length = 512
|
||||
else:
|
||||
t5xxl_max_token_length = args.t5xxl_max_token_length
|
||||
|
||||
logger.info(f"t5xxl_max_token_length: {t5xxl_max_token_length}")
|
||||
return strategy_flux.FluxTokenizeStrategy(t5xxl_max_token_length, args.tokenizer_cache_dir)
|
||||
|
||||
def get_tokenizers(self, tokenize_strategy: strategy_flux.FluxTokenizeStrategy):
|
||||
return [tokenize_strategy.clip_l, tokenize_strategy.t5xxl]
|
||||
@@ -470,7 +488,13 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
help="[EXPERIMENTAL] use split mode for Flux model, network arg `train_blocks=single` is required"
|
||||
+ "/[実験的] Fluxモデルの分割モードを使用する。ネットワーク引数`train_blocks=single`が必要",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="maximum token length for T5-XXL. if omitted, 256 for schnell and 512 for dev"
|
||||
" / T5-XXLの最大トークン長。省略された場合、schnellの場合は256、devの場合は512",
|
||||
)
|
||||
# copy from Diffusers
|
||||
parser.add_argument(
|
||||
"--weighting_scheme",
|
||||
|
||||
@@ -863,6 +863,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
@@ -875,6 +876,7 @@ class Flux(nn.Module):
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks + self.single_blocks:
|
||||
@@ -972,6 +974,7 @@ class FluxUpper(nn.Module):
|
||||
|
||||
self.time_in.enable_gradient_checkpointing()
|
||||
self.vector_in.enable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.enable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
@@ -984,6 +987,7 @@ class FluxUpper(nn.Module):
|
||||
|
||||
self.time_in.disable_gradient_checkpointing()
|
||||
self.vector_in.disable_gradient_checkpointing()
|
||||
if self.guidance_in.__class__ != nn.Identity:
|
||||
self.guidance_in.disable_gradient_checkpointing()
|
||||
|
||||
for block in self.double_blocks:
|
||||
|
||||
Reference in New Issue
Block a user