Reinstantiate weighted captions after a necessary revert to Main

This commit is contained in:
AI-Casanova
2023-04-02 19:43:34 +00:00
parent f037b09c2d
commit 1892c82a60
5 changed files with 360 additions and 20 deletions

View File

@@ -23,8 +23,7 @@ from library.config_util import (
BlueprintGenerator,
)
import library.custom_train_functions as custom_train_functions
from library.custom_train_functions import apply_snr_weight
from library.custom_train_functions import apply_snr_weight, get_weighted_text_embeddings
def train(args):
train_util.verify_training_args(args)
@@ -273,10 +272,19 @@ def train(args):
# Get the text embedding for conditioning
with torch.set_grad_enabled(global_step < args.stop_text_encoder_training):
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
if args.weighted_captions:
encoder_hidden_states = get_weighted_text_embeddings(tokenizer,
text_encoder,
batch["captions"],
accelerator.device,
args.max_token_length // 75 if args.max_token_length else 1,
clip_skip=args.clip_skip,
)
else:
input_ids = batch["input_ids"].to(accelerator.device)
encoder_hidden_states = train_util.get_hidden_states(
args, input_ids, tokenizer, text_encoder, None if not args.full_fp16 else weight_dtype
)
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
@@ -426,4 +434,4 @@ if __name__ == "__main__":
args = parser.parse_args()
args = train_util.read_config_from_file(args, parser)
train(args)
train(args)