support SD3.5M

This commit is contained in:
Kohya S
2024-10-30 12:51:49 +09:00
parent 75554867ce
commit bdddc20d68
5 changed files with 99 additions and 58 deletions

View File

@@ -65,6 +65,7 @@ class Sd3NetworkTrainer(train_network.NetworkTrainer):
)
mmdit = sd3_utils.load_mmdit(state_dict, loading_dtype, "cpu")
self.model_type = mmdit.model_type
mmdit.set_pos_emb_random_crop_rate(args.pos_emb_random_crop_rate)
if args.fp8_base:
# check dtype of model