From 34f2315047f8d5b89b7a8a6093bb56679bff13c3 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 22:33:37 +0800 Subject: [PATCH 1/2] fix: text_encoder_conds referenced before assignment --- train_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 048c7e7b..628c421c 100644 --- a/train_network.py +++ b/train_network.py @@ -1081,12 +1081,12 @@ class NetworkTrainer: # print(f"set multiplier: {multipliers}") accelerator.unwrap_model(network).set_multiplier(multipliers) + text_encoder_conds = [] text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None) if text_encoder_outputs_list is not None: text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs if ( - text_encoder_conds is None - or len(text_encoder_conds) == 0 + len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder ): From 35882f8d5bbd076a97622cf6193c988621481803 Mon Sep 17 00:00:00 2001 From: Akegarasu Date: Thu, 29 Aug 2024 23:03:43 +0800 Subject: [PATCH 2/2] fix --- train_network.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/train_network.py b/train_network.py index 628c421c..4204bce3 100644 --- a/train_network.py +++ b/train_network.py @@ -1112,10 +1112,14 @@ class NetworkTrainer: if args.full_fp16: encoded_text_encoder_conds = [c.to(weight_dtype) for c in encoded_text_encoder_conds] - # if encoded_text_encoder_conds is not None, update cached text_encoder_conds - for i in range(len(encoded_text_encoder_conds)): - if encoded_text_encoder_conds[i] is not None: - text_encoder_conds[i] = encoded_text_encoder_conds[i] + # if text_encoder_conds is not cached, use encoded_text_encoder_conds + if len(text_encoder_conds) == 0: + text_encoder_conds = encoded_text_encoder_conds + else: + # if encoded_text_encoder_conds is not None, update cached text_encoder_conds + for i in range(len(encoded_text_encoder_conds)): + if encoded_text_encoder_conds[i] is not None: + text_encoder_conds[i] = encoded_text_encoder_conds[i] # sample noise, call unet, get target noise_pred, target, timesteps, huber_c, weighting = self.get_noise_pred_and_target(