Commit Graph

404 Commits

Author SHA1 Message Date
gesen2egee
cdb2d9c516 Update train_network.py 2024-08-04 17:36:34 +08:00
gesen2egee
aa850aa531 Update train_network.py 2024-08-04 17:34:20 +08:00
gesen2egee
f6dbf7c419 Update train_network.py 2024-08-04 15:18:53 +08:00
gesen2egee
a593e837f3 Update train_network.py 2024-08-04 15:17:30 +08:00
gesen2egee
b9bdd10129 Update train_network.py 2024-08-04 15:11:26 +08:00
gesen2egee
31507b9901 Remove unnecessary is_train changes and use apply_debiased_estimation to calculate validation loss. Balances the influence of different time steps on training performance (without affecting actual training results) 2024-08-02 13:15:21 +08:00
Kohya S
41dee60383 Refactor caching mechanism for latents and text encoder outputs, etc. 2024-07-27 13:50:05 +09:00
Kohya S
4dbcef429b update for corner cases 2024-06-04 21:26:55 +09:00
Kohaku-Blueleaf
3eb27ced52 Skip the final 1 step 2024-05-31 12:24:15 +08:00
Kohaku-Blueleaf
b2363f1021 Final implementation 2024-05-31 12:20:20 +08:00
Kohya S
da6fea3d97 simplify and update alpha mask to work with various cases 2024-05-19 21:26:18 +09:00
u-haru
db6752901f 画像のアルファチャンネルをlossのマスクとして使用するオプションを追加 (#1223)
* Add alpha_mask parameter and apply masked loss

* Fix type hint in trim_and_resize_if_required function

* Refactor code to use keyword arguments in train_util.py

* Fix alpha mask flipping logic

* Fix alpha mask initialization

* Fix alpha_mask transformation

* Cache alpha_mask

* Update alpha_masks to be on CPU

* Set flipped_alpha_masks to Null if option disabled

* Check if alpha_mask is None

* Set alpha_mask to None if option disabled

* Add description of alpha_mask option to docs
2024-05-19 19:07:25 +09:00
Kohya S
c68baae480 add --log_config option to enable/disable output training config 2024-05-19 17:21:04 +09:00
Kohya S
47187f7079 Merge pull request #1285 from ccharest93/main
Hyperparameter tracking
2024-05-19 16:31:33 +09:00
Kohya S
52e64c69cf add debug log 2024-05-04 18:43:52 +09:00
Kohya S
58c2d856ae support block dim/lr for sdxl 2024-05-03 22:18:20 +09:00
Kohya S
969f82ab47 move loraplus args from args to network_args, simplify log lr desc 2024-04-29 20:04:25 +09:00
Kohya S
834445a1d6 Merge pull request #1233 from rockerBOO/lora-plus
Add LoRA+ support
2024-04-29 18:05:12 +09:00
Kohya S
0540c33aca pop weights if available #1247 2024-04-21 17:45:29 +09:00
Kohya S
52652cba1a disable main process check for deepspeed #1247 2024-04-21 17:41:32 +09:00
Maatra
2c9db5d9f2 passing filtered hyperparameters to accelerate 2024-04-20 14:11:43 +01:00
gesen2egee
086f6000f2 Merge branch 'main' into val 2024-04-11 01:14:46 +08:00
rockerBOO
75833e84a1 Fix default LR, Add overall LoRA+ ratio, Add log
`--loraplus_ratio` added for both TE and UNet
Add log for lora+
2024-04-08 19:23:02 -04:00
Kohya S
d30ebb205c update readme, add metadata for network module 2024-04-07 14:58:17 +09:00
kabachuha
90b18795fc Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption (#1228)
* add huber loss and huber_c compute to train_util

* add reduction modes

* add huber_c retrieval from timestep getter

* move get timesteps and huber to own function

* add conditional loss to all training scripts

* add cond loss to train network

* add (scheduled) huber_loss to args

* fixup twice timesteps getting

* PHL-schedule should depend on noise scheduler's num timesteps

* *2 multiplier to huber loss cause of 1/2 a^2 conv.

The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another

* add option for smooth l1 (huber / delta)

* unify huber scheduling

* add snr huber scheduler

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-04-07 13:54:21 +09:00
ykume
cd587ce62c verify command line args if wandb is enabled 2024-04-05 08:23:03 +09:00
rockerBOO
f99fe281cb Add LoRA+ support 2024-04-01 15:38:26 -04:00
Kohya S
2258a1b753 add save/load hook to remove U-Net/TEs from state 2024-03-31 15:50:35 +09:00
Kohya S
c86e356013 Merge branch 'dev' into dataset-cache 2024-03-26 19:43:40 +09:00
Kohya S
ab1e389347 Merge branch 'dev' into masked-loss 2024-03-26 19:39:30 +09:00
Kohya S
a2b8531627 make each script consistent, fix to work w/o DeepSpeed 2024-03-25 22:28:46 +09:00
Kohya S
025347214d refactor metadata caching for DreamBooth dataset 2024-03-24 18:09:32 +09:00
Kohaku-Blueleaf
ae97c8bfd1 [Experimental] Add cache mechanism for dataset groups to avoid long waiting time for initilization (#1178)
* support meta cached dataset

* add cache meta scripts

* random ip_noise_gamma strength

* random noise_offset strength

* use correct settings for parser

* cache path/caption/size only

* revert mess up commit

* revert mess up commit

* Update requirements.txt

* Add arguments for meta cache.

* remove pickle implementation

* Return sizes when enable cache

---------

Co-authored-by: Kohya S <52813779+kohya-ss@users.noreply.github.com>
2024-03-24 15:40:18 +09:00
Kohya S
fbb98f144e Merge branch 'dev' into deep-speed 2024-03-20 18:15:26 +09:00
Kohya S
9b6b39f204 Merge branch 'dev' into masked-loss 2024-03-20 18:14:36 +09:00
Kohya S
bf6cd4b9da Merge pull request #1168 from gesen2egee/save_state_on_train_end
Save state on train end
2024-03-20 18:02:13 +09:00
Kohya S
3419c3de0d common masked loss func, apply to all training script 2024-03-17 19:30:20 +09:00
gesen2egee
d05965dbad Update train_network.py 2024-03-13 18:33:51 +08:00
gesen2egee
5d7ed0dff0 Merge remote-tracking branch 'kohya-ss/dev' into val 2024-03-13 18:00:49 +08:00
gesen2egee
bd7e2295b7 fix 2024-03-13 17:54:21 +08:00
gesen2egee
d282c45002 Update train_network.py 2024-03-11 23:56:09 +08:00
gesen2egee
a6c41c6bea Update train_network.py 2024-03-11 19:23:48 +08:00
gesen2egee
63e58f78e3 Update train_network.py 2024-03-11 19:15:55 +08:00
gesen2egee
befbec5335 Update train_network.py 2024-03-11 18:47:04 +08:00
gesen2egee
a51723cc2a fix timesteps 2024-03-11 09:42:58 +08:00
gesen2egee
095b8035e6 save state on train end 2024-03-10 23:33:38 +08:00
gesen2egee
47359b8fac Update train_network.py 2024-03-10 20:17:40 +08:00
gesen2egee
923b761ce3 Update train_network.py 2024-03-10 20:01:40 +08:00
gesen2egee
78cfb01922 improve 2024-03-10 18:55:48 +08:00
gesen2egee
b558a5b73d val 2024-03-10 04:37:16 +08:00