Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protect alphas_cumprod during refiner switchover #14979

Merged
merged 5 commits into from Mar 2, 2024

Conversation

drhead
Copy link
Contributor

@drhead drhead commented Feb 21, 2024

Description

  • There is currently a bug, mentioned in Make refiner switchover based on model timesteps instead of sampling steps #14978, where when the refiner switches on, the first step it performs uses the original alphas_cumprod schedule, which causes problems if zero SNR is enabled.
  • This stores and re-applies the model alphas_cumprod when refiner switchover happens. This fixes outputs in rare cases where the change in noise schedules is significant enough to change the called timestep to something outside of the range of the refiner.

Screenshots/videos:

Before fix, image generated on DPM++ 2M, overridden with Karras schedule and sigma_max of 1500, 50 steps:
99942-36463525-(best quality, high quality,_1 5) by strange-fox, solo, anthro, male, rat, manly, clothed, jacket, shirt, detailed background, n
This specific schedule causes one of the sampling steps to be changed from timestep 190 to timestep 200. The highest timestep a typical refiner is trained for is 199 (last 200, zero indexed), so this is out of the range for the model and causes extra noise in the output.

After the fix is applied:
99943-36463525-(best quality, high quality,_1 5) by strange-fox, solo, anthro, male, rat, manly, clothed, jacket, shirt, detailed background, n
The resulting image looks much cleaner (particularly the background at the top left).

Checklist:

@drhead drhead closed this Feb 21, 2024
@drhead drhead reopened this Feb 21, 2024
@drhead drhead changed the base branch from master to dev February 21, 2024 00:37
@AUTOMATIC1111
Copy link
Owner

reload_model_weights is used in many places, and if there are issues with it working, those should be solved in it, rather than outside of it.

@drhead
Copy link
Contributor Author

drhead commented Feb 26, 2024

I'm not sure this should really be considered an issue with reload_model_weights as much as it is an issue with how the main processing loop handles alphas_cumprod -- as a temporary state reflecting the current settings, for the current sampling step, which will be updated again on the next step. This isn't really applicable to any other context where reload_model_weights is called (three of which are initialization related, and one of which is xyz plot related), and intuitively I would prefer avoiding doing something that could cause problems in the areas where this doesn't apply.

I could look into moving it to within the reload function, but I think that this would have more potential to cause issues in other parts of the code and I would need to test it further to verify that it doesn't. In the mean time, I've removed protection of alphas_cumprod_original since that isn't necessary to fix the bug.

@AUTOMATIC1111
Copy link
Owner

AUTOMATIC1111 commented Feb 26, 2024

I just don't get this or why it helps. How is it happening that alphas_cumprod is incorrect after sd_models.reload_model_weights ?

Edit: I guess it's because of this code in processing.py:

                if opts.sd_noise_schedule == "Zero Terminal SNR":
                    p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
                    p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)

But then there should be the same problem when switching models during hires fix.

@drhead
Copy link
Contributor Author

drhead commented Feb 26, 2024

It is related to the implementation of zero terminal SNR and the compatibility fix that came with it. Doing either of those things requires changing the model alphas_cumprod to reflect those changes, and how it is implemented currently is that this is done at runtime, before each sampling step, as an override.

When the model weights are reloaded, this override is undone, since the newly-loaded model weights come with their own alphas_cumprod value (which in almost every case will be the same as it is by default on the other model, but which will be different if we overrode the values to something other than what is in the model). And then this will result in the wrong timestep being called for the value of sigma used on that sampling step. As in my example, the sigma value of 0.5725 would correspond to timestep 190 under a zero terminal SNR schedule, but would correspond to timestep 200 on the default schedule -- which is a problem, because it means that not only is the schedule now wrong, it means the refiner is now sampling from a step it was never meant to sample from. The alphas_cumprod override will be reapplied next step, but at this point there's already going to be artifacts in the final image from the bad refiner step. I do see the logical solution to this problem to be maintaining the override through the model switch.

I'll look into whether hiresfix has the same issue, though I would expect it to be far more benign there in any case.

@AUTOMATIC1111
Copy link
Owner

If I'm understanding this right, the proper fix, I think, would be to take this code from processing:

            def rescale_zero_terminal_snr_abar(alphas_cumprod):
                alphas_bar_sqrt = alphas_cumprod.sqrt()

                # Store old values.
                alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
                alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

                # Shift so the last timestep is zero.
                alphas_bar_sqrt -= (alphas_bar_sqrt_T)

                # Scale so the first timestep is back to the old value.
                alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

                # Convert alphas_bar_sqrt to betas
                alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
                alphas_bar[-1] = 4.8973451890853435e-08
                return alphas_bar

            if hasattr(p.sd_model, 'alphas_cumprod') and hasattr(p.sd_model, 'alphas_cumprod_original'):
                p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod_original.to(shared.device)

                if opts.use_downcasted_alpha_bar:
                    p.extra_generation_params['Downcast alphas_cumprod'] = opts.use_downcasted_alpha_bar
                    p.sd_model.alphas_cumprod = p.sd_model.alphas_cumprod.half().to(shared.device)
                if opts.sd_noise_schedule == "Zero Terminal SNR":
                    p.extra_generation_params['Noise Schedule'] = opts.sd_noise_schedule
                    p.sd_model.alphas_cumprod = rescale_zero_terminal_snr_abar(p.sd_model.alphas_cumprod).to(shared.device)

And put it in a separate function, and call this function both in processing where it was originally and in sd_models.load_model_weights.

@drhead
Copy link
Contributor Author

drhead commented Feb 27, 2024

I'm not entirely sure where the best place to put the alpha override function is, but I had to relocate it to sd_models.py to avoid circular imports, however it can be easily moved to any other location that won't cause circular import errors if needed. The fix works as it stands.

@AUTOMATIC1111 AUTOMATIC1111 merged commit 06b9200 into AUTOMATIC1111:dev Mar 2, 2024
3 checks passed
AUTOMATIC1111 added a commit that referenced this pull request Mar 2, 2024
AUTOMATIC1111 added a commit that referenced this pull request Mar 2, 2024
Protect alphas_cumprod during refiner switchover
AUTOMATIC1111 added a commit that referenced this pull request Mar 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants