Commit Graph

279 Commits

Author SHA1 Message Date
Cauldrath
e66f94a76c Adjustments to resuming training
Currently skips the resumed epoch if partway through
These changes make it resume mid epoch on the appropriate step
2024-06-30 23:40:53 -04:00
Kohya S
4dbcef429b update for corner cases 2024-06-04 21:26:55 +09:00
Kohaku-Blueleaf
3eb27ced52 Skip the final 1 step 2024-05-31 12:24:15 +08:00
Kohaku-Blueleaf
b2363f1021 Final implementation 2024-05-31 12:20:20 +08:00
Kohya S
da6fea3d97 simplify and update alpha mask to work with various cases 2024-05-19 21:26:18 +09:00
u-haru
db6752901f 画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)
* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
2024-05-19 19:07:25 +09:00
Kohya S
c68baae480 add --log_config option to enable/disable output training config 2024-05-19 17:21:04 +09:00
Kohya S
47187f7079 Merge pull request #1285 from ccharest93/main
Hyperparameter tracking
2024-05-19 16:31:33 +09:00
Kohya S
52e64c69cf add debug log 2024-05-04 18:43:52 +09:00
Kohya S
58c2d856ae support block dim/lr for sdxl 2024-05-03 22:18:20 +09:00
Kohya S
969f82ab47 move loraplus args from args to network_args, simplify log lr desc 2024-04-29 20:04:25 +09:00
Kohya S
834445a1d6 Merge pull request #1233 from rockerBOO/lora-plus
Add LoRA+ support
2024-04-29 18:05:12 +09:00
Kohya S
0540c33aca pop weights if available #1247 2024-04-21 17:45:29 +09:00
Kohya S
52652cba1a disable main process check for deepspeed #1247 2024-04-21 17:41:32 +09:00
Maatra
2c9db5d9f2 passing filtered hyperparameters to accelerate 2024-04-20 14:11:43 +01:00
rockerBOO
75833e84a1 Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet
Add log for lora+
2024-04-08 19:23:02 -04:00
Kohya S
d30ebb205c update readme, add metadata for network module 2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-04-07 13:54:21 +09:00
ykume
cd587ce62c verify command line args if wandb is enabled 2024-04-05 08:23:03 +09:00
rockerBOO
f99fe281cb Add LoRA+ support 2024-04-01 15:38:26 -04:00
Kohya S
2258a1b753 add save/load hook to remove U-Net/TEs from state 2024-03-31 15:50:35 +09:00
Kohya S
c86e356013 Merge branch 'dev' into dataset-cache 2024-03-26 19:43:40 +09:00
Kohya S
ab1e389347 Merge branch 'dev' into masked-loss 2024-03-26 19:39:30 +09:00
Kohya S
a2b8531627 make each script consistent, fix to work w/o DeepSpeed 2024-03-25 22:28:46 +09:00
Kohya S
025347214d refactor metadata caching for DreamBooth dataset 2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00
Kohya S
fbb98f144e Merge branch 'dev' into deep-speed 2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204 Merge branch 'dev' into masked-loss 2024-03-20 18:14:36 +09:00
Kohya S
bf6cd4b9da Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3419c3de0d common masked loss func, apply to all training script 2024-03-17 19:30:20 +09:00
gesen2egee
d282c45002 Update train_network.py 2024-03-11 23:56:09 +08:00
gesen2egee
095b8035e6 save state on train end 2024-03-10 23:33:38 +08:00
Kohya S
e3ccf8fbf7 make deepspeed_utils 2024-02-27 21:30:46 +09:00
Kohya S
eefb3cc1e7 Merge branch 'deep-speed' into deepspeed 2024-02-27 18:57:42 +09:00
Kohya S
4a5546d40e fix typo 2024-02-26 23:39:56 +09:00
Kohya S
f2c727fc8c add minimal impl for masked loss 2024-02-26 23:19:58 +09:00
Kohya S
577e9913ca add some new dataset settings 2024-02-26 20:01:25 +09:00
Kohya S
f4132018c5 fix to work with cpu_count() == 1 closes #1134 2024-02-24 19:25:31 +09:00
BootsofLagrangian
4d5186d1cf refactored codes, some function moved into train_utils.py 2024-02-22 16:20:53 +09:00
Kohya S
baa0e97ced Merge branch 'dev' into dev_device_support 2024-02-17 11:54:07 +09:00
Kohya S
93bed60762 fix to work --console_log_xxx options 2024-02-12 14:49:29 +09:00
Kohya S
358ca205a3 Merge branch 'dev' into dev_device_support 2024-02-12 13:01:54 +09:00
Kohya S
e24d9606a2 add clean_memory_on_device and use it from training 2024-02-12 11:10:52 +09:00
Kohya S
055f02e1e1 add logging args for training scripts 2024-02-08 21:16:42 +09:00
BootsofLagrangian
62556619bd fix full_fp16 compatible and train_step 2024-02-07 16:42:05 +09:00
BootsofLagrangian
7d2a9268b9 apply offloading method runable for all trainer 2024-02-05 22:42:06 +09:00
BootsofLagrangian
4295f91dcd fix all trainer about vae 2024-02-05 20:19:56 +09:00
Yuta Hayashibe
5f6bf29e52 Replace print with logger if they are logs (#905)
* Add get_my_logger()

* Use logger instead of print

* Fix log level

* Removed line-breaks for readability

* Use setup_logging()

* Add rich to requirements.txt

* Make simple

* Use logger instead of print

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-02-04 18:14:34 +09:00
BootsofLagrangian
dfe08f395f support deepspeed 2024-02-04 03:12:42 +09:00
Disty0
a6a2b5a867 Fix IPEX support and add XPU device to device_utils 2024-01-31 17:32:37 +03:00