work cn load and validation

This commit is contained in:
minux302
2024-11-18 12:47:01 +00:00
parent 35778f0218
commit 4dd4cd6ec8
4 changed files with 37 additions and 32 deletions

View File

@@ -266,7 +266,7 @@ def train(args):
flux.to(accelerator.device)
# load controlnet
controlnet = flux_utils.load_controlnet()
controlnet = flux_utils.load_controlnet(args.controlnet, weight_dtype, "cpu", args.disable_mmap_load_safetensors)
controlnet.train()
if args.gradient_checkpointing:
@@ -568,7 +568,7 @@ def train(args):
# For --sample_at_first
optimizer_eval_fn()
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs)
flux_train_utils.sample_images(accelerator, args, 0, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
@@ -718,7 +718,7 @@ def train(args):
optimizer_eval_fn()
flux_train_utils.sample_images(
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, None, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
# 指定ステップごとにモデルを保存
@@ -774,7 +774,7 @@ def train(args):
# )
flux_train_utils.sample_images(
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs
accelerator, args, epoch + 1, global_step, flux, ae, [clip_l, t5xxl], sample_prompts_te_outputs, controlnet=controlnet
)
optimizer_train_fn()
@@ -850,18 +850,6 @@ def setup_parser() -> argparse.ArgumentParser:
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="controlnet model name or path / controlnetのモデル名またはパス",
)
# parser.add_argument(
# "--conditioning_data_dir",
# type=str,
# default=None,
# help="conditioning data directory / 条件付けデータのディレクトリ",
# )
return parser

View File

@@ -1142,11 +1142,11 @@ class ControlNetFlux(nn.Module):
self.num_single_blocks = len(self.single_blocks)
# add ControlNet blocks
self.controlnet_blocks_for_double = nn.ModuleList([])
self.controlnet_blocks = nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks_for_double.append(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_blocks_for_single = nn.ModuleList([])
for _ in range(0): # TODO
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
@@ -1312,7 +1312,7 @@ class ControlNetFlux(nn.Module):
controlnet_block_samples = ()
controlnet_single_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):

View File

@@ -175,10 +175,6 @@ def sample_image_inference(
# if negative_prompt is None:
# negative_prompt = ""
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
@@ -232,6 +228,12 @@ def sample_image_inference(
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None
if controlnet_image is not None:
controlnet_image = Image.open(controlnet_image).convert("RGB")
controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)
with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
@@ -315,6 +317,8 @@ def denoise(
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()
@@ -560,6 +564,12 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
help="path to t5xxl (*.sft or *.safetensors), should be float16 / t5xxlのパス*.sftまたは*.safetensors、float16が前提",
)
parser.add_argument("--ae", type=str, help="path to ae (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors")
parser.add_argument(
"--controlnet",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / aeのパス*.sftまたは*.safetensors"
)
parser.add_argument(
"--t5xxl_max_token_length",
type=int,

View File

@@ -153,15 +153,22 @@ def load_ae(
return ae
def load_controlnet():
# TODO
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device("cuda:0"):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
# if transformer is not None:
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
return controlnet
def load_controlnet(
ckpt_path: Optional[str], dtype: torch.dtype, device: Union[str, torch.device], disable_mmap: bool = False
):
logger.info("Building ControlNet")
# is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
is_schnell = False
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
with torch.device("meta"):
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params).to(dtype)
if ckpt_path is not None:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
def load_clip_l(