mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
support sdxl in prepare scipt
This commit is contained in:
@@ -124,11 +124,11 @@ class BucketManager:
|
||||
|
||||
self.resos = []
|
||||
self.reso_to_id = {}
|
||||
self.buckets = [] # 前処理時は (image_key, image)、学習時は image_key
|
||||
self.buckets = [] # 前処理時は (image_key, image, original size, crop left/top)、学習時は image_key
|
||||
|
||||
def add_image(self, reso, image):
|
||||
def add_image(self, reso, image_or_info):
|
||||
bucket_id = self.reso_to_id[reso]
|
||||
self.buckets[bucket_id].append(image)
|
||||
self.buckets[bucket_id].append(image_or_info)
|
||||
|
||||
def shuffle(self):
|
||||
for bucket in self.buckets:
|
||||
@@ -767,7 +767,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img = np.array(image, np.uint8)
|
||||
return img
|
||||
|
||||
def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_size):
|
||||
# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top)
|
||||
def trim_and_resize_if_required(
|
||||
self, subset: BaseSubset, image, reso, resized_size
|
||||
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int]]:
|
||||
image_height, image_width = image.shape[0:2]
|
||||
|
||||
if image_width != resized_size[0] or image_height != resized_size[1]:
|
||||
@@ -907,19 +910,13 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
|
||||
# check NaN
|
||||
for info, latents1 in zip(batch, latents):
|
||||
if torch.isnan(latents1).any():
|
||||
for info, latent in zip(batch, latents):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any():
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
for info, latent in zip(batch, latents):
|
||||
if cache_to_disk:
|
||||
np.savez(
|
||||
info.latents_npz,
|
||||
latents=latent.float().numpy(),
|
||||
original_size=np.array(info.latents_original_size),
|
||||
crop_left_top=np.array(info.latents_crop_left_top),
|
||||
)
|
||||
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_left_top)
|
||||
else:
|
||||
info.latents = latent
|
||||
|
||||
@@ -927,12 +924,14 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
img_tensors = torch.flip(img_tensors, dims=[3])
|
||||
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
|
||||
for info, latent in zip(batch, latents):
|
||||
# check NaN
|
||||
if torch.isnan(latents).any():
|
||||
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
|
||||
|
||||
if cache_to_disk:
|
||||
np.savez(
|
||||
info.latents_npz_flipped,
|
||||
latents=latent.float().numpy(),
|
||||
original_size=np.array(info.latents_original_size),
|
||||
crop_left_top=np.array(info.latents_crop_left_top), # reverse horizontally when use flipped latents
|
||||
# crop_left_top is reversed when making batch
|
||||
save_latents_to_disk(
|
||||
info.latents_npz_flipped, latent, info.latents_original_size, info.latents_crop_left_top
|
||||
)
|
||||
else:
|
||||
info.latents_flipped = latent
|
||||
@@ -1005,18 +1004,7 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
def load_latents_from_npz(self, image_info: ImageInfo, flipped):
|
||||
npz_file = image_info.latents_npz_flipped if flipped else image_info.latents_npz
|
||||
if npz_file is None:
|
||||
return None, None, None
|
||||
|
||||
npz = np.load(npz_file)
|
||||
if "latents" not in npz:
|
||||
print(f"error: npz is old format. please re-generate {npz_file}")
|
||||
return None, None, None
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_left_top = npz["crop_left_top"].tolist()
|
||||
return latents, original_size, crop_left_top
|
||||
return load_latents_from_disk(npz_file)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
@@ -1762,6 +1750,31 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
dataset.disable_token_padding()
|
||||
|
||||
|
||||
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
|
||||
def load_latents_from_disk(npz_path) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]]]:
|
||||
if npz_path is None: # flipped doesn't exist
|
||||
return None, None, None
|
||||
|
||||
npz = np.load(npz_path)
|
||||
if "latents" not in npz:
|
||||
print(f"error: npz is old format. please re-generate {npz_path}")
|
||||
return None, None, None
|
||||
|
||||
latents = npz["latents"]
|
||||
original_size = npz["original_size"].tolist()
|
||||
crop_left_top = npz["crop_left_top"].tolist()
|
||||
return latents, original_size, crop_left_top
|
||||
|
||||
|
||||
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_left_top):
|
||||
np.savez(
|
||||
npz_path,
|
||||
latents=latents_tensor.float().cpu().numpy(),
|
||||
original_size=np.array(original_size),
|
||||
crop_left_top=np.array(crop_left_top),
|
||||
)
|
||||
|
||||
|
||||
def debug_dataset(train_dataset, show_input_ids=False):
|
||||
print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
|
||||
print("`S` for next step, `E` for next epoch no. , Escape for exit. / Sキーで次のステップ、Eキーで次のエポック、Escキーで中断、終了します")
|
||||
|
||||
Reference in New Issue
Block a user