mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
train run
This commit is contained in:
@@ -103,11 +103,11 @@ def train(args):
|
|||||||
|
|
||||||
# データセットを準備する
|
# データセットを準備する
|
||||||
if args.dataset_class is None:
|
if args.dataset_class is None:
|
||||||
blueprint_generator = BlueprintGenerator(ConfigSanitizer(True, True, args.masked_loss, True))
|
blueprint_generator = BlueprintGenerator(ConfigSanitizer(False, False, True, True))
|
||||||
if args.dataset_config is not None:
|
if args.dataset_config is not None:
|
||||||
logger.info(f"Load dataset config from {args.dataset_config}")
|
logger.info(f"Load dataset config from {args.dataset_config}")
|
||||||
user_config = config_util.load_user_config(args.dataset_config)
|
user_config = config_util.load_user_config(args.dataset_config)
|
||||||
ignored = ["train_data_dir", "conditioing_data_dir"]
|
ignored = ["train_data_dir", "conditioning_data_dir"]
|
||||||
if any(getattr(args, attr) is not None for attr in ignored):
|
if any(getattr(args, attr) is not None for attr in ignored):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
|
||||||
@@ -263,10 +263,11 @@ def train(args):
|
|||||||
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
|
||||||
)
|
)
|
||||||
flux.requires_grad_(False)
|
flux.requires_grad_(False)
|
||||||
|
flux.to(accelerator.device)
|
||||||
|
|
||||||
# load controlnet
|
# load controlnet
|
||||||
controlnet = flux_utils.load_controlnet()
|
controlnet = flux_utils.load_controlnet()
|
||||||
controlnet.requires_grad_(True)
|
controlnet.train()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
controlnet.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
|
||||||
@@ -443,7 +444,8 @@ def train(args):
|
|||||||
|
|
||||||
clean_memory_on_device(accelerator.device)
|
clean_memory_on_device(accelerator.device)
|
||||||
|
|
||||||
if args.deepspeed:
|
# if args.deepspeed:
|
||||||
|
if True:
|
||||||
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
ds_model = deepspeed_utils.prepare_deepspeed_model(args, mmdit=controlnet)
|
||||||
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
|
||||||
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
@@ -612,8 +614,10 @@ def train(args):
|
|||||||
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
text_encoder_conds = text_encoding_strategy.encode_tokens(
|
||||||
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
flux_tokenize_strategy, [clip_l, t5xxl], input_ids, args.apply_t5_attn_mask
|
||||||
)
|
)
|
||||||
if args.full_fp16:
|
# if args.full_fp16:
|
||||||
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
# text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||||
|
# TODO: check
|
||||||
|
text_encoder_conds = [c.to(weight_dtype) for c in text_encoder_conds]
|
||||||
|
|
||||||
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps
|
||||||
|
|
||||||
@@ -629,10 +633,10 @@ def train(args):
|
|||||||
# pack latents and get img_ids
|
# pack latents and get img_ids
|
||||||
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
packed_noisy_model_input = flux_utils.pack_latents(noisy_model_input) # b, c, h*2, w*2 -> b, h*w, c*4
|
||||||
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
packed_latent_height, packed_latent_width = noisy_model_input.shape[2] // 2, noisy_model_input.shape[3] // 2
|
||||||
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device)
|
img_ids = flux_utils.prepare_img_ids(bsz, packed_latent_height, packed_latent_width).to(device=accelerator.device).to(weight_dtype)
|
||||||
|
|
||||||
# get guidance: ensure args.guidance_scale is float
|
# get guidance: ensure args.guidance_scale is float
|
||||||
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)
|
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device, dtype=weight_dtype)
|
||||||
|
|
||||||
# call model
|
# call model
|
||||||
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
|
||||||
@@ -640,10 +644,11 @@ def train(args):
|
|||||||
t5_attn_mask = None
|
t5_attn_mask = None
|
||||||
|
|
||||||
with accelerator.autocast():
|
with accelerator.autocast():
|
||||||
|
print("control start")
|
||||||
block_samples, block_single_samples = controlnet(
|
block_samples, block_single_samples = controlnet(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
img_ids=img_ids,
|
img_ids=img_ids,
|
||||||
controlnet_img=batch["conditioing_image"].to(accelerator.device),
|
controlnet_cond=batch["conditioning_images"].to(accelerator.device).to(weight_dtype),
|
||||||
txt=t5_out,
|
txt=t5_out,
|
||||||
txt_ids=txt_ids,
|
txt_ids=txt_ids,
|
||||||
y=l_pooled,
|
y=l_pooled,
|
||||||
@@ -651,6 +656,8 @@ def train(args):
|
|||||||
guidance=guidance_vec,
|
guidance=guidance_vec,
|
||||||
txt_attention_mask=t5_attn_mask,
|
txt_attention_mask=t5_attn_mask,
|
||||||
)
|
)
|
||||||
|
print("control end")
|
||||||
|
print("dit start")
|
||||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||||
model_pred = flux(
|
model_pred = flux(
|
||||||
img=packed_noisy_model_input,
|
img=packed_noisy_model_input,
|
||||||
@@ -796,7 +803,7 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
add_logging_arguments(parser)
|
add_logging_arguments(parser)
|
||||||
train_util.add_sd_models_arguments(parser) # TODO split this
|
train_util.add_sd_models_arguments(parser) # TODO split this
|
||||||
train_util.add_dataset_arguments(parser, True, True, True)
|
train_util.add_dataset_arguments(parser, False, True, True)
|
||||||
train_util.add_training_arguments(parser, False)
|
train_util.add_training_arguments(parser, False)
|
||||||
train_util.add_masked_loss_arguments(parser)
|
train_util.add_masked_loss_arguments(parser)
|
||||||
deepspeed_utils.add_deepspeed_arguments(parser)
|
deepspeed_utils.add_deepspeed_arguments(parser)
|
||||||
@@ -852,12 +859,12 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
default=None,
|
default=None,
|
||||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
# parser.add_argument(
|
||||||
"--conditioning_data_dir",
|
# "--conditioning_data_dir",
|
||||||
type=str,
|
# type=str,
|
||||||
default=None,
|
# default=None,
|
||||||
help="conditioning data directory / 条件付けデータのディレクトリ",
|
# help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||||
)
|
# )
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1042,20 +1042,20 @@ class Flux(nn.Module):
|
|||||||
if not self.blocks_to_swap:
|
if not self.blocks_to_swap:
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
for block_idx, block in enumerate(self.double_blocks):
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
if block_controlnet_hidden_states is not None:
|
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
for block in self.single_blocks:
|
for block_idx, block in enumerate(self.single_blocks):
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
if block_controlnet_single_hidden_states is not None:
|
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||||
else:
|
else:
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
for block_idx, block in enumerate(self.double_blocks):
|
||||||
self.offloader_double.wait_for_block(block_idx)
|
self.offloader_double.wait_for_block(block_idx)
|
||||||
|
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
if block_controlnet_hidden_states is not None:
|
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||||
|
|
||||||
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
||||||
@@ -1066,7 +1066,7 @@ class Flux(nn.Module):
|
|||||||
self.offloader_single.wait_for_block(block_idx)
|
self.offloader_single.wait_for_block(block_idx)
|
||||||
|
|
||||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
if block_controlnet_single_hidden_states is not None:
|
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||||
|
|
||||||
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
||||||
@@ -1121,14 +1121,14 @@ class ControlNetFlux(nn.Module):
|
|||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(controlnet_depth)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.single_blocks = nn.ModuleList(
|
self.single_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||||
for _ in range(0) # TMP
|
for _ in range(0) # TODO
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
controlnet_block = zero_module(controlnet_block)
|
controlnet_block = zero_module(controlnet_block)
|
||||||
self.controlnet_blocks_for_double.append(controlnet_block)
|
self.controlnet_blocks_for_double.append(controlnet_block)
|
||||||
self.controlnet_blocks_for_single = nn.ModuleList([])
|
self.controlnet_blocks_for_single = nn.ModuleList([])
|
||||||
for _ in range(controlnet_depth):
|
for _ in range(0): # TODO
|
||||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||||
controlnet_block = zero_module(controlnet_block)
|
controlnet_block = zero_module(controlnet_block)
|
||||||
self.controlnet_blocks_for_single.append(controlnet_block)
|
self.controlnet_blocks_for_single.append(controlnet_block)
|
||||||
@@ -1252,7 +1252,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
self,
|
self,
|
||||||
img: Tensor,
|
img: Tensor,
|
||||||
img_ids: Tensor,
|
img_ids: Tensor,
|
||||||
controlnet_img: Tensor,
|
controlnet_cond: Tensor,
|
||||||
txt: Tensor,
|
txt: Tensor,
|
||||||
txt_ids: Tensor,
|
txt_ids: Tensor,
|
||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
@@ -1265,10 +1265,10 @@ class ControlNetFlux(nn.Module):
|
|||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
img = self.img_in(img)
|
img = self.img_in(img)
|
||||||
controlnet_img = self.input_hint_block(controlnet_img)
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
controlnet_img = self.pos_embed_input(controlnet_img)
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
img = img + controlnet_img
|
img = img + controlnet_cond
|
||||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
if self.params.guidance_embed:
|
if self.params.guidance_embed:
|
||||||
if guidance is None:
|
if guidance is None:
|
||||||
@@ -1283,7 +1283,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
block_samples = ()
|
block_samples = ()
|
||||||
block_single_samples = ()
|
block_single_samples = ()
|
||||||
if not self.blocks_to_swap:
|
if not self.blocks_to_swap:
|
||||||
for block_idx, block in enumerate(self.double_blocks):
|
for block in self.double_blocks:
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||||
block_samples = block_samples + (img,)
|
block_samples = block_samples + (img,)
|
||||||
|
|
||||||
@@ -1315,7 +1315,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
|
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
|
||||||
block_sample = controlnet_block(block_sample)
|
block_sample = controlnet_block(block_sample)
|
||||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single):
|
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
|
||||||
block_sample = controlnet_block(block_sample)
|
block_sample = controlnet_block(block_sample)
|
||||||
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
|
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
|
||||||
|
|
||||||
|
|||||||
@@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps(
|
|||||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||||
|
|
||||||
return noisy_model_input, timesteps, sigmas
|
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||||
|
|
||||||
|
|
||||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ def load_controlnet():
|
|||||||
# TODO
|
# TODO
|
||||||
is_schnell = False
|
is_schnell = False
|
||||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||||
with torch.device("meta"):
|
with torch.device("cuda:0"):
|
||||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
||||||
# if transformer is not None:
|
# if transformer is not None:
|
||||||
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user