mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 14:34:23 +00:00
fix: use strategy for tokenizer and latent caching
This commit is contained in:
@@ -12,7 +12,7 @@ import toml
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from library import deepspeed_utils
|
from library import deepspeed_utils, strategy_base, strategy_sd
|
||||||
from library.device_utils import init_ipex, clean_memory_on_device
|
from library.device_utils import init_ipex, clean_memory_on_device
|
||||||
|
|
||||||
init_ipex()
|
init_ipex()
|
||||||
@@ -73,7 +73,14 @@ def train(args):
|
|||||||
args.seed = random.randint(0, 2**32)
|
args.seed = random.randint(0, 2**32)
|
||||||
set_seed(args.seed)
|
set_seed(args.seed)
|
||||||
|
|
||||||
tokenizer = train_util.load_tokenizer(args)
|
tokenize_strategy = strategy_sd.SdTokenizeStrategy(args.v2, args.max_token_length, args.tokenizer_cache_dir)
|
||||||
|
strategy_base.TokenizeStrategy.set_strategy(tokenize_strategy)
|
||||||
|
tokenizer = tokenize_strategy.tokenizer
|
||||||
|
# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
|
||||||
|
latents_caching_strategy = strategy_sd.SdSdxlLatentsCachingStrategy(
|
||||||
|
True, args.cache_latents_to_disk, args.vae_batch_size, False
|
||||||
|
)
|
||||||
|
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)
|
||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
@@ -100,7 +107,7 @@ def train(args):
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
@@ -243,12 +250,7 @@ def train(args):
|
|||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
vae.eval()
|
vae.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
train_dataset_group.cache_latents(
|
train_dataset_group.new_cache_latents(vae, accelerator)
|
||||||
vae,
|
|
||||||
args.vae_batch_size,
|
|
||||||
args.cache_latents_to_disk,
|
|
||||||
accelerator.is_main_process,
|
|
||||||
)
|
|
||||||
vae.to("cpu")
|
vae.to("cpu")
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
@@ -267,6 +269,7 @@ def train(args):
|
|||||||
|
|
||||||
# dataloaderを準備する
|
# dataloaderを準備する
|
||||||
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
|
||||||
|
train_dataset_group.set_current_strategies()
|
||||||
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers
|
||||||
|
|
||||||
train_dataloader = torch.utils.data.DataLoader(
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
@@ -451,7 +454,7 @@ def train(args):
|
|||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
b_size = latents.shape[0]
|
b_size = latents.shape[0]
|
||||||
|
|
||||||
input_ids = batch["input_ids"].to(accelerator.device)
|
input_ids = batch["input_ids_list"][0].to(accelerator.device)
|
||||||
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
encoder_hidden_states = train_util.get_hidden_states(args, input_ids, tokenizer, text_encoder, weight_dtype)
|
||||||
|
|
||||||
# Sample noise that we'll add to the latents
|
# Sample noise that we'll add to the latents
|
||||||
|
|||||||
Reference in New Issue
Block a user