diff --git a/train_native.py b/train_native.py index 79418f4d..d8ff94f0 100644 --- a/train_native.py +++ b/train_native.py @@ -1258,6 +1258,7 @@ class NativeTrainer: minimum_metadata[key] = metadata[key] # calculate steps to skip when resuming or starting from a specific step + # this is not used for logging and file save. use global_step instead. initial_step = 0 if args.initial_epoch is not None or args.initial_step is not None: # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming @@ -1305,6 +1306,7 @@ class NativeTrainer: epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip + # This variable is used for logging and file save: global_step = 0 noise_scheduler = self.get_noise_scheduler(args, accelerator.device) @@ -1343,7 +1345,8 @@ class NativeTrainer: for skip_epoch in range(epoch_to_start): # skip epochs logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") initial_step -= len(train_dataloader) - global_step = initial_step + # I have found that the log is screwed up. This should be divided back. + global_step = int(initial_step / args.gradient_accumulation_steps) # log device and dtype for each model logger.info(f"unet dtype: {unet_weight_dtype}, device: {unet.device}")