mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
train run
This commit is contained in:
@@ -103,11 +103,11 @@ def train(args):
|
||||
|
||||
# データセットを準備する
|
||||
if args.dataset_class is None:
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||
if args.dataset_config is not None:
|
||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||
user_config = config_util.load_user_config(args.dataset_config)
|
||||
ignored = ["train_data_dir", "conditioing_data_dir"]
|
||||
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||
if any(getattr(args, attr) is not None for attr in ignored):
|
||||
logger.warning(
|
||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||
@@ -263,10 +263,11 @@ def train(args):
|
||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||
)
|
||||
flux.requires_grad_(False)
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet()
|
||||
controlnet.requires_grad_(True)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||
@@ -443,7 +444,8 @@ def train(args):
|
||||
|
||||
clean_memory_on_device(accelerator.device)
|
||||
|
||||
if args.deepspeed:
|
||||
# if args.deepspeed:
|
||||
if True:
|
||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
@@ -612,8 +614,10 @@ def train(args):
|
||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||
)
|
||||
if args.full_fp16:
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
# if args.full_fp16:
|
||||
# text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
# TODO: check
|
||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||
|
||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||
|
||||
@@ -629,10 +633,10 @@ def train(args):
|
||||
# pack latents and get img_ids
|
||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype)
|
||||
|
||||
# get guidance: ensure args.guidance_scale is float
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# call model
|
||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||
@@ -640,10 +644,11 @@ def train(args):
|
||||
t5_attn_mask = None
|
||||
|
||||
with accelerator.autocast():
|
||||
print("control start")
|
||||
block_samples, block_single_samples = controlnet(
|
||||
img=packed_noisy_model_input,
|
||||
img_ids=img_ids,
|
||||
controlnet_img=batch["conditioing_image"].to(accelerator.device),
|
||||
controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype),
|
||||
txt=t5_out,
|
||||
txt_ids=txt_ids,
|
||||
y=l_pooled,
|
||||
@@ -651,6 +656,8 @@ def train(args):
|
||||
guidance=guidance_vec,
|
||||
txt_attention_mask=t5_attn_mask,
|
||||
)
|
||||
print("control end")
|
||||
print("dit start")
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
model_pred = flux(
|
||||
img=packed_noisy_model_input,
|
||||
@@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
|
||||
add_logging_arguments(parser)
|
||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||
train_util.add_dataset_arguments(parser, True, True, True)
|
||||
train_util.add_dataset_arguments(parser, False, True, True)
|
||||
train_util.add_training_arguments(parser, False)
|
||||
train_util.add_masked_loss_arguments(parser)
|
||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||
@@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conditioning_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--conditioning_data_dir",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
# )
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user