Diffusion models exhibit impressive generative capabilities but are significantly impacted by exposure bias. In this paper, we make a key observation: the energy of the predicted noisy images decreases during the diffusion process. Building on this, we identify two important findings: 1) The reduction in energy follows distinct patterns in the low-frequency and high-frequency subbands; 2) This energy reduction results in amplitude variations between the network-reconstructed clean data and the real clean data. Based on the first finding, we introduce a frequency-domain regulation mechanism utilizing wavelet transforms, which separately adjusts the low- and high-frequency subbands. Leveraging the second insight, we provide a more accurate analysis of exposure bias in the two subbands. Our method is training-free and plug-and-play, significantly improving the generative quality of various diffusion models and providing a robust solution to exposure bias across different model architectures.
This is the codebase for our paper Frequency Regulation for Exposure Bias Mitigation in Diffusion Models (ACM MM2025) The repository is heavily based on EDM For environment setup, datasets preparation, pre-trained models loading, and fid calculation, please refer to the official EDM code repository.
After setting up the environment required for the EDM base model, you also need to install the packages related to wavelet transform.
pip install pytorch_wavelets
pip install PyWaveletsNow, we integrate the frequency-domain regulation mechanism based on wavelet transform into the EDM base model.
(1) Import the required packages.
from pytorch_wavelets import DWTForward, DWTInverse(2) Define DTW and IDWT.
dwt = DWTForward(J=1, mode='zero', wave='haar').cuda()
iwt = DWTInverse(mode='zero', wave='haar').cuda()In addition, the wave can be set to other wavelet bases, such as db4 and sym8. (3) Define frequency_regulation fuctions.
def frequency_x_low(x, scaler):
x = x.to(torch.float32)
xl, xh = dwt(x)
xl = scaler * xl
x_new = iwt((xl, xh))
return x_new
def frequency_x_high(x, scaler):
x = x.to(torch.float32)
xl, xh = dwt(x)
xh = [i * scaler for i in xh]
x_new = iwt((xl, xh))
return x_new(4) Frequency regulation. Low-frequency regulation, as shown in Eq.17 of our paper (decreasing-type).
x_next = frequency_x_low(x_next, (1 + (t_steps[i] / sigma_max) * 0.036)) High-frequency regulation, as shown in Eq.18 of our paper (switch-off-type).
if i >=10:
x_next = frequency_x_high(x_next, eps_scaler)We appreciate it if you cite the following paper:
@InProceedings{yumACMmm25,
author = {Meng Yu and Kun Zhan},
booktitle = {ACM Multimedia},
title = {Frequency Regulation for Exposure Bias Mitigation in Diffusion Models},
year = {2025},
volume = {33},
}
If you have any questions, feel free to contact me. (Email: ice.echo#gmail.com)
