diff --git a/train_network.py b/train_network.py index 628c421c..4204bce3 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ class NetworkTrainer: if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(