fix dataloader

This commit is contained in:
minux302
2024-11-16 14:49:29 +09:00
parent 42f6edf3a8
commit e358b118af
2 changed files with 52 additions and 49 deletions

View File

@@ -2,15 +2,15 @@
# license: Apache-2.0 License
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
import math
import os
import time
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from library import utils
from library.device_utils import init_ipex, clean_memory_on_device
from library.device_utils import clean_memory_on_device, init_ipex
init_ipex()
@@ -18,6 +18,7 @@ import torch
from einops import rearrange
from torch import Tensor, nn
from torch.utils.checkpoint import checkpoint
from library import custom_offloading_utils
# USE_REENTRANT = True
@@ -1251,7 +1252,7 @@ class ControlNetFlux(nn.Module):
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
controlnet_img: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
@@ -1264,10 +1265,10 @@ class ControlNetFlux(nn.Module):
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
controlnet_img = self.input_hint_block(controlnet_img)
controlnet_img = rearrange(controlnet_img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_img = self.pos_embed_input(controlnet_img)
img = img + controlnet_img
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None: