mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-09 06:45:09 +00:00
Implement XTI
This commit is contained in:
@@ -391,6 +391,8 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
|
||||
self.token_padding_disabled = False
|
||||
self.tag_frequency = {}
|
||||
self.XTI_layers = None
|
||||
self.token_strings = None
|
||||
|
||||
self.enable_bucket = False
|
||||
self.bucket_manager: BucketManager = None # not initialized
|
||||
@@ -437,6 +439,10 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
def disable_token_padding(self):
|
||||
self.token_padding_disabled = True
|
||||
|
||||
def enable_XTI(self, layers=None, token_strings=None):
|
||||
self.XTI_layers = layers
|
||||
self.token_strings = token_strings
|
||||
|
||||
def add_replacement(self, str_from, str_to):
|
||||
self.replacements[str_from] = str_to
|
||||
|
||||
@@ -870,9 +876,22 @@ class BaseDataset(torch.utils.data.Dataset):
|
||||
latents_list.append(latents)
|
||||
|
||||
caption = self.process_caption(subset, image_info.caption)
|
||||
captions.append(caption)
|
||||
if self.XTI_layers:
|
||||
caption_layer = []
|
||||
for layer in self.XTI_layers:
|
||||
token_strings_from = " ".join(self.token_strings)
|
||||
token_strings_to = " ".join([f"{x}_{layer}" for x in self.token_strings])
|
||||
caption_ = caption.replace(token_strings_from, token_strings_to)
|
||||
caption_layer.append(caption_)
|
||||
captions.append(caption_layer)
|
||||
else:
|
||||
captions.append(caption)
|
||||
if not self.token_padding_disabled: # this option might be omitted in future
|
||||
input_ids_list.append(self.get_input_ids(caption))
|
||||
if self.XTI_layers:
|
||||
token_caption = self.get_input_ids(caption_layer)
|
||||
else:
|
||||
token_caption = self.get_input_ids(caption)
|
||||
input_ids_list.append(token_caption)
|
||||
|
||||
example = {}
|
||||
example["loss_weights"] = torch.FloatTensor(loss_weights)
|
||||
@@ -1273,6 +1292,10 @@ class DatasetGroup(torch.utils.data.ConcatDataset):
|
||||
# for dataset in self.datasets:
|
||||
# dataset.make_buckets()
|
||||
|
||||
def enable_XTI(self, *args, **kwargs):
|
||||
for dataset in self.datasets:
|
||||
dataset.enable_XTI(*args, **kwargs)
|
||||
|
||||
def cache_latents(self, vae, vae_batch_size=1):
|
||||
for i, dataset in enumerate(self.datasets):
|
||||
print(f"[Dataset {i}]")
|
||||
|
||||
Reference in New Issue
Block a user