mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
add sdxl fine-tuning and LoRA
This commit is contained in:
@@ -798,6 +798,19 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def is_latent_cacheable(self):
|
||||
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self):
|
||||
return all(
|
||||
[
|
||||
not (
|
||||
subset.caption_dropout_rate > 0
|
||||
or subset.shuffle_caption
|
||||
or subset.token_warmup_step > 0
|
||||
or subset.caption_tag_dropout_rate > 0
|
||||
)
|
||||
for subset in self.subsets
|
||||
]
|
||||
)
|
||||
|
||||
def is_disk_cached_latents_is_expected(self, reso, npz_path, flipped_npz_path):
|
||||
expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意
|
||||
|
||||
@@ -850,7 +863,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
continue
|
||||
|
||||
cache_available = self.is_disk_cached_latents_is_expected(
|
||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if self.flip_aug else None
|
||||
info.bucket_reso, info.latents_npz, info.latents_npz_flipped if subset.flip_aug else None
|
||||
)
|
||||
|
||||
if cache_available: # do not add to batch
|
||||
@@ -1719,6 +1732,9 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
def is_latent_cacheable(self) -> bool:
|
||||
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
|
||||
|
||||
def is_text_encoder_output_cacheable(self) -> bool:
|
||||
return all([dataset.is_text_encoder_output_cacheable() for dataset in self.datasets])
|
||||
|
||||
def set_current_epoch(self, epoch):
|
||||
for dataset in self.datasets:
|
||||
dataset.set_current_epoch(epoch)
|
||||
@@ -3284,11 +3300,17 @@ def _load_target_model(args: argparse.Namespace, weight_dtype, device="cpu", une
|
||||
return text_encoder, vae, unet, load_stable_diffusion_format
|
||||
|
||||
|
||||
# TODO remove this function in the future
|
||||
def transform_if_model_is_DDP(text_encoder, unet, network=None):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return (model.module if type(model) == DDP else model for model in [text_encoder, unet, network] if model is not None)
|
||||
|
||||
|
||||
def transform_models_if_DDP(models):
|
||||
# Transform text_encoder, unet and network from DistributedDataParallel
|
||||
return [model.module if type(model) == DDP else model for model in models if model is not None]
|
||||
|
||||
|
||||
def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projection_in_v2=False):
|
||||
# load models for each process
|
||||
for pi in range(accelerator.state.num_processes):
|
||||
@@ -3430,6 +3452,42 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
text_encoder,
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args,
|
||||
on_epoch_end,
|
||||
accelerator,
|
||||
save_stable_diffusion_format,
|
||||
use_safetensors,
|
||||
epoch,
|
||||
num_train_epochs,
|
||||
global_step,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
)
|
||||
|
||||
|
||||
def save_sd_model_on_epoch_end_or_stepwise_common(
|
||||
args: argparse.Namespace,
|
||||
on_epoch_end: bool,
|
||||
accelerator,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
epoch: int,
|
||||
num_train_epochs: int,
|
||||
global_step: int,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
):
|
||||
if on_epoch_end:
|
||||
epoch_no = epoch + 1
|
||||
@@ -3457,9 +3515,7 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
print(f"\nsaving checkpoint: {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
sd_saver(ckpt_file, epoch_no, global_step)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name)
|
||||
@@ -3483,9 +3539,8 @@ def save_sd_model_on_epoch_end_or_stepwise(
|
||||
out_dir = os.path.join(args.output_dir, STEP_DIFFUSERS_DIR_NAME.format(model_name, global_step))
|
||||
|
||||
print(f"\nsaving model: {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
diffusers_saver(out_dir)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name)
|
||||
|
||||
@@ -3578,6 +3633,30 @@ def save_sd_model_on_train_end(
|
||||
text_encoder,
|
||||
unet,
|
||||
vae,
|
||||
):
|
||||
def sd_saver(ckpt_file, epoch_no, global_step):
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch_no, global_step, save_dtype, vae
|
||||
)
|
||||
|
||||
def diffusers_saver(out_dir):
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
|
||||
save_sd_model_on_train_end_common(
|
||||
args, save_stable_diffusion_format, use_safetensors, epoch, global_step, sd_saver, diffusers_saver
|
||||
)
|
||||
|
||||
|
||||
def save_sd_model_on_train_end_common(
|
||||
args: argparse.Namespace,
|
||||
save_stable_diffusion_format: bool,
|
||||
use_safetensors: bool,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
sd_saver,
|
||||
diffusers_saver,
|
||||
):
|
||||
model_name = default_if_none(args.output_name, DEFAULT_LAST_OUTPUT_NAME)
|
||||
|
||||
@@ -3588,9 +3667,8 @@ def save_sd_model_on_train_end(
|
||||
ckpt_file = os.path.join(args.output_dir, ckpt_name)
|
||||
|
||||
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
|
||||
model_util.save_stable_diffusion_checkpoint(
|
||||
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
|
||||
)
|
||||
sd_saver(ckpt_file, epoch, global_step)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, ckpt_file, "/" + ckpt_name, force_sync_upload=True)
|
||||
else:
|
||||
@@ -3598,9 +3676,8 @@ def save_sd_model_on_train_end(
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
print(f"save trained model as Diffusers to {out_dir}")
|
||||
model_util.save_diffusers_checkpoint(
|
||||
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
|
||||
)
|
||||
diffusers_saver(out_dir)
|
||||
|
||||
if args.huggingface_repo_id is not None:
|
||||
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user