mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
work cn load and validation
This commit is contained in:
@@ -266,7 +266,7 @@ def train(args):
|
||||
flux.to(accelerator.device)
|
||||
|
||||
# load controlnet
|
||||
controlnet = flux_utils.load_controlnet()
|
||||
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
|
||||
controlnet.train()
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
@@ -568,7 +568,7 @@ def train(args):
|
||||
|
||||
# For --sample_at_first
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
|
||||
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
|
||||
optimizer_train_fn()
|
||||
if len(accelerator.trackers) > 0:
|
||||
# log empty object to commit the sample images to wandb
|
||||
@@ -718,7 +718,7 @@ def train(args):
|
||||
|
||||
optimizer_eval_fn()
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
|
||||
# 指定ステップごとにモデルを保存
|
||||
@@ -774,7 +774,7 @@ def train(args):
|
||||
# )
|
||||
|
||||
flux_train_utils.sample_images(
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
|
||||
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
|
||||
)
|
||||
optimizer_train_fn()
|
||||
|
||||
@@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser:
|
||||
action="store_true",
|
||||
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--controlnet_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="controlnet model name or path / controlnetのモデル名またはパス",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--conditioning_data_dir",
|
||||
# type=str,
|
||||
# default=None,
|
||||
# help="conditioning data directory / 条件付けデータのディレクトリ",
|
||||
# )
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@@ -1142,11 +1142,11 @@ class ControlNetFlux(nn.Module):
|
||||
self.num_single_blocks = len(self.single_blocks)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks_for_double = nn.ModuleList([])
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks_for_double.append(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.controlnet_blocks_for_single = nn.ModuleList([])
|
||||
for _ in range(0): # TODO
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
@@ -1312,7 +1312,7 @@ class ControlNetFlux(nn.Module):
|
||||
|
||||
controlnet_block_samples = ()
|
||||
controlnet_single_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
|
||||
|
||||
@@ -175,10 +175,6 @@ def sample_image_inference(
|
||||
|
||||
# if negative_prompt is None:
|
||||
# negative_prompt = ""
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||
|
||||
height = max(64, height - height % 16) # round to divisible by 16
|
||||
width = max(64, width - width % 16) # round to divisible by 16
|
||||
logger.info(f"prompt: {prompt}")
|
||||
@@ -232,6 +228,12 @@ def sample_image_inference(
|
||||
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
|
||||
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
|
||||
|
||||
if controlnet_image is not None:
|
||||
controlnet_image = Image.open(controlnet_image).convert("RGB")
|
||||
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
|
||||
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
|
||||
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
|
||||
|
||||
with accelerator.autocast(), torch.no_grad():
|
||||
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
|
||||
|
||||
@@ -315,6 +317,8 @@ def denoise(
|
||||
):
|
||||
# this is ignored for schnell
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
|
||||
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
@@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
|
||||
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提",
|
||||
)
|
||||
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)")
|
||||
parser.add_argument(
|
||||
"--controlnet",
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t5xxl_max_token_length",
|
||||
type=int,
|
||||
|
||||
@@ -153,15 +153,22 @@ def load_ae(
|
||||
return ae
|
||||
|
||||
|
||||
def load_controlnet():
|
||||
# TODO
|
||||
is_schnell = False
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device("cuda:0"):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
||||
# if transformer is not None:
|
||||
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||
return controlnet
|
||||
def load_controlnet(
|
||||
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
|
||||
):
|
||||
logger.info("Building ControlNet")
|
||||
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
|
||||
is_schnell = False
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device("meta"):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
|
||||
|
||||
if ckpt_path is not None:
|
||||
logger.info(f"Loading state dict from {ckpt_path}")
|
||||
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
|
||||
info = controlnet.load_state_dict(sd, strict=False, assign=True)
|
||||
logger.info(f"Loaded ControlNet: {info}")
|
||||
return controlnet
|
||||
|
||||
|
||||
def load_clip_l(
|
||||
|
||||
Reference in New Issue
Block a user