mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
train run
This commit is contained in:
@@ -1042,20 +1042,20 @@ class Flux(nn.Module):
|
||||
if not self.blocks_to_swap:
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
if block_controlnet_hidden_states is not None:
|
||||
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
for block_idx, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
if block_controlnet_single_hidden_states is not None:
|
||||
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||
else:
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
self.offloader_double.wait_for_block(block_idx)
|
||||
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
if block_controlnet_hidden_states is not None:
|
||||
if block_controlnet_hidden_states is not None and controlnet_depth > 0:
|
||||
img = img + block_controlnet_hidden_states[block_idx % controlnet_depth]
|
||||
|
||||
self.offloader_double.submit_move_blocks(self.double_blocks, block_idx)
|
||||
@@ -1066,7 +1066,7 @@ class Flux(nn.Module):
|
||||
self.offloader_single.wait_for_block(block_idx)
|
||||
|
||||
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
if block_controlnet_single_hidden_states is not None:
|
||||
if block_controlnet_single_hidden_states is not None and controlnet_single_depth > 0:
|
||||
img = img + block_controlnet_single_hidden_states[block_idx % controlnet_single_depth]
|
||||
|
||||
self.offloader_single.submit_move_blocks(self.single_blocks, block_idx)
|
||||
@@ -1121,14 +1121,14 @@ class ControlNetFlux(nn.Module):
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(0) # TMP
|
||||
for _ in range(0) # TODO
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1148,7 +1148,7 @@ class ControlNetFlux(nn.Module):
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks_for_double.append(controlnet_block)
|
||||
self.controlnet_blocks_for_single = nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
for _ in range(0): # TODO
|
||||
controlnet_block = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks_for_single.append(controlnet_block)
|
||||
@@ -1252,7 +1252,7 @@ class ControlNetFlux(nn.Module):
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_img: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
@@ -1265,10 +1265,10 @@ class ControlNetFlux(nn.Module):
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_img = self.input_hint_block(controlnet_img)
|
||||
controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_img = self.pos_embed_input(controlnet_img)
|
||||
img = img + controlnet_img
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
@@ -1283,7 +1283,7 @@ class ControlNetFlux(nn.Module):
|
||||
block_samples = ()
|
||||
block_single_samples = ()
|
||||
if not self.blocks_to_swap:
|
||||
for block_idx, block in enumerate(self.double_blocks):
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
|
||||
block_samples = block_samples + (img,)
|
||||
|
||||
@@ -1315,7 +1315,7 @@ class ControlNetFlux(nn.Module):
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_double):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_single_blocks_for_single):
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks_for_single):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)
|
||||
|
||||
|
||||
@@ -460,7 +460,7 @@ def get_noisy_model_input_and_timesteps(
|
||||
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
|
||||
|
||||
return noisy_model_input, timesteps, sigmas
|
||||
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
|
||||
|
||||
|
||||
def apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas):
|
||||
|
||||
@@ -157,7 +157,7 @@ def load_controlnet():
|
||||
# TODO
|
||||
is_schnell = False
|
||||
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL
|
||||
with torch.device("meta"):
|
||||
with torch.device("cuda:0"):
|
||||
controlnet = flux_models.ControlNetFlux(flux_models.configs[name].params)
|
||||
# if transformer is not None:
|
||||
# controlnet.load_state_dict(transformer.state_dict(), strict=False)
|
||||
|
||||
Reference in New Issue
Block a user