diff --git a/train_network.py b/train_network.py index 9f8ccd2f..9dac2ed3 100644 --- a/train_network.py +++ b/train_network.py @@ -323,7 +323,6 @@ def train(args): subsets_metadata = [] for subset in dataset.subsets: subset_metadata = { - "image_dir": os.path.basename(subset.image_dir), "img_count": subset.img_count, "num_repeats": subset.num_repeats, "color_aug": bool(subset.color_aug), @@ -332,6 +331,9 @@ def train(args): "shuffle_caption": bool(subset.shuffle_caption), "keep_tokens": subset.keep_tokens, } + if subset.image_dir: + subset_metadata["image_dir"] = os.path.basename(subset.image_dir) + if is_dreambooth_dataset: subset_metadata["class_tokens"] = subset.class_tokens subset_metadata["is_reg"] = subset.is_reg