mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
work cn load and validation
This commit is contained in:
@@ -266,7 +266,7 @@ def train(args):
|
|||||||
flux.to(accelerator.device)
|
flux.to(accelerator.device)
|
||||||
|
|
||||||
# load controlnet
|
# load controlnet
|
||||||
controlnet = flux_utils.load_controlnet()
|
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||||
controlnet.train()
|
controlnet.train()
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
@@ -568,7 +568,7 @@ def train(args):
|
|||||||
|
|
||||||
# For --sample_at_first
|
# For --sample_at_first
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
if len(accelerator.trackers) > 0:
|
if len(accelerator.trackers) > 0:
|
||||||
# log empty object to commit the sample images to wandb
|
# log empty object to commit the sample images to wandb
|
||||||
@@ -718,7 +718,7 @@ def train(args):
|
|||||||
|
|
||||||
optimizer_eval_fn()
|
optimizer_eval_fn()
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||||
)
|
)
|
||||||
|
|
||||||
# 指定ステップごとにモデルを保存
|
# 指定ステップごとにモデルを保存
|
||||||
@@ -774,7 +774,7 @@ def train(args):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
flux_train_utils.sample_images(
|
flux_train_utils.sample_images(
|
||||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||||
)
|
)
|
||||||
optimizer_train_fn()
|
optimizer_train_fn()
|
||||||
|
|
||||||
@@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--controlnet_model_name_or_path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
|
||||||
)
|
|
||||||
# parser.add_argument(
|
|
||||||
# "--conditioning_data_dir",
|
|
||||||
# type=str,
|
|
||||||
# default=None,
|
|
||||||
# help="conditioning data directory / 条件付けデータのディレクトリ",
|
|
||||||
# )
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1142,11 +1142,11 @@ class ControlNetFlux(nn.Module):
|
|||||||
self.num_single_blocks = len(self.single_blocks)
|
self.num_single_blocks = len(self.single_blocks)
|
||||||
|
|
||||||
# add ControlNet blocks
|
# add ControlNet blocks
|
||||||
self.controlnet_blocks_for_double = nn.ModuleList([])
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
for _ in range(controlnet_depth):
|
for _ in range(controlnet_depth):
|
||||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||||
controlnet_block = zero_module(controlnet_block)
|
controlnet_block = zero_module(controlnet_block)
|
||||||
self.controlnet_blocks_for_double.append(controlnet_block)
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
self.controlnet_blocks_for_single = nn.ModuleList([])
|
self.controlnet_blocks_for_single = nn.ModuleList([])
|
||||||
for _ in range(0): # TODO
|
for _ in range(0): # TODO
|
||||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||||
@@ -1312,7 +1312,7 @@ class ControlNetFlux(nn.Module):
|
|||||||
|
|
||||||
controlnet_block_samples = ()
|
controlnet_block_samples = ()
|
||||||
controlnet_single_block_samples = ()
|
controlnet_single_block_samples = ()
|
||||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
|
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||||
block_sample = controlnet_block(block_sample)
|
block_sample = controlnet_block(block_sample)
|
||||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
|
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
|
||||||
|
|||||||
@@ -175,10 +175,6 @@ def sample_image_inference(
|
|||||||
|
|
||||||
# if negative_prompt is None:
|
# if negative_prompt is None:
|
||||||
# negative_prompt = ""
|
# negative_prompt = ""
|
||||||
if controlnet_image is not None:
|
|
||||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
|
||||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
|
||||||
|
|
||||||
height = max(64, height - height % 16) # round to divisible by 16
|
height = max(64, height - height % 16) # round to divisible by 16
|
||||||
width = max(64, width - width % 16) # round to divisible by 16
|
width = max(64, width - width % 16) # round to divisible by 16
|
||||||
logger.info(f"prompt: {prompt}")
|
logger.info(f"prompt: {prompt}")
|
||||||
@@ -232,6 +228,12 @@ def sample_image_inference(
|
|||||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||||
|
|
||||||
|
if controlnet_image is not None:
|
||||||
|
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||||
|
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||||
|
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||||
|
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||||
|
|
||||||
with accelerator.autocast(), torch.no_grad():
|
with accelerator.autocast(), torch.no_grad():
|
||||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||||
|
|
||||||
@@ -315,6 +317,8 @@ def denoise(
|
|||||||
):
|
):
|
||||||
# this is ignored for schnell
|
# this is ignored for schnell
|
||||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||||
|
|
||||||
|
|
||||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||||
model.prepare_block_swap_before_forward()
|
model.prepare_block_swap_before_forward()
|
||||||
@@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--controlnet",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--t5xxl_max_token_length",
|
"--t5xxl_max_token_length",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
@@ -153,15 +153,22 @@ def load_ae(
|
|||||||
return ae
|
return ae
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet():
|
def load_controlnet(
|
||||||
# TODO
|
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||||
is_schnell = False
|
):
|
||||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
logger.info("Building ControlNet")
|
||||||
with torch.device("cuda:0"):
|
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
is_schnell = False
|
||||||
# if transformer is not None:
|
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||||
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
with torch.device("meta"):
|
||||||
return controlnet
|
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
logger.info(f"Loading state dict from {ckpt_path}")
|
||||||
|
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||||
|
info = controlnet.load_state_dict(sd, strict=False, assign=True)
|
||||||
|
logger.info(f"Loaded ControlNet: {info}")
|
||||||
|
return controlnet
|
||||||
|
|
||||||
|
|
||||||
def load_clip_l(
|
def load_clip_l(
|
||||||
|
|||||||
Reference in New Issue
Block a user