mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Merge branch 'dev' into gradual_latent_hires_fix
This commit is contained in:
@@ -1722,6 +1722,7 @@ class ControlNetDataset(BaseDataset):
|
|||||||
subset.caption_extension,
|
subset.caption_extension,
|
||||||
subset.num_repeats,
|
subset.num_repeats,
|
||||||
subset.shuffle_caption,
|
subset.shuffle_caption,
|
||||||
|
subset.caption_separator,
|
||||||
subset.keep_tokens,
|
subset.keep_tokens,
|
||||||
subset.color_aug,
|
subset.color_aug,
|
||||||
subset.flip_aug,
|
subset.flip_aug,
|
||||||
@@ -2979,9 +2980,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
"--sample_every_n_steps", type=int, default=None, help="generate sample images every N steps / 学習中のモデルで指定ステップごとにサンプル出力する"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--sample_at_first", action="store_true", help="generate sample images before training / 学習前にサンプル出力する")
|
||||||
"--sample_at_first", action='store_true', help="generate sample images before training / 学習前にサンプル出力する"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sample_every_n_epochs",
|
"--sample_every_n_epochs",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -3115,12 +3114,8 @@ def add_dataset_arguments(
|
|||||||
):
|
):
|
||||||
# dataset common
|
# dataset common
|
||||||
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
|
||||||
parser.add_argument(
|
parser.add_argument("--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする")
|
||||||
"--shuffle_caption", action="store_true", help="shuffle separated caption / 区切られたcaptionの各要素をshuffleする"
|
parser.add_argument("--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--caption_separator", type=str, default=",", help="separator for caption / captionの区切り文字"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
|
"--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子"
|
||||||
)
|
)
|
||||||
@@ -4048,6 +4043,7 @@ def get_hidden_states_sdxl(
|
|||||||
text_encoder1: CLIPTextModel,
|
text_encoder1: CLIPTextModel,
|
||||||
text_encoder2: CLIPTextModelWithProjection,
|
text_encoder2: CLIPTextModelWithProjection,
|
||||||
weight_dtype: Optional[str] = None,
|
weight_dtype: Optional[str] = None,
|
||||||
|
accelerator: Optional[Accelerator] = None,
|
||||||
):
|
):
|
||||||
# input_ids: b,n,77 -> b*n, 77
|
# input_ids: b,n,77 -> b*n, 77
|
||||||
b_size = input_ids1.size()[0]
|
b_size = input_ids1.size()[0]
|
||||||
@@ -4063,7 +4059,8 @@ def get_hidden_states_sdxl(
|
|||||||
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
|
||||||
|
|
||||||
# pool2 = enc_out["text_embeds"]
|
# pool2 = enc_out["text_embeds"]
|
||||||
pool2 = pool_workaround(text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
unwrapped_text_encoder2 = text_encoder2 if accelerator is None else accelerator.unwrap_model(text_encoder2)
|
||||||
|
pool2 = pool_workaround(unwrapped_text_encoder2, enc_out["last_hidden_state"], input_ids2, tokenizer2.eos_token_id)
|
||||||
|
|
||||||
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
|
||||||
n_size = 1 if max_token_length is None else max_token_length // 75
|
n_size = 1 if max_token_length is None else max_token_length // 75
|
||||||
@@ -4451,6 +4448,7 @@ SCHEDULER_LINEAR_END = 0.0120
|
|||||||
SCHEDULER_TIMESTEPS = 1000
|
SCHEDULER_TIMESTEPS = 1000
|
||||||
SCHEDLER_SCHEDULE = "scaled_linear"
|
SCHEDLER_SCHEDULE = "scaled_linear"
|
||||||
|
|
||||||
|
|
||||||
def get_my_scheduler(
|
def get_my_scheduler(
|
||||||
*,
|
*,
|
||||||
sample_sampler: str,
|
sample_sampler: str,
|
||||||
@@ -4495,10 +4493,7 @@ def get_my_scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# clip_sample=Trueにする
|
# clip_sample=Trueにする
|
||||||
if (
|
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is False:
|
||||||
hasattr(scheduler.config, "clip_sample")
|
|
||||||
and scheduler.config.clip_sample is False
|
|
||||||
):
|
|
||||||
# print("set clip_sample to True")
|
# print("set clip_sample to True")
|
||||||
scheduler.config.clip_sample = True
|
scheduler.config.clip_sample = True
|
||||||
|
|
||||||
@@ -4513,48 +4508,48 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
# subset of gen_img_diffusers
|
# subset of gen_img_diffusers
|
||||||
prompt_args = line.split(" --")
|
prompt_args = line.split(" --")
|
||||||
prompt_dict = {}
|
prompt_dict = {}
|
||||||
prompt_dict['prompt'] = prompt_args[0]
|
prompt_dict["prompt"] = prompt_args[0]
|
||||||
|
|
||||||
for parg in prompt_args:
|
for parg in prompt_args:
|
||||||
try:
|
try:
|
||||||
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
prompt_dict['width'] = int(m.group(1))
|
prompt_dict["width"] = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"h (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
prompt_dict['height'] = int(m.group(1))
|
prompt_dict["height"] = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"d (\d+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
prompt_dict['seed'] = int(m.group(1))
|
prompt_dict["seed"] = int(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
m = re.match(r"s (\d+)", parg, re.IGNORECASE)
|
||||||
if m: # steps
|
if m: # steps
|
||||||
prompt_dict['sample_steps'] = max(1, min(1000, int(m.group(1))))
|
prompt_dict["sample_steps"] = max(1, min(1000, int(m.group(1))))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
m = re.match(r"l ([\d\.]+)", parg, re.IGNORECASE)
|
||||||
if m: # scale
|
if m: # scale
|
||||||
prompt_dict['scale'] = float(m.group(1))
|
prompt_dict["scale"] = float(m.group(1))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
m = re.match(r"n (.+)", parg, re.IGNORECASE)
|
||||||
if m: # negative prompt
|
if m: # negative prompt
|
||||||
prompt_dict['negative_prompt'] = m.group(1)
|
prompt_dict["negative_prompt"] = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
m = re.match(r"ss (.+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
prompt_dict['sample_sampler'] = m.group(1)
|
prompt_dict["sample_sampler"] = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
m = re.match(r"cn (.+)", parg, re.IGNORECASE)
|
||||||
if m:
|
if m:
|
||||||
prompt_dict['controlnet_image'] = m.group(1)
|
prompt_dict["controlnet_image"] = m.group(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
except ValueError as ex:
|
except ValueError as ex:
|
||||||
@@ -4563,6 +4558,7 @@ def line_to_prompt_dict(line: str) -> dict:
|
|||||||
|
|
||||||
return prompt_dict
|
return prompt_dict
|
||||||
|
|
||||||
|
|
||||||
def sample_images_common(
|
def sample_images_common(
|
||||||
pipe_class,
|
pipe_class,
|
||||||
accelerator,
|
accelerator,
|
||||||
@@ -4663,7 +4659,7 @@ def sample_images_common(
|
|||||||
seed = prompt_dict.get("seed")
|
seed = prompt_dict.get("seed")
|
||||||
controlnet_image = prompt_dict.get("controlnet_image")
|
controlnet_image = prompt_dict.get("controlnet_image")
|
||||||
prompt: str = prompt_dict.get("prompt", "")
|
prompt: str = prompt_dict.get("prompt", "")
|
||||||
sampler_name:str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
@@ -4671,7 +4667,10 @@ def sample_images_common(
|
|||||||
|
|
||||||
scheduler = schedulers.get(sampler_name)
|
scheduler = schedulers.get(sampler_name)
|
||||||
if scheduler is None:
|
if scheduler is None:
|
||||||
scheduler = get_my_scheduler(sample_sampler=sampler_name, v_parameterization=args.v_parameterization,)
|
scheduler = get_my_scheduler(
|
||||||
|
sample_sampler=sampler_name,
|
||||||
|
v_parameterization=args.v_parameterization,
|
||||||
|
)
|
||||||
schedulers[sampler_name] = scheduler
|
schedulers[sampler_name] = scheduler
|
||||||
pipeline.scheduler = scheduler
|
pipeline.scheduler = scheduler
|
||||||
|
|
||||||
|
|||||||
@@ -505,6 +505,7 @@ def train(args):
|
|||||||
# else:
|
# else:
|
||||||
input_ids1 = input_ids1.to(accelerator.device)
|
input_ids1 = input_ids1.to(accelerator.device)
|
||||||
input_ids2 = input_ids2.to(accelerator.device)
|
input_ids2 = input_ids2.to(accelerator.device)
|
||||||
|
# unwrap_model is fine for models not wrapped by accelerator
|
||||||
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
encoder_hidden_states1, encoder_hidden_states2, pool2 = train_util.get_hidden_states_sdxl(
|
||||||
args.max_token_length,
|
args.max_token_length,
|
||||||
input_ids1,
|
input_ids1,
|
||||||
@@ -514,6 +515,7 @@ def train(args):
|
|||||||
text_encoder1,
|
text_encoder1,
|
||||||
text_encoder2,
|
text_encoder2,
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
if torch.xpu.is_available():
|
if torch.xpu.is_available():
|
||||||
from library.ipex import ipex_init
|
from library.ipex import ipex_init
|
||||||
|
|
||||||
ipex_init()
|
ipex_init()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
@@ -123,6 +126,7 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
|
|||||||
text_encoders[0],
|
text_encoders[0],
|
||||||
text_encoders[1],
|
text_encoders[1],
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
encoder_hidden_states1 = batch["text_encoder_outputs1_list"].to(accelerator.device).to(weight_dtype)
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class SdxlTextualInversionTrainer(train_textual_inversion.TextualInversionTraine
|
|||||||
text_encoders[0],
|
text_encoders[0],
|
||||||
text_encoders[1],
|
text_encoders[1],
|
||||||
None if not args.full_fp16 else weight_dtype,
|
None if not args.full_fp16 else weight_dtype,
|
||||||
|
accelerator=accelerator,
|
||||||
)
|
)
|
||||||
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
return encoder_hidden_states1, encoder_hidden_states2, pool2
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user