mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Make a folder for seperating blip files, add seed.
This commit is contained in:
@@ -10,8 +10,8 @@ warnings.filterwarnings("ignore")
|
|||||||
|
|
||||||
# from models.vit import VisionTransformer, interpolate_pos_embed
|
# from models.vit import VisionTransformer, interpolate_pos_embed
|
||||||
# from models.med import BertConfig, BertModel, BertLMHeadModel
|
# from models.med import BertConfig, BertModel, BertLMHeadModel
|
||||||
from vit import VisionTransformer, interpolate_pos_embed
|
from blip.vit import VisionTransformer, interpolate_pos_embed
|
||||||
from med import BertConfig, BertModel, BertLMHeadModel
|
from blip.med import BertConfig, BertModel, BertLMHeadModel
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -2,6 +2,7 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -9,20 +10,31 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms.functional import InterpolationMode
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
from blip import blip_decoder
|
from blip.blip import blip_decoder
|
||||||
# from Salesforce_BLIP.models.blip import blip_decoder
|
# from Salesforce_BLIP.models.blip import blip_decoder
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
# fix the seed for reproducibility
|
||||||
|
seed = args.seed # + utils.get_rank()
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
if not os.path.exists("blip"):
|
||||||
|
cwd = os.getcwd()
|
||||||
|
print('Current Working Directory is: ', cwd)
|
||||||
|
os.chdir('finetune')
|
||||||
|
|
||||||
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
|
||||||
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
|
||||||
print(f"found {len(image_paths)} images.")
|
print(f"found {len(image_paths)} images.")
|
||||||
|
|
||||||
print(f"loading BLIP caption: {args.caption_weights}")
|
print(f"loading BLIP caption: {args.caption_weights}")
|
||||||
image_size = 384
|
image_size = 384
|
||||||
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./med_config.json")
|
model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large', med_config="./blip/med_config.json")
|
||||||
model.eval()
|
model.eval()
|
||||||
model = model.to(DEVICE)
|
model = model.to(DEVICE)
|
||||||
print("BLIP loaded")
|
print("BLIP loaded")
|
||||||
@@ -84,6 +96,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
|
||||||
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
|
||||||
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
|
||||||
|
parser.add_argument('--seed', default=42, type=int, help='seed for reproducibility / 再現性を確保するための乱数seed')
|
||||||
parser.add_argument("--debug", action="store_true", help="debug mode")
|
parser.add_argument("--debug", action="store_true", help="debug mode")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
Reference in New Issue
Block a user