Make a folder for seperating blip files, add seed.

This commit is contained in:
Kohya S
2022-12-20 08:18:24 +09:00
parent e53adbdbcc
commit dadea12ad2
5 changed files with 18 additions and 5 deletions

View File

@@ -10,8 +10,8 @@ warnings.filterwarnings("ignore")
# from models.vit import VisionTransformer, interpolate_pos_embed # from models.vit import VisionTransformer, interpolate_pos_embed
# from models.med import BertConfig, BertModel, BertLMHeadModel # from models.med import BertConfig, BertModel, BertLMHeadModel
from vit import VisionTransformer, interpolate_pos_embed from blip.vit import VisionTransformer, interpolate_pos_embed
from med import BertConfig, BertModel, BertLMHeadModel from blip.med import BertConfig, BertModel, BertLMHeadModel
from transformers import BertTokenizer from transformers import BertTokenizer
import torch import torch

View File

@@ -2,6 +2,7 @@ import argparse
import glob import glob
import os import os
import json import json
import random
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
@@ -9,20 +10,31 @@ import numpy as np
import torch import torch
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
from blip import blip_decoder from blip.blip import blip_decoder
# from Salesforce_BLIP.models.blip import blip_decoder # from Salesforce_BLIP.models.blip import blip_decoder
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def main(args): def main(args):
# fix the seed for reproducibility
seed = args.seed # + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if not os.path.exists("blip"):
cwd = os.getcwd()
print('Current Working Directory is: ', cwd)
os.chdir('finetune')
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp")) glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
print(f"found {len(image_paths)} images.") print(f"found {len(image_paths)} images.")
print(f"loading BLIP caption: {args.caption_weights}") print(f"loading BLIP caption: {args.caption_weights}")
image_size = 384 image_size = 384
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./med_config.json") model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
model.eval() model.eval()
model = model.to(DEVICE) model = model.to(DEVICE)
print("BLIP loaded") print("BLIP loaded")
@@ -84,6 +96,7 @@ if __name__ == '__main__':
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長") parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
parser.add_argument("--debug", action="store_true", help="debug mode") parser.add_argument("--debug", action="store_true", help="debug mode")
args = parser.parse_args() args = parser.parse_args()