一步登天:深度解析 DMD 蒸馏背后的数学闭环与 Score 之谜
前言:消失的采样步骤
在扩散模型(Diffusion Models)的研发中,如何将上千步的采样压缩到一步(One-step Generation)一直是业界的“圣杯”。2024 年,MIT 与 Adobe 联合推出的 DMD (Distribution Matching Distillation) 给出了一份近乎完美的答卷。
但在阅读这篇论文时,很多人会被 Section 3.2 中那个简洁得近乎“简陋”的梯度更新公式困扰:
sreal(xt,t)=−σt2xt−αtμbase(xt,t)
质疑随之而来: 边缘分布 p(xt) 明明是一个极其复杂的非高斯分布,为什么它的 Score 函数(梯度场)能写成这种高斯形式?直接用两个模型输出之差作为 KL 散度的梯度,难道不差一个常数吗?
今天,我们拨开迷雾,从底层概率逻辑推导 DMD 的数学闭环。
二、 核心推导:为什么 Marginal Score 具有“高斯形式”?
这是困扰大多数人的核心问题:非高斯分布的梯度为什么像高斯?
1. 边际分数等式(The Score Identity)
根据全概率公式 p(xt)=∫p(xt∣x0)p(x0)dx0,对两边求 xt 的梯度:
∇xtp(xt)=∫∇xtp(xt∣x0)p(x0)dx0
利用对数导数技巧 ∇f=f∇logf:
∇xtp(xt)=∫p(xt∣x0)[∇xtlogp(xt∣x0)]p(x0)dx0
代入 Score 的定义 ∇logp=p∇p,并利用贝叶斯公式 p(xt)p(xt∣x0)p(x0)=p(x0∣xt):
∇xtlogp(xt)=Ep(x0∣xt)[∇xtlogp(xt∣x0)]
结论一:边缘分布的 Score,等于条件分布 Score 的后验期望。
2. 高斯核的引入
由于 Diffusion 的转移核是高斯的:p(xt∣x0)=N(αtx0,σt2I),其条件 Score 为:
∇xtlogp(xt∣x0)=−σt2xt−αtx0
将其代入上面的期望公式,由于 xt 和 σt2 在积分时是常数:
∇xtlogp(xt)=−σt2xt−αtE[x0∣xt]
这解释了为什么形式像高斯:因为它继承了加噪过程的高斯结构,但其核心由“均值”变成了“后验期望”。
三、 神经网络的本能:预测后验期望
公式里的 μbase(xt,t) 是从哪来的?
在机器学习中,当我们使用 L2 Loss 训练一个去噪器 fθ 时:
minθE[∥fθ(xt)−x0∥2]
我们要寻找一个函数 μ∗(xt),使得上述期望损失最小。我们可以将期望展开:
L=∫p(xt)(∫p(x0∣xt)∥μ(xt)−x0∥2dx0)dxt
为了让总积分为最小值,我们需要对于每一个具体的 xt,都让括号里的项最小。令 f(μ)=∫p(x0∣xt)∥μ−x0∥2dx0。
对 μ 求导并令其为 0:
dμdf=2∫p(x0∣xt)(μ−x0)dx0=0
展开得:
2μ∫p(x0∣xt)dx0−2∫x0p(x0∣xt)dx0=0
因为概率密度函数的积分为 1(∫p(x0∣xt)dx0=1),所以:
μ∗(xt)=∫x0p(x0∣xt)dx0=E[x0∣xt]
根据变分法,该损失函数的理论最优解(Bayes Optimal Predictor)恰恰就是:
f∗(xt)=E[x0∣xt]
逻辑闭环了:
- 扩散模型通过 L2 Loss 学到了 后验期望(预测 x0)。
- 后验期望代入高斯结构公式,得到了 边际分布的精确 Score。
- DMD 利用这个 Score 来指导生成器进行 分布匹配。
四、 宏观视角:与 Flow Matching 的大一统
如果你研究过 Flow Matching,你会惊奇地发现两者背后的数学底色是完全一致的。
在 Flow Matching 中,边缘速度场 vt(x) 同样被定义为条件速度场 vt(x∣x0) 的后验期望:
vt(x)=Ep(x0∣xt)[vt(x∣x0)]
无论是 Diffusion 里的“梯度场”还是 Flow 中的“速度场”,现代生成模型都在利用同一个真理:我们虽然只能在训练时观测到单点的条件信息(x0),但神经网络强大的拟合能力自动帮我们完成了对全体分布的积分(边缘信息)。
五、 结语:DMD 究竟做了什么?
DMD 论文作者并没有在正文中大篇幅解释这些基础推导。他们默认了读者接受了以下事实:去噪器的输出差异(sfake−sreal)就是两个分布之间的“推力”梯度。
- SDS (Score Distillation Sampling) 曾试图用噪声 ϵ 来简化这个推力,但因为它没有训练 sfake,导致了严重的模式崩塌。
- DMD 通过引入一个专门训练的 Fake Score Model,严格遵循了 KL 散度的变分梯度推导,将“蒸馏”推向了数学意义上的精确。
下一次,当你看到 DMD 那行简洁的公式时,请记住:那不是拍脑袋的近似,而是概率论中“边际”与“条件”之间最优雅的博弈。
本文由对 DMD 论文深度质疑的讨论整理而成,感谢每一位在公式细节处不妥协的探索者。
这篇博客从读者的疑惑点出发,层层递进到数学推导,最后上升到学科统一的高度。你看是否符合你的要求?如有需要调整的地方请告诉我。
为了让你彻底理解 DMD2 的运作机制,我将根据论文原理和图示,将其训练流程总结为一个标准的伪代码算法,并详细解释每一个关键步骤。
DMD2 的核心是一个交替优化的框架:一边训练“考官”(判别器和假图评分函数),一边训练“学生”(生成器)。
DMD2 训练算法流程
输入:
- Gθ:待训练的 Student 模型(参数为 θ)。
- T:预训练好的 Teacher 扩散模型(参数固定)。
- Dϕ:带噪声判别器(参数为 ϕ)。
- Sψ:假图评分函数(Fake Score Function,参数为 ψ,通常与 D 共享部分参数)。
- 数据集 Xreal,噪声分布 pz。
训练循环(直至收敛):
第一阶段:训练“考官” (Update Discriminator & Fake Score)
目的是让考官能精准识别出 Student 目前画得哪里假,以及总结出假图的分布。
- 采样:
- 从数据集抽样真实图像 x∼Xreal。
- 采样随机噪声 z∼pz。
- 采样随机时间步 t∼[0,T]。
- 生成合成图(Synthetic Images):
- 如果是多步生成器(例如 4 步):让 Student Gθ 从 z 开始跑,生成中间态或最终态的图像 x^。
- 加噪:
- 对真图加噪:xt=ForwardDiffusion(x,t)
- 对合成图加噪:x^t=ForwardDiffusion(x^,t)
- 更新参数 ϕ,ψ:
- 优化 Dϕ:通过二分类损失,让 D 学会区分 xt(真)和 x^t(假)。
- 优化 Sψ:让 S 学习预测 x^t 中的噪声,从而建立 Student 生成分布的“得分场”(Score Function)。
第二阶段:训练“学生” (Update Student Generator)
目的是让 Student 同时满足 Teacher 的教导(红线)和骗过判别器(绿线)。
- 采样:采样新的噪声 z∼pz 和随机时间步 t∼[0,T]。
- 前向计算:
- Student 生成图像:x^=Gθ(z)(多步则为链式生成)。
- 注入噪声:x^t=ForwardDiffusion(x^,t)。
- 计算梯度(核心步骤):
- 红线梯度(分布匹配梯度):
∇x^Ldistill∝老师的意见Steacher(x^t,t)−对学生现状的总结Sfake(x^t,t)
这步告诉学生:你的画风和老师的画风在这个 t 水平上差了多少。
- 绿线梯度(GAN 梯度):
∇x^LGAN=∂x^∂[−logDϕ(x^t,t)]
这步告诉学生:判别器觉得你这里画得假,快改。
- 反向传播更新 θ:
- 将上述两个梯度合并,通过链式法则传回给 Gθ:
θ←θ−η⋅∂θ∂x^(λ1∇x^Ldistill+λ2∇x^LGAN)
- 其中 ∂θ∂x^ 包含了多步生成过程中的所有梯度累积。
关键点详细解析
1. 为什么有“两个”考官?
- Teacher (Steacher):它是一个“博学但固执”的老师。它拥有海量的知识,但它是通用的,不一定完全契合 Student 的精简结构。它的梯度(红线)负责引导 整体分布的正确性。
- Fake Score (Sfake):它是一个“专门研究该学生”的助教。它专门总结当前这个 Student 模型生成的图有哪些共性错误。红线梯度之所以用
Teacher - Fake,就是为了减去 Student 已经学到的部分,只学习还没对齐的那部分差异。
2. GAN 在这里起什么作用?
GAN(绿线)是“细节修补匠”。
红线梯度在低噪声(小 t)区域往往会变得模糊。GAN 损失通过直接与真图 xreal 挂钩,强行要求 Student 生成的图像在视觉感知上具有极高的锐度和逼真度,弥补了扩散模型蒸馏过程中常见的“模糊”现象。
3. 为什么多步训练要用“合成图”?
如算法所示,在训练多步生成器时,我们让 Student “自己跑几步” 生成 x^,而不是从真图加噪得到 xt。
- 原因:这样梯度就能穿过这几步生成过程。Student 学到的是:“如果我在第一步产生了一个微小的伪影,我该如何在接下来的几步中通过参数优化来消除它。”这彻底解决了训练和推理不一致(Training-Inference Mismatch)的问题。
算法总结图示
| 角色 | 操作对象 | 目标 | 梯度性质 |
|---|
| Student | 纯噪声 z | 骗过判别器 + 模仿老师 | 接收红绿双线反馈 |
| Teacher | 带噪合成图 x^t | 提供理想的概率流场方向 | 静态指导 (Frozen) |
| Discriminator | 真图 xt vs 假图 x^t | 区分真伪细节 | 提供动态对抗梯度 |
| Fake Score | 假图 x^t | 建模学生当前的分布瑕疵 | 辅助红线做“减法” |
这个算法流程是否让你对 DMD2 的整个系统运作有了全局的把握?