Skip to content

Conversation

araleza
Copy link

@araleza araleza commented Jul 23, 2025

If you've been doing full fine tuning with Flux, you've likely seen the image quality improvements that the stochastic rounding feature (available when using --fused_backward_pass with the supplied Adafactor optimizer) can offer. That feature improves the casting of the f32 'weight + gradient update' values back to bf16 during each training step by randomly moving the value either up or down to the next rounded bfloat16 value, depending on the distance the f32 value is between the two adjacent quantized bf16 values.

But although stochastic rounding offers a large improvement to bf16 training, its randomness doesn't always produce ideal results. As the "Revisiting BFloat16 Training" paper - https://arxiv.org/pdf/2010.06192 - notes:

"using stochastic rounding alone could not fully match 32-bit training on all models."

This same paper then goes on to note a different technique - Kahan summation - that does match the performance of 32-bit training:

"To address this, we show that Kahan summation for model weight updates closes remaining gaps on all the models we consider"

Kahan summation is a technique where the f32 values are rounded to the nearest bf16 value, and then the 16-bit offset between the original f32 and the quantized bf16 value is recorded. This offset is then sent back to the CPU so it doesn't use up VRAM past that point. Then, on the next training step, instead of just taking the bf16 weight value plus the gradient update for that step, the offset that was lost from the previous step is added back on as well. And then that process is repeated.

It means that a bf16 value that is updating by (e.g.) 20% of the distance between that bf16 value and the next adjacent bf16 value will go up once in every 5 training steps, rather than stochastic rounding's random approximation of once in every 5 steps. Stochastic rounded values bounce up and down fairly unpredictably, but the Kahan summed values are more stable.

The technique comes at a price: training steps take around 40% longer in my tests so far. (I'm training at batch size 5 on a 5090 RTX card). The slowdown comes from copying the values from the GPU to main memory and back again on each step. It's possible to use stochastic rounding for most of the training, and then switch to Kahan summation with --kahan_summation for the final polish phase, but doing full runs start-to-finish with kahan summation works fine too - and I'd recommend you do that, for best quality.

Now I've implemented Kahan summation in sd-scripts, the quality improvements that it achieves are impressive. Flux.dev training does seem to be one of the cases where Kahan summation significantly exceeds stochastic rounding in terms of image quality. Very low LRs such as 5e-7 (which suffer quantization randomness with stochastic rounding) work great with Kahan summation.

If you (e.g. @kohya-ss) want to see what this feature can do, then why not grab this branch, switch on --kahan_summation, set the LR to 5e-7, and then try running a final polish pass on an already-trained FFT model that you have? The quality improvements are almost immediate, as it allows the weights to 'settle' into a more stable pattern.

@araleza
Copy link
Author

araleza commented Jul 29, 2025

I've now accelerated the kahan summation function by only sending the 16 lower bits that are clipped from the f32 values to the cpu, instead of a floating point offset. With only half of the number of bytes needing to be transferred, it's much quicker. It might be higher quality too, as the bit identical f32 value that was quantized to bf16 is now restored.

@FurkanGozukara
Copy link

@araleza this sounds amazing

i recently compared majority of optimizers and adafactor was king

definitely gonna test your branch

can you share your example toml / json? so i need to add --kahan_summation what other?

@araleza
Copy link
Author

araleza commented Jul 29, 2025

I don't usually use toml / json, so I don't have one to share @FurkanGozukara. But yeah, switching this on is literally just adding --kahan_summation to the command line and that's that.

I don't know how to add that to a toml / json file, but it's just a command line option like the rest, so however you add those, it should work for --kahan_summation too.

I should really add a debug output message that says it's switched on, so you can confirm it's working. Give me a minute and I'll check that in now.

@FurkanGozukara
Copy link

I don't usually use toml / json, so I don't have one to share @FurkanGozukara. But yeah, switching this on is literally just adding --kahan_summation to the command line and that's that.

I don't know how to add that to a toml / json file, but it's just a command line option like the rest, so however you add those, it should work for --kahan_summation too.

I should really add a debug output message that says it's switched on, so you can confirm it's working. Give me a minute and I'll check that in now.

whatever format you use fine if you put here as a txt file i can compare what i am missing. i will definitely run a comparison training

@araleza
Copy link
Author

araleza commented Jul 29, 2025

I'm running directly from the command line like this, @FurkanGozukara:

python flux_train.py --clip_l /home/ara/Dev/sd3/clip_l.safetensors --t5xxl /home/ara/Dev/flux/flanT5XXLTextEncorder_fp16.safetensors --ae /home/ara/Dev/flux/ae.safetensors --cache_latents_to_disk --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --logging_dir="/home/ara/Dev/training/earthscape/kohya/log" --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --full_bf16 --cache_text_encoder_outputs --output_name earthscape --cache_text_encoder_outputs_to_disk --max_train_epochs 900 --save_every_n_steps 40 --save_last_n_steps 80 --enable_bucket --min_bucket_reso=64 --bucket_no_upscale --max_bucket_reso=1024 --resolution="1024,1024" --caption_extension=".txt" --sample_sampler=k_dpm_2 --train_data_dir="/home/ara/Dev/training/earthscape/kohya/img" --enable_wildcard --output_dir output --apply_t5_attn_mask --guidance_scale 1.0 --model_prediction_type raw --sample_prompts="/home/ara/Dev/training/earthscape/kohya/dreambooth/sample/prompt.txt" --sample_every_n_steps="10" --debiased_estimation_loss --alpha_mask --timestep_sampling flux_shift --seed 42 --kahan_summation --loss_type huber --huber_c 1.0 --huber_scale 2.5 --huber_schedule='exponential' --optimizer_type adafactor --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" --train_batch_size 5 --blocks_to_swap 35 --fused_backward_pass --flip_aug --learning_rate 4e-7 --pretrained_model_name_or_path "/home/ara/Dev/flux/flux1-dev.safetensors"

(I switched out my already-trained FFT model for the base flux1-dev model in that command line. If you're actually training from the base with kahan summation, you probably want a higher LR of something like 2e-6, rather than a polishing LR of 4e-7)

Edit: I removed the --ip_noise_gamma and --ip_noise_gamma_random_strength parameters from my command line, and learning seems much stronger without them, at least for my training datasets.

@blackmagic24
Copy link

The technique comes at a price: training steps take around 40% longer in my tests so far.

The StableAdamW optimizer also supports Kahan summation and can be used with sd-scripts without any issues. However, I haven’t noticed any significant performance differences compared to other optimizers. It might still be worth a look.

https://pytorch-optimizers.readthedocs.io/en/latest/optimizer/#pytorch_optimizer.StableAdamW

optimizer_type = "pytorch_optimizer.StableAdamW"
optimizer_args = [
"kahan_sum=True",
"weight_decay=0.01",
"weight_decouple=True",
]

@araleza
Copy link
Author

araleza commented Aug 1, 2025

Hi @blackmagic24, thanks for your reply.

I did notice that the pytorch_optimizer's adamW implementation had a feature that it referred to as a kahan_sum, but I think it might be an unrelated feature that also shares the name of the technique, as kahan summation is an idea that's applicable to more than one area. When I looked at the source code for pytorch_optimizer's adamw implementation, the only changes that activating its kahan_sum feature does is this:

image

That's from:

https://github.com/kozistr/pytorch_optimizer/blob/5a985ade9c6a887c6507a061d1ce16673882a155/pytorch_optimizer/optimizer/adamw.py#L16

I'm not exactly certain what that section of code does, but it doesn't seem to be keeping the lower bits of the weights between steps, but rather it's doing something with the exponential averages. And it doesn't seem to be sending the bits to the CPU between training steps, so it doesn't seem like it has the same purpose as what my Kahan summation code is doing. But if you have any more information about this feature for adamW, I'd be interested to know about it.

@araleza
Copy link
Author

araleza commented Aug 1, 2025

Actually, looking more at that AdamW code, it does seem to be doing a floating point sub_() of the bfloat16 value against the 'real' value, and an add_() to the kahan_comp variable. The parts with the exponential moving averages are the same between the kahan version and the non-kahan version, so it's not related to that at all.

I think this code is more like the first version of my pull request, where I kept an f32 offset on the CPU between training steps, rather than directly keeping the lower 16 bits that were lost. It isn't been sent to the CPU between steps though, so it wouldn't work in low-memory conditions like the Adafactor version in my pull request would.

An interesting thing is that that pytorch_optimizer's version of AdamW doesn't do stochastic rounding. While it's an interesting question whether Kahan summation beats stochastic rounding for Flux training or not, both of these should strongly outperform an optimizer that isn't doing either of them. So you should see a large improvement for that AdamW optimizer if you switch on kahan_sum. And the fact that you aren't is curious.

I'm sure this is unlikely, but you're not actually training a LoRA instead of doing full fine-tuning are you? Cause LoRAs are f32 all the time, so kahan_sum would have no effect there.

I might give that AdamW optimizer a try at some point soon, although I'll need to check if it even can be run on my 32 GB graphics card when doing full fine tuning.

@FurkanGozukara
Copy link

i have recently compared AdamW vs Adafactor and Adafactor yielded best realistic results

@araleza
Copy link
Author

araleza commented Aug 1, 2025

i have recently compared AdamW vs Adafactor and Adafactor yielded best realistic results

That might be because sd-scripts' current Adafactor does have stochastic rounding, and the AdamW implementation doesn't. AdamW should typically outperform Adafactor if both are using stochastic rounding - or both aren't - but AdamW uses much more memory, which is why projects often switch to Adafactor in low memory conditions.

@araleza
Copy link
Author

araleza commented Aug 1, 2025

I might give that AdamW optimizer a try at some point soon, although I'll need to check if it even can be run on my 32 GB graphics card when doing full fine tuning.

Yep, I can't run fine tuning with pytorch_optimizer.StableAdamW, even on my 32 GB graphics card. If I try to use that optimizer, I get the error message:

fused_backward_pass currently only works with optimizer_type Adafactor

If I switch off --fused_backward_pass to get the optimizer to work, I get a CUDA out of memory error, even with batch size 1 and 35 swapped blocks. So it seems a moot point for anyone fine tuning with a GPU that has 32 GB or less whether the StableAdamW optimizer supports Kahan summation or not.

@FurkanGozukara
Copy link

stochastic rounding

this makes sense ty for info

@kohya-ss
Copy link
Owner

Thanks, sorry for the late response. This is very interesting. I'll take a closer look at the code.

Copy link
Owner

@kohya-ss kohya-ss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that having kahan_residuals on the CPU reduces the amount of GPU memory required. This is a kind of CPU offloading.

I think the code would be more organized if kahan_residuals were stored in state, but is that possible?


kahan_residuals = []
tensor_index = 0
prev_step = 0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since step starts from 0, it would be better to set this to -1. The tensor_index of the first step starts from 1, which will cause a mismatch with the next step.

@araleza
Copy link
Author

araleza commented Aug 20, 2025

I think the code would be more organized if kahan_residuals were stored in state, but is that possible?

Yes, I've now made this change - and the code is simpler now too.

I think that the step and tensor_index values that you mentioned were already correct, but perhaps the code was not clear here. However, now I'm using the optimizer state to store the kahan residuals, both step and tensor_index have been deleted from the copy_kahan_() function, so they are definitely not a problem now. 😊

@kohya-ss kohya-ss requested a review from Copilot August 24, 2025 08:42
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces Kahan summation as an alternative to stochastic rounding for bfloat16 training with the Adafactor optimizer in Flux model training. This technique preserves the precision lost during bf16 quantization by storing the residual bits and reapplying them in subsequent training steps.

Key changes:

  • Implements Kahan summation algorithm that offloads quantization residuals to CPU memory
  • Adds --kahan_summation command-line argument with compatibility validation
  • Modifies the Adafactor optimizer step to conditionally use Kahan summation instead of stochastic rounding

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
library/adafactor_fused.py Implements copy_kahan_ function and integrates it into the Adafactor parameter update process
flux_train.py Adds Kahan summation CLI argument, validation logic, and passes the setting to the optimizer

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +36 to +46
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
Copy link

Copilot AI Aug 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring incorrectly describes the source parameter. Both target and source are described as 'the target tensor' - the source parameter should be described as 'the source tensor with dtype=float32'.

Suggested change
"""
Copies source into target using Kahan summation.
The lower bits of the float32 weight that are lost on conversion to bfloat16
are sent to the CPU until the next step, where they are re-added onto the weights
before adding the gradient update. This produces near float32-like weight behavior,
although the copies back and forth to main memory result in slower training steps.
Args:
target: the target tensor with dtype=bfloat16
source: the target tensor with dtype=float32
source: the source tensor with dtype=float32

Copilot uses AI. Check for mistakes.

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@araleza
Copy link
Author

araleza commented Aug 24, 2025

I've also included the Kahan summation code in my separate (and more impactful / higher image quality) AdamW for 5090 branch here:

#2187

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants