diff --git a/library/train_util.py b/library/train_util.py index 51610e70..cae0df84 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3,6 +3,7 @@ import argparse import ast import asyncio +import datetime import importlib import json import pathlib @@ -18,7 +19,7 @@ from typing import ( Tuple, Union, ) -from accelerate import Accelerator +from accelerate import Accelerator, InitProcessGroupKwargs import gc import glob import math @@ -2855,6 +2856,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)", + ) parser.add_argument( "--clip_skip", type=int, @@ -3786,6 +3790,7 @@ def prepare_accelerator(args: argparse.Namespace): mixed_precision=args.mixed_precision, log_with=log_with, project_dir=logging_dir, + kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))], ) return accelerator diff --git a/sdxl_train.py b/sdxl_train.py index 55c11f9c..c368f27c 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -287,6 +287,8 @@ def train(args): training_models.append(text_encoder2) # set require_grad=True later else: + text_encoder1.to(weight_dtype) + text_encoder2.to(weight_dtype) text_encoder1.requires_grad_(False) text_encoder2.requires_grad_(False) text_encoder1.eval() @@ -295,7 +297,7 @@ def train(args): # TextEncoderの出力をキャッシュする if args.cache_text_encoder_outputs: # Text Encodes are eval and no grad - with torch.no_grad(): + with torch.no_grad(), accelerator.autocast(): train_dataset_group.cache_text_encoder_outputs( (tokenizer1, tokenizer2), (text_encoder1, text_encoder2), @@ -315,25 +317,23 @@ def train(args): m.requires_grad_(True) if block_lrs is None: - params = [] - for m in training_models: - params.extend(m.parameters()) - params_to_optimize = params - - # calculate number of trainable parameters - n_params = 0 - for p in params: - n_params += p.numel() + params_to_optimize = [ + {"params": list(training_models[0].parameters()), "lr": args.learning_rate}, + ] else: params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net - for m in training_models[1:]: # Text Encoders if exists - params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate}) - # calculate number of trainable parameters - n_params = 0 - for params in params_to_optimize: - for p in params["params"]: - n_params += p.numel() + for m in training_models[1:]: # Text Encoders if exists + params_to_optimize.append({ + "params": list(m.parameters()), + "lr": args.learning_rate_te or args.learning_rate + }) + + # calculate number of trainable parameters + n_params = 0 + for params in params_to_optimize: + for p in params["params"]: + n_params += p.numel() accelerator.print(f"number of models: {len(training_models)}") accelerator.print(f"number of trainable parameters: {n_params}") @@ -396,8 +396,6 @@ def train(args): else: unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) (unet,) = train_util.transform_models_if_DDP([unet]) - text_encoder1.to(weight_dtype) - text_encoder2.to(weight_dtype) # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: @@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser: config_util.add_config_arguments(parser) custom_train_functions.add_custom_train_arguments(parser) sdxl_train_util.add_sdxl_training_arguments(parser) + parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder") parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する") parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する") diff --git a/sdxl_train_network.py b/sdxl_train_network.py index 2de57c0a..199c4e03 100644 --- a/sdxl_train_network.py +++ b/sdxl_train_network.py @@ -70,14 +70,16 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer): if torch.cuda.is_available(): torch.cuda.empty_cache() - dataset.cache_text_encoder_outputs( - tokenizers, - text_encoders, - accelerator.device, - weight_dtype, - args.cache_text_encoder_outputs_to_disk, - accelerator.is_main_process, - ) + # When TE is not be trained, it will not be prepared so we need to use explicit autocast + with accelerator.autocast(): + dataset.cache_text_encoder_outputs( + tokenizers, + text_encoders, + accelerator.device, + weight_dtype, + args.cache_text_encoder_outputs_to_disk, + accelerator.is_main_process, + ) text_encoders[0].to("cpu", dtype=torch.float32) # Text Encoder doesn't work with fp16 on CPU text_encoders[1].to("cpu", dtype=torch.float32) diff --git a/train_network.py b/train_network.py index 9deb5331..38934c74 100644 --- a/train_network.py +++ b/train_network.py @@ -109,6 +109,9 @@ class NetworkTrainer: def is_text_encoder_outputs_cached(self, args): return False + def is_train_text_encoder(self, args): + return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + def cache_text_encoder_outputs_if_needed( self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype ): @@ -310,7 +313,7 @@ class NetworkTrainer: args.scale_weight_norms = False train_unet = not args.network_train_text_encoder_only - train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args) + train_text_encoder = self.is_train_text_encoder(args) network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: @@ -403,6 +406,8 @@ class NetworkTrainer: unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( unet, network, optimizer, train_dataloader, lr_scheduler ) + for t_enc in text_encoders: + t_enc.to(accelerator.device, dtype=weight_dtype) elif train_text_encoder: if len(text_encoders) > 1: t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( @@ -767,7 +772,7 @@ class NetworkTrainer: latents = latents * self.vae_scale_factor b_size = latents.shape[0] - with torch.set_grad_enabled(train_text_encoder): + with torch.set_grad_enabled(train_text_encoder), accelerator.autocast(): # Get the text embedding for conditioning if args.weighted_captions: text_encoder_conds = get_weighted_text_embeddings(