mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Merge branch 'dev' into onnx
This commit is contained in:
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: typos-action
|
- name: typos-action
|
||||||
uses: crate-ci/typos@v1.16.15
|
uses: crate-ci/typos@v1.16.15
|
||||||
|
|||||||
@@ -197,9 +197,27 @@ def main(args):
|
|||||||
if len(character_tag_text) > 0:
|
if len(character_tag_text) > 0:
|
||||||
character_tag_text = character_tag_text[2:]
|
character_tag_text = character_tag_text[2:]
|
||||||
|
|
||||||
|
caption_file = os.path.splitext(image_path)[0] + args.caption_extension
|
||||||
|
|
||||||
tag_text = ", ".join(combined_tags)
|
tag_text = ", ".join(combined_tags)
|
||||||
|
|
||||||
with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding="utf-8") as f:
|
if args.append_tags:
|
||||||
|
# Check if file exists
|
||||||
|
if os.path.exists(caption_file):
|
||||||
|
with open(caption_file, "rt", encoding="utf-8") as f:
|
||||||
|
# Read file and remove new lines
|
||||||
|
existing_content = f.read().strip("\n") # Remove newlines
|
||||||
|
|
||||||
|
# Split the content into tags and store them in a list
|
||||||
|
existing_tags = [tag.strip() for tag in existing_content.split(",") if tag.strip()]
|
||||||
|
|
||||||
|
# Check and remove repeating tags in tag_text
|
||||||
|
new_tags = [tag for tag in combined_tags if tag not in existing_tags]
|
||||||
|
|
||||||
|
# Create new tag_text
|
||||||
|
tag_text = ", ".join(existing_tags + new_tags)
|
||||||
|
|
||||||
|
with open(caption_file, "wt", encoding="utf-8") as f:
|
||||||
f.write(tag_text + "\n")
|
f.write(tag_text + "\n")
|
||||||
if args.debug:
|
if args.debug:
|
||||||
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
print(f"\n{image_path}:\n Character tags: {character_tag_text}\n General tags: {general_tag_text}")
|
||||||
@@ -316,9 +334,11 @@ def setup_parser() -> argparse.ArgumentParser:
|
|||||||
)
|
)
|
||||||
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
parser.add_argument("--frequency_tags", action="store_true", help="Show frequency of tags for images / 画像ごとのタグの出現頻度を表示する")
|
||||||
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
|
parser.add_argument("--onnx", action="store_true", help="use onnx model for inference")
|
||||||
|
parser.add_argument("--append_tags", action="store_true", help="Append captions instead of overwriting / 上書きではなくキャプションを追記する")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = setup_parser()
|
parser = setup_parser()
|
||||||
|
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# JPEG-XL on Linux
|
||||||
try:
|
try:
|
||||||
from jxlpy import JXLImagePlugin
|
from jxlpy import JXLImagePlugin
|
||||||
|
|
||||||
@@ -103,6 +104,14 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# JPEG-XL on Windows
|
||||||
|
try:
|
||||||
|
import pillow_jxl
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
IMAGE_TRANSFORMS = transforms.Compose(
|
IMAGE_TRANSFORMS = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
|
|||||||
@@ -283,7 +283,10 @@ class NetworkTrainer:
|
|||||||
if args.dim_from_weights:
|
if args.dim_from_weights:
|
||||||
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
network, _ = network_module.create_network_from_weights(1, args.network_weights, vae, text_encoder, unet, **net_kwargs)
|
||||||
else:
|
else:
|
||||||
# LyCORIS will work with this...
|
if "dropout" not in net_kwargs:
|
||||||
|
# workaround for LyCORIS (;^ω^)
|
||||||
|
net_kwargs["dropout"] = args.network_dropout
|
||||||
|
|
||||||
network = network_module.create_network(
|
network = network_module.create_network(
|
||||||
1.0,
|
1.0,
|
||||||
args.network_dim,
|
args.network_dim,
|
||||||
|
|||||||
Reference in New Issue
Block a user