Merge branch 'dev' into dev

This commit is contained in:
Kohya S
2023-03-10 13:00:49 +09:00
committed by GitHub
14 changed files with 462 additions and 197 deletions

View File

@@ -4,7 +4,7 @@
import math
import os
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from safetensors.torch import load_file, save_file
@@ -916,7 +916,11 @@ def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
logging.set_verbosity_error() # don't show annoying warning
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
logging.set_verbosity_warning()
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print("loading text encoder:", info)

View File

@@ -924,7 +924,9 @@ class FineTuningDataset(BaseDataset):
elif tags is not None and len(tags) > 0:
caption = caption + ', ' + tags
tags_list.append(tags)
assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{abs_path}"
if caption is None:
caption = ""
image_info = ImageInfo(image_key, subset.num_repeats, caption, False, abs_path)
image_info.image_size = img_md.get('train_resolution')
@@ -2207,7 +2209,7 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if steps % args.sample_every_n_steps != 0:
if steps % args.sample_every_n_steps != 0 or epoch is not None: # steps is not divisible or end of epoch
return
print(f"generating sample images at step / サンプル画像生成 ステップ: {steps}")
@@ -2353,6 +2355,8 @@ def sample_images(accelerator, args: argparse.Namespace, epoch, steps, device, v
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
print(f"prompt: {prompt}")
print(f"negative_prompt: {negative_prompt}")
print(f"height: {height}")