mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'sd3' into multi-gpu-caching
This commit is contained in:
@@ -999,8 +999,9 @@ class Flux(nn.Module):
|
||||
|
||||
def prepare_block_swap_before_forward(self):
|
||||
# make: first n blocks are on cuda, and last n blocks are on cpu
|
||||
if self.blocks_to_swap is None:
|
||||
raise ValueError("Block swap is not enabled.")
|
||||
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
||||
# raise ValueError("Block swap is not enabled.")
|
||||
return
|
||||
for i in range(self.num_block_units - self.blocks_to_swap):
|
||||
for b in self.get_block_unit(i):
|
||||
b.to(self.device)
|
||||
|
||||
@@ -196,7 +196,6 @@ def sample_image_inference(
|
||||
tokens_and_masks = tokenize_strategy.tokenize(prompt)
|
||||
# strategy has apply_t5_attn_mask option
|
||||
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)
|
||||
print([x.shape if x is not None else None for x in encoded_text_encoder_conds])
|
||||
|
||||
# if text_encoder_conds is not cached, use encoded_text_encoder_conds
|
||||
if len(text_encoder_conds) == 0:
|
||||
@@ -313,6 +312,7 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
model.prepare_block_swap_before_forward()
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -325,7 +325,8 @@ def denoise(
|
||||
)
|
||||
|
||||
img = img + (t_prev - t_curr) * pred
|
||||
|
||||
|
||||
model.prepare_block_swap_before_forward()
|
||||
return img
|
||||
|
||||
|
||||
|
||||
@@ -661,6 +661,34 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
self.text_encoder_output_caching_strategy = TextEncoderOutputsCachingStrategy.get_strategy()
|
||||
self.latents_caching_strategy = LatentsCachingStrategy.get_strategy()
|
||||
|
||||
def adjust_min_max_bucket_reso_by_steps(
|
||||
self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int
|
||||
) -> Tuple[int, int]:
|
||||
# make min/max bucket reso to be multiple of bucket_reso_steps
|
||||
if min_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}"
|
||||
)
|
||||
min_bucket_reso = adjusted_min_bucket_reso
|
||||
if max_bucket_reso % bucket_reso_steps != 0:
|
||||
adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps
|
||||
logger.warning(
|
||||
f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps"
|
||||
f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}"
|
||||
)
|
||||
max_bucket_reso = adjusted_max_bucket_reso
|
||||
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
|
||||
return min_bucket_reso, max_bucket_reso
|
||||
|
||||
def set_seed(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
@@ -1712,12 +1740,9 @@ class DreamBoothDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
assert (
|
||||
min(resolution) >= min_bucket_reso
|
||||
), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください"
|
||||
assert (
|
||||
max(resolution) <= max_bucket_reso
|
||||
), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください"
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -2090,6 +2115,9 @@ class FineTuningDataset(BaseDataset):
|
||||
|
||||
self.enable_bucket = enable_bucket
|
||||
if self.enable_bucket:
|
||||
min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps(
|
||||
resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps
|
||||
)
|
||||
self.min_bucket_reso = min_bucket_reso
|
||||
self.max_bucket_reso = max_bucket_reso
|
||||
self.bucket_reso_steps = bucket_reso_steps
|
||||
@@ -4154,8 +4182,20 @@ def add_dataset_arguments(
|
||||
action="store_true",
|
||||
help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする",
|
||||
)
|
||||
parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
|
||||
parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度")
|
||||
parser.add_argument(
|
||||
"--min_bucket_reso",
|
||||
type=int,
|
||||
default=256,
|
||||
help="minimum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_bucket_reso",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="maximum resolution for buckets, must be divisible by bucket_reso_steps "
|
||||
" / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bucket_reso_steps",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user