教师模型必须输出logits而非概率,学生模型也需输出同维logits;KL损失用tf.nn.softmax_cross_entropy_with_logits计算,两logits均除以温度T;总损失为alpha×hard_loss+(1−alpha)×kl_loss×T²。
教师模型输出 logits 还是概率?学生模型怎么接?
蒸馏的关键不是“模仿预测结果”,而是模仿教师对各类别的置信度分布——所以教师必须输出 logits(未归一化的原始分),而不是 tf.nn.softmax(logits) 后的概率。学生模型也必须输出同维度的 logits,否则 KL 散度无法计算。
常见错误:教师用 model(x, training=False) 得到 softmax 概率,再拿去算 KL,这会丢失温度缩放信息,梯度变弱;或者学生最后加了 softmax 层,导致 loss 计算时多套了一层非线性。
- 教师前向必须保留
logits,例如:teacher_logits = teacher(x, training=False) - 学生前向只到 logits 层,不加 softmax;训练时用
tf.nn.softmax(student_logits / T)和tf.nn.softmax(teacher_logits / T)算 KL - 温度
T通常设为 3–20,T=1退化为普通交叉熵
KL 散度损失怎么写?别用 tf.keras.losses.KLDivergence
tf.keras.losses.KLDivergence 是单向、非对称的,且默认要求输入是概率分布(sum=1),但蒸馏中要用带温度的 soft target,必须手动实现带温度的 KL,并确保数值稳定。
正确做法是用 tf.nn.softmax_cross_entropy_with_logits 的等价形式:先对教师 logits 做 soft label,再算学生 logits 在该分布下的交叉熵(即 KL 散度 + 常数项)。
- 推荐写法:
kl_loss = tf.nn.softmax_cross_entropy_with_logits(labels=tf.nn.softmax(teacher_logits / T),logits=student_logits / T) - 注意两个 logits 都要除以同一
T,否则量纲错乱 - 该 loss 已自动减去 log-sum-exp,无需额外 clip 或加 epsilon
怎么组合蒸馏 loss 和真实标签 loss?权重和比例怎么调?
纯蒸馏 loss 容易让模型忽略真实标签(尤其小样本或类别不平衡时),必须混合监督 loss(如 sparse categorical crossentropy)。但两者的量级差异大,直接相加会导致一方主导。
云境标书AI
云境标书AI官网,国内领先招投标AI工具,简单3步,千页标书一键全文。
典型做法是加权求和:total_loss = alpha * hard_loss + (1 - alpha) * kl_loss * T^2。其中 T^2 是关键补偿项——因为 KL 对 logits 缩放后敏感度下降,乘上 T^2 能大致恢复梯度幅值。
-
alpha一般取 0.1–0.5;任务越难、教师越强,alpha可越小 - 务必验证
hard_loss和kl_loss * T^2的数值在同量级(比如都在 1–5 左右),否则调整alpha或T - 不要在验证集上用 soft label 做评估——测试时学生模型就用原始 logits + argmax
TensorFlow 2.x 中如何避免 eager mode 下的梯度追踪污染?
教师模型若参与梯度计算(比如误写成 with tf.GradientTape() as tape: 并把 teacher 放进 tape),会导致显存暴涨甚至 OOM。教师必须明确设为不可训练、且不被 tape watch。
- 加载教师模型后立刻执行:
teacher.trainable = False,并在前向时确保它不在tf.GradientTape的 watch 列表中 - 更稳妥写法:用
@tf.function包裹训练 step,并在函数内用teacher(x, training=False)——TF 会自动跳过不可训练变量的梯度路径 - 调试时可打印
tape.watched_variables(),确认列表里只有学生模型参数
蒸馏真正难的不是公式,而是 logits 归一化方式、温度补偿系数、loss 量级对齐这三个点——漏掉任意一个,模型性能可能比不蒸馏还差。
就爱读