mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Better implementation for te autocast (#895)
* Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * Better cache TE and TE lr * Fix with list * Add timeout settings * Fix arg style
This commit is contained in:
@@ -3,6 +3,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import datetime
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
@@ -18,7 +19,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator, InitProcessGroupKwargs
|
||||||
import gc
|
import gc
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
@@ -2855,6 +2856,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
|
||||||
) # TODO move to SDXL training, because it is not supported by SD1/2
|
) # TODO move to SDXL training, because it is not supported by SD1/2
|
||||||
|
parser.add_argument(
|
||||||
|
"--ddp_timeout", type=int, default=30, help="DDP timeout (min) / DDPのタイムアウト(min)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--clip_skip",
|
"--clip_skip",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -3786,6 +3790,7 @@ def prepare_accelerator(args: argparse.Namespace):
|
|||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
log_with=log_with,
|
log_with=log_with,
|
||||||
project_dir=logging_dir,
|
project_dir=logging_dir,
|
||||||
|
kwargs_handlers=[InitProcessGroupKwargs(timeout=datetime.timedelta(minutes=args.ddp_timeout))],
|
||||||
)
|
)
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
|
|||||||
@@ -287,6 +287,8 @@ def train(args):
|
|||||||
training_models.append(text_encoder2)
|
training_models.append(text_encoder2)
|
||||||
# set require_grad=True later
|
# set require_grad=True later
|
||||||
else:
|
else:
|
||||||
|
text_encoder1.to(weight_dtype)
|
||||||
|
text_encoder2.to(weight_dtype)
|
||||||
text_encoder1.requires_grad_(False)
|
text_encoder1.requires_grad_(False)
|
||||||
text_encoder2.requires_grad_(False)
|
text_encoder2.requires_grad_(False)
|
||||||
text_encoder1.eval()
|
text_encoder1.eval()
|
||||||
@@ -295,7 +297,7 @@ def train(args):
|
|||||||
# TextEncoderの出力をキャッシュする
|
# TextEncoderの出力をキャッシュする
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
# Text Encodes are eval and no grad
|
# Text Encodes are eval and no grad
|
||||||
with torch.no_grad():
|
with torch.no_grad(), accelerator.autocast():
|
||||||
train_dataset_group.cache_text_encoder_outputs(
|
train_dataset_group.cache_text_encoder_outputs(
|
||||||
(tokenizer1, tokenizer2),
|
(tokenizer1, tokenizer2),
|
||||||
(text_encoder1, text_encoder2),
|
(text_encoder1, text_encoder2),
|
||||||
@@ -315,19 +317,17 @@ def train(args):
|
|||||||
m.requires_grad_(True)
|
m.requires_grad_(True)
|
||||||
|
|
||||||
if block_lrs is None:
|
if block_lrs is None:
|
||||||
params = []
|
params_to_optimize = [
|
||||||
for m in training_models:
|
{"params": list(training_models[0].parameters()), "lr": args.learning_rate},
|
||||||
params.extend(m.parameters())
|
]
|
||||||
params_to_optimize = params
|
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
|
||||||
n_params = 0
|
|
||||||
for p in params:
|
|
||||||
n_params += p.numel()
|
|
||||||
else:
|
else:
|
||||||
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
params_to_optimize = get_block_params_to_optimize(training_models[0], block_lrs) # U-Net
|
||||||
|
|
||||||
for m in training_models[1:]: # Text Encoders if exists
|
for m in training_models[1:]: # Text Encoders if exists
|
||||||
params_to_optimize.append({"params": m.parameters(), "lr": args.learning_rate})
|
params_to_optimize.append({
|
||||||
|
"params": list(m.parameters()),
|
||||||
|
"lr": args.learning_rate_te or args.learning_rate
|
||||||
|
})
|
||||||
|
|
||||||
# calculate number of trainable parameters
|
# calculate number of trainable parameters
|
||||||
n_params = 0
|
n_params = 0
|
||||||
@@ -396,8 +396,6 @@ def train(args):
|
|||||||
else:
|
else:
|
||||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||||
(unet,) = train_util.transform_models_if_DDP([unet])
|
(unet,) = train_util.transform_models_if_DDP([unet])
|
||||||
text_encoder1.to(weight_dtype)
|
|
||||||
text_encoder2.to(weight_dtype)
|
|
||||||
|
|
||||||
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
# TextEncoderの出力をキャッシュするときにはCPUへ移動する
|
||||||
if args.cache_text_encoder_outputs:
|
if args.cache_text_encoder_outputs:
|
||||||
@@ -728,6 +726,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
config_util.add_config_arguments(parser)
|
config_util.add_config_arguments(parser)
|
||||||
custom_train_functions.add_custom_train_arguments(parser)
|
custom_train_functions.add_custom_train_arguments(parser)
|
||||||
sdxl_train_util.add_sdxl_training_arguments(parser)
|
sdxl_train_util.add_sdxl_training_arguments(parser)
|
||||||
|
parser.add_argument("--learning_rate_te", type=float, default=0.0, help="learning rate for text encoder")
|
||||||
|
|
||||||
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
|
||||||
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# When TE is not be trained, it will not be prepared so we need to use explicit autocast
|
||||||
|
with accelerator.autocast():
|
||||||
dataset.cache_text_encoder_outputs(
|
dataset.cache_text_encoder_outputs(
|
||||||
tokenizers,
|
tokenizers,
|
||||||
text_encoders,
|
text_encoders,
|
||||||
|
|||||||
@@ -109,6 +109,9 @@ class NetworkTrainer:
|
|||||||
def is_text_encoder_outputs_cached(self, args):
|
def is_text_encoder_outputs_cached(self, args):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_train_text_encoder(self, args):
|
||||||
|
return not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
||||||
|
|
||||||
def cache_text_encoder_outputs_if_needed(
|
def cache_text_encoder_outputs_if_needed(
|
||||||
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
self, args, accelerator, unet, vae, tokenizers, text_encoders, data_loader, weight_dtype
|
||||||
):
|
):
|
||||||
@@ -310,7 +313,7 @@ class NetworkTrainer:
|
|||||||
args.scale_weight_norms = False
|
args.scale_weight_norms = False
|
||||||
|
|
||||||
train_unet = not args.network_train_text_encoder_only
|
train_unet = not args.network_train_text_encoder_only
|
||||||
train_text_encoder = not args.network_train_unet_only and not self.is_text_encoder_outputs_cached(args)
|
train_text_encoder = self.is_train_text_encoder(args)
|
||||||
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
network.apply_to(text_encoder, unet, train_text_encoder, train_unet)
|
||||||
|
|
||||||
if args.network_weights is not None:
|
if args.network_weights is not None:
|
||||||
@@ -403,6 +406,8 @@ class NetworkTrainer:
|
|||||||
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
unet, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
unet, network, optimizer, train_dataloader, lr_scheduler
|
unet, network, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
for t_enc in text_encoders:
|
||||||
|
t_enc.to(accelerator.device, dtype=weight_dtype)
|
||||||
elif train_text_encoder:
|
elif train_text_encoder:
|
||||||
if len(text_encoders) > 1:
|
if len(text_encoders) > 1:
|
||||||
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
t_enc1, t_enc2, network, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -767,7 +772,7 @@ class NetworkTrainer:
|
|||||||
latents = latents * self.vae_scale_factor
|
latents = latents * self.vae_scale_factor
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
with torch.set_grad_enabled(train_text_encoder):
|
with torch.set_grad_enabled(train_text_encoder), accelerator.autocast():
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
if args.weighted_captions:
|
if args.weighted_captions:
|
||||||
text_encoder_conds = get_weighted_text_embeddings(
|
text_encoder_conds = get_weighted_text_embeddings(
|
||||||
|
|||||||
Reference in New Issue
Block a user