mirror of
https://github.com/kohya-ss/sd-scripts.git
synced 2026-04-08 22:35:09 +00:00
Support SD3.5M multi resolutional training
This commit is contained in:
@@ -26,8 +26,8 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
super().__init__()
|
||||
self.sample_prompts_te_outputs = None
|
||||
|
||||
def assert_extra_args(self, args, train_dataset_group):
|
||||
super().assert_extra_args(args, train_dataset_group)
|
||||
def assert_extra_args(self, args, train_dataset_group: train_util.DatasetGroup):
|
||||
# super().assert_extra_args(args, train_dataset_group)
|
||||
# sdxl_train_util.verify_sdxl_training_args(args)
|
||||
|
||||
if args.fp8_base_unet:
|
||||
@@ -53,6 +53,9 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
|
||||
train_dataset_group.verify_bucket_reso_steps(32) # TODO check this
|
||||
|
||||
# enumerate resolutions from dataset for positional embeddings
|
||||
self.resolutions = train_dataset_group.get_resolutions()
|
||||
|
||||
def load_target_model(self, args, weight_dtype, accelerator):
|
||||
# currently offload to cpu for some models
|
||||
|
||||
@@ -67,6 +70,12 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
|
||||
self.model_type = mmdit.model_type
|
||||
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
|
||||
|
||||
# set resolutions for positional embeddings
|
||||
if args.enable_scaled_pos_embed:
|
||||
latent_sizes = [round(math.sqrt(res[0] * res[1])) // 8 for res in self.resolutions] # 8 is stride for latent
|
||||
logger.info(f"Prepare scaled positional embeddings for resolutions: {self.resolutions}, sizes: {latent_sizes}")
|
||||
mmdit.enable_scaled_pos_embed(True, latent_sizes)
|
||||
|
||||
if args.fp8_base:
|
||||
# check dtype of model
|
||||
if mmdit.dtype == torch.float8_e4m3fnuz or mmdit.dtype == torch.float8_e5m2 or mmdit.dtype == torch.float8_e5m2fnuz:
|
||||
|
||||
Reference in New Issue
Block a user