Add cleaning patterns

This commit is contained in:
Kohya S
2023-02-03 21:04:37 +09:00
parent 58a809eaff
commit 73d612ff9c

View File

@@ -13,17 +13,26 @@ import library.train_util as train_util
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PATTERN_REPLACE = [re.compile(r'with the (words?|letters?) (" ?[^"]*"|\w+)( on (the)? ?\w+)?'), PATTERN_REPLACE = [
re.compile(r'that says (" ?[^"]*"|\w+)')] re.compile(r'(has|with|and) the (words?|letters?|name) (" ?[^"]*"|\w+)( ?(is )?(on|in) (the |her |their |him )?\w+)?'),
re.compile(r'(with a sign )?that says ?(" ?[^"]*"|\w+)( ?on it)?'),
re.compile(r"(with a sign )?that says ?(' ?(i'm)?[^']*'|\w+)( ?on it)?"),
re.compile(r'with the number \d+ on (it|\w+ \w+)'),
re.compile(r'with the words "'),
re.compile(r'word \w+ on it'),
re.compile(r'that says the word \w+ on it'),
re.compile('that says\'the word "( on it)?'),
]
# 誤検知しまくりの with the word xxxx を消す # 誤検知しまくりの with the word xxxx を消す
def remove_words(captions, debug): def remove_words(captions, debug):
removed_caps = [] removed_caps = []
for caption in captions: for caption in captions:
cap = caption cap = caption
for pat in PATTERN_REPLACE: for pat in PATTERN_REPLACE:
cap = pat.sub("", caption) cap = pat.sub("", cap)
if debug and cap != caption: if debug and cap != caption:
print(caption) print(caption)
print(cap) print(cap)
@@ -87,7 +96,7 @@ def main(args):
if args.max_data_loader_n_workers is not None: if args.max_data_loader_n_workers is not None:
dataset = train_util.ImageLoadingDataset(image_paths) dataset = train_util.ImageLoadingDataset(image_paths)
data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, data = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False) num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False)
else: else:
data = [[(None, ip)] for ip in image_paths] data = [[(None, ip)] for ip in image_paths]
@@ -96,7 +105,7 @@ def main(args):
for data in data_entry: for data in data_entry:
if data is None: if data is None:
continue continue
image, image_path = data image, image_path = data
if image is None: if image is None:
try: try: