mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Handle tuple return from generate_dataset_group_by_blueprint
This commit is contained in:
@@ -91,9 +91,9 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -138,9 +138,10 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -126,9 +126,9 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group, val_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -467,7 +467,7 @@ class BlueprintGenerator:
|
|||||||
|
|
||||||
return default_value
|
return default_value
|
||||||
|
|
||||||
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint):
|
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint) -> Tuple[DatasetGroup, Optional[DatasetGroup]]:
|
||||||
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
datasets: List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
|
||||||
|
|
||||||
for dataset_blueprint in dataset_group_blueprint.datasets:
|
for dataset_blueprint in dataset_group_blueprint.datasets:
|
||||||
|
|||||||
@@ -149,9 +149,10 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -176,9 +176,10 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=[tokenizer1, tokenizer2])
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -116,10 +116,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
# use arbitrary dataset class
|
# use arbitrary dataset class
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
logger.info("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
|
|||||||
@@ -103,10 +103,11 @@ def cache_to_disk(args: argparse.Namespace) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
# use arbitrary dataset class
|
# use arbitrary dataset class
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
# acceleratorを準備する
|
# acceleratorを準備する
|
||||||
logger.info("prepare accelerator")
|
logger.info("prepare accelerator")
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -89,9 +89,10 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
@@ -320,9 +320,10 @@ class TextualInversionTrainer:
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args)
|
blueprint = blueprint_generator.generate(user_config, args)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
else:
|
else:
|
||||||
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
train_dataset_group = train_util.load_arbitrary_dataset(args)
|
||||||
|
val_dataset_group = None
|
||||||
|
|
||||||
self.assert_extra_args(args, train_dataset_group)
|
self.assert_extra_args(args, train_dataset_group)
|
||||||
|
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ def train(args):
|
|||||||
}
|
}
|
||||||
|
|
||||||
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
blueprint = blueprint_generator.generate(user_config, args, tokenizer=tokenizer)
|
||||||
train_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
train_dataset_group, val_dataset_group = config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
||||||
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
train_dataset_group.enable_XTI(XTI_layers, token_strings=token_strings)
|
||||||
current_epoch = Value("i", 0)
|
current_epoch = Value("i", 0)
|
||||||
current_step = Value("i", 0)
|
current_step = Value("i", 0)
|
||||||
|
|||||||
Reference in New Issue
Block a user