From 4dd4cd6ec8c55fa94b53217181ed9c95e59eed56 Mon Sep 17 00:00:00 2001 From: minux302 Date: Mon, 18 Nov 2024 12:47:01 +0000 Subject: [PATCH] work cn load and validation --- flux_train_control_net.py | 20 ++++---------------- library/flux_models.py | 6 +++--- library/flux_train_utils.py | 18 ++++++++++++++---- library/flux_utils.py | 25 ++++++++++++++++--------- 4 files changed, 37 insertions(+), 32 deletions(-) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 791900d1..cbfac418 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -266,7 +266,7 @@ def train(args): flux.to(accelerator.device) # load controlnet - controlnet = flux_utils.load_controlnet() + controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors) controlnet.train() if args.gradient_checkpointing: @@ -568,7 +568,7 @@ def train(args): # For --sample_at_first optimizer_eval_fn() - flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs) + flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet) optimizer_train_fn() if len(accelerator.trackers) > 0: # log empty object to commit the sample images to wandb @@ -718,7 +718,7 @@ def train(args): optimizer_eval_fn() flux_train_utils.sample_images( - accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) # 指定ステップごとにモデルを保存 @@ -774,7 +774,7 @@ def train(args): # ) flux_train_utils.sample_images( - accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs + accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet ) optimizer_train_fn() @@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする", ) - parser.add_argument( - "--controlnet_model_name_or_path", - type=str, - default=None, - help="controlnet model name or path / controlnetのモデル名またはパス", - ) - # parser.add_argument( - # "--conditioning_data_dir", - # type=str, - # default=None, - # help="conditioning data directory / 条件付けデータのディレクトリ", - # ) return parser diff --git a/library/flux_models.py b/library/flux_models.py index 2fc21db9..4123b40e 100644 --- a/library/flux_models.py +++ b/library/flux_models.py @@ -1142,11 +1142,11 @@ class ControlNetFlux(nn.Module): self.num_single_blocks = len(self.single_blocks) # add ControlNet blocks - self.controlnet_blocks_for_double = nn.ModuleList([]) + self.controlnet_blocks = nn.ModuleList([]) for _ in range(controlnet_depth): controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) controlnet_block = zero_module(controlnet_block) - self.controlnet_blocks_for_double.append(controlnet_block) + self.controlnet_blocks.append(controlnet_block) self.controlnet_blocks_for_single = nn.ModuleList([]) for _ in range(0): # TODO controlnet_block = nn.Linear(self.hidden_size, self.hidden_size) @@ -1312,7 +1312,7 @@ class ControlNetFlux(nn.Module): controlnet_block_samples = () controlnet_single_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double): + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single): diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index de2ee030..dbbaba73 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -175,10 +175,6 @@ def sample_image_inference( # if negative_prompt is None: # negative_prompt = "" - if controlnet_image is not None: - controlnet_image = Image.open(controlnet_image).convert("RGB") - controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) - height = max(64, height - height % 16) # round to divisible by 16 width = max(64, width - width % 16) # round to divisible by 16 logger.info(f"prompt: {prompt}") @@ -232,6 +228,12 @@ def sample_image_inference( img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype) t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None + if controlnet_image is not None: + controlnet_image = Image.open(controlnet_image).convert("RGB") + controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS) + controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1) + controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device) + with accelerator.autocast(), torch.no_grad(): x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image) @@ -315,6 +317,8 @@ def denoise( ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) + + for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) model.prepare_block_swap_before_forward() @@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス(*.sftまたは*.safetensors)、float16が前提", ) parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)") + parser.add_argument( + "--controlnet", + type=str, + default=None, + help="path to controlnet (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)" + ) parser.add_argument( "--t5xxl_max_token_length", type=int, diff --git a/library/flux_utils.py b/library/flux_utils.py index 4a3817fd..fb7a3074 100644 --- a/library/flux_utils.py +++ b/library/flux_utils.py @@ -153,15 +153,22 @@ def load_ae( return ae -def load_controlnet(): - # TODO - is_schnell = False - name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL - with torch.device("cuda:0"): - controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params) - # if transformer is not None: - # controlnet.load_state_dict(transformer.state_dict(), strict=False) - return controlnet +def load_controlnet( + ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False +): + logger.info("Building ControlNet") + # is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path) + is_schnell = False + name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL + with torch.device("meta"): + controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype) + + if ckpt_path is not None: + logger.info(f"Loading state dict from {ckpt_path}") + sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype) + info = controlnet.load_state_dict(sd, strict=False, assign=True) + logger.info(f"Loaded ControlNet: {info}") + return controlnet def load_clip_l(