MMD-Critic 方法解析

发布时间:2026/5/19 23:47:02

MMD-Critic 方法解析 原文towardsdatascience.com/the-mmd-critic-method-explained-c6a77f2dbf18?sourcecollection_archive---------7-----------------------#2024-08-27一种强大但鲜为人知的数据总结和可解释性 AI 方法https://medium.com/physboom?sourcepost_page---byline--c6a77f2dbf18--------------------------------https://towardsdatascience.com/?sourcepost_page---byline--c6a77f2dbf18-------------------------------- Matthew Chak·发表于 Towards Data Science ·11 分钟阅读·2024 年 8 月 27 日–尽管 MMD-Critic 方法是一个强大的数据总结工具但它的使用和“覆盖面”却出奇地少。也许这是因为存在一些更简单且更成熟的数据总结方法例如 K-medoids参见 [1] 或更简单地查看 维基百科页面又或者是因为在此之前并没有 Python 包支持该方法直到现在。无论如何原始论文 [2] 中展示的结果比 MMD-Critic 当前的应用更值得广泛使用。因此我将在这里尽可能清晰地解释 MMD-Critic 方法。我还发布了一个 开源 Python 包其中实现了该技术方便你轻松使用。原型和批评在深入探讨 MMD-Critic 方法之前值得讨论一下我们究竟想要实现什么目标。最终我们希望通过一个数据集找到代表性的数据示例原型以及可能干扰我们机器学习模型的边缘案例批评。https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/71c22de7c81736578e26d81da3a21ce2.png来自 [2] 的 MNIST 数据集的原型和批评示例。这种方法可能有很多有用的地方通过查看典型和非典型的示例我们可以获得对数据集的一个非常好的总结视图我们可以通过测试模型在批评样本上的表现来了解它们如何处理边缘案例显而易见这是非常重要的虽然可能没有那么实用我们可以使用原型来创建一个类似 K-means 的自然可解释算法其中最接近新数据点的原型被用来标记该数据点。然后解释就很简单因为我们只需要向用户展示最相似的数据点。更多你可以在这本书的第 6.3 节中找到更多关于该应用的信息以及关于 MMD-Critic 的一个不错的解释但可以简单地说找到这些例子在很多方面都是有用的。MMD-Critic 使我们能够做到这一点。最大均值差异不幸的是我不能声称对最大均值差异MMD有超严格的理解因为这种理解需要有较强的泛函分析背景。如果你有这样的背景你可以在这里找到介绍该度量的论文。简单来说MMD 是一种确定两个概率分布之间差异的方法。正式地对于两个概率分布P和Q我们定义它们的 MMD 为https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/806f69f4f5fb7bdfddc1e2b7e198f203.png两个分布 PQ 的 MMD 公式这里F是任何函数空间——即具有相同定义域和陪域的函数集合。还要注意符号x~P意味着我们将x视为从分布P中抽取的随机变量——即x是由P描述的。因此这个公式找到当X和Y通过我们空间F中的某个函数变换时它们期望值之间的最大差异。这可能有点难以理解但这是一个例子。假设X是Uniform(0, 1)即从 0 到 1 随机选择一个数的分布而Y是Uniform(-1, 1)。我们还假设F是一个包含三个函数的简单函数族——f(x) 0f(x) x和f(x) x²。在我们的空间中对每个函数进行迭代得到在f(x) 0 的情况下当x ~ P时E[f(x)]为 0因为无论选择什么xf(x)都会是 0。对于x ~ Q的情况也一样。因此我们得到一个均值差异为 0。在f(x) x的情况下我们有 E[f(x)] 0.5对于 P 的情况和 0对于 Q 的情况因此我们的均值差异为 0.5。在f(x) x²的情况下我们注意到https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ea72888c6295343ec4a233c883265a0a.png变换函数 f 作用下随机变量x的期望值公式因此在 P 的情况下我们得到https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f96988df2e03471d79f8585d36fc3596.png在分布 P 下f(x)的期望值在 Q 的情况下我们得到https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/25a4e021d3c2c5b36bcaaa3c774de110.png在分布 Q 下f(x)的期望值因此在这种情况下我们的差异也是 0。我们函数空间上的上确界为 0.5所以这就是我们的 MMD。现在你可能会注意到我们 MMD 的一些问题。它似乎高度依赖于我们选择的函数空间并且在大型或无限函数空间中计算起来也显得非常昂贵甚至是不可能的。不仅如此它还要求我们知道分布P和Q这是不现实的。后一个问题很容易解决因为我们可以重写我们的 MMD 度量利用基于数据集的 P 和 Q 的估计https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2140a40093f3dd0d76850321ceb07100.png使用 P 和 Q 的估计的 MMD这里我们的x是从数据集中抽样得到的P的样本而y是从Q中抽取的样本。前两个问题可以通过一些额外的数学手段来解决。无需过多细节事实证明如果F是一个叫做再生核希尔伯特空间RKHS的东西我们就能预先知道哪个函数将给我们提供 MMD。即它是下面这个函数称为见证函数https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/e46b5d340b5f3db6a69efe75e2b8cc47.png我们在 RKHS 中的最优 f(x)其中k是与 RKHS 相关的核函数内积¹。直观地说这个函数在点x处“见证”了P和Q之间的差异。因此我们只需要选择一个足够具表现力的 RKHS/核函数——通常使用 RBF 核其核函数为https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/f5ea568dd5266a59359ef3961d188f9c.pngRBF 核其中 sigma 是一个超参数这通常能得到相当直观的结果。例如这里是使用 RBF 核估计的见证函数的图以与之前相同的方式进行估计——即用求和替代期望在从Uniform(-0.5, 0.5)和Uniform(-1, 1)两个数据集抽样时的图示https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/fd302a4cc896b43064206d874d7409da.png在不同点上见证函数对于两个均匀分布的值生成上述图形的代码如下importnumpyasnpimportmatplotlib.pyplotaspltimportseabornassnsdefrbf(v1,v2,sigma0.5):returnnp.exp(-(v2-v1)**2/(2*sigma**0.5))defcomp_wit_fn(x,d1,d2):return1/len(d1)*sum([rbf(x,dp)fordpind1])-1/len(d2)*sum([rbf(x,dp)fordpind2])low1,high1-0.5,0.5# Range for the first uniform distributionlow2,high2-1,1# Range for the second uniform distribution# Generate data for the uniform distributionsdata1np.random.uniform(low1,high1,10000)data2np.random.uniform(low2,high2,10000)# Generate a range of x values for which to compute comp_wit_fnx_valuesnp.linspace(min(low1*2,low2*2),max(high1*2,high2*2),100)comp_wit_values[comp_wit_fn(x,data1,data2)forxinx_values]sns.kdeplot(data1,labelfUniform({low1},{high1}),colorblue,fillTrue)sns.kdeplot(data2,labelfUniform({low2},{high2}),colorred,fillTrue)plt.plot(x_values,comp_wit_values,labelWitness Function,colorgreen)plt.xlabel(Value)plt.ylabel(Density / Wit Fn)plt.legend()plt.show()最后的 MMD-Critic 方法MMD-Critic 背后的思路现在相当简单——如果我们想找到k个原型我们需要找到最能匹配原始数据集分布的原型集通过它们的平方 MMD。换句话说我们希望找到一个子集P它的基数是k并且最小化MMD²(F, X, P)。无需过多细节平方 MMD 由以下公式给出https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/0dff52d10b914ee9e6082e9b47cb6292.png平方 MMD 度量假设 X ~ PY ~ Qk 是我们 RKHS F 的核函数在找到这些原型之后我们接着选择那些假设原型分布与数据集分布在最不同的点作为批评点。正如我们之前所见两个分布在某一点的差异可以通过我们的见证函数来衡量所以我们只需要在X和P的上下文中找到最大化其绝对值的点。换句话说我们将批评“得分”定义为https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/d19b02d21f693952776695457da054ee.png一个批评 c 的“得分”或者以更易用的近似形式https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/2a763ae077e95c0041dba13b5c4fab86.png对于批评 c 的近似 S©然后为了找到我们想要的批评数量假设是m个我们只需要找到大小为m的集合C使得它能够最大化https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/8eacec3442f5cc088cb6704f587a7190.png为了促进选择更多样化的批评论文还建议添加一个正则化项鼓励选中的批评尽可能分开。论文中建议的正则化项是对数行列式正则化项尽管这不是必须的。我在这里不会详细讲解因为它不是关键内容但论文建议参考[6]²。因此我们可以实现一个极其天真的MMD-Critic没有批评正则化如下所示不要使用这个importmathimportitertoolsdefeuc_distance(p1,p2):returnmath.sqrt(sum((x-y)**2forx,yinzip(p1,p2)))defrbf(v1,v2,sigma0.5):returnmath.exp(-euc_distance(v1,v2)**2/(2*sigma**0.5))defmmd_sq(X,Y,sigma0.5):sm_xx0forxinX:forx2inX:sm_xxrbf(x,x2,sigma)sm_xy0forxinX:foryinY:sm_xyrbf(x,y,sigma)sm_yy0foryinY:fory2inY:sm_yyrbf(y,y2,sigma)return1/(len(X)**2)*sm_xx \-2/(len(X)*len(Y))*sm_xy \1/(len(Y)**2)*sm_yydefselect_protos(X,n,sigma0.5):min_score,min_submath.inf,Noneforsubsetinitertools.combinations(X,n):new_mmdmmd_sq(X,subset,sigma)ifnew_mmdmin_score:min_scorenew_mmd min_subsubsetreturnmin_subdefcriticism_score(criticism,prototypes,X,sigma0.5):returnabs(1/len(X)*sum([rbf(criticism,x,sigma)forxinX])\-1/len(prototypes)*sum([rbf(criticism,p,sigma)forpinprototypes]))defselect_criticisms(X,P,n,sigma0.5):candidates[cforcinXifcnotinP]max_score,crits-math.inf,[]forsubsetinitertools.combinations(candidates,n):new_scoresum([criticism_score(c,P,X,sigma)forcinsubset])ifnew_scoremax_score:max_scorenew_score critssubsetreturncrits优化 MMD-Critic上面的实现如此不实用以至于当我运行时未能在一个包含 25 个数据点的数据集中在合理的时间内找到 5 个原型。这是因为我们的 MMD 计算是O(max(|X|, |Y|)²)并且遍历每个长度为 n 的子集的时间复杂度是O(C(|X|, n))其中 C 是组合函数这导致了非常可怕的运行时复杂度。忽略使用更高效的计算方法例如使用纯 numpy/numexpr/矩阵计算而不是循环/其他以及缓存重复计算我们在理论层面上可以进行一些优化。首先最明显的性能瓶颈是我们在原型和批评方法中对C(|X|, n)子集的循环。我们可以用一个近似方法来代替循环n次每次贪婪地选择最优原型。这样我们可以将原型选择代码更改为defselect_protos(X,n,sigma0.5):protos[]for_inrange(n):min_score,min_protomath.inf,NoneforcandinX:ifcandinprotos:continuenew_scoremmd_sq(X,protos[cand],sigma)ifnew_scoremin_score:min_scorenew_score min_protocand protos.append(min_proto)returnprotos批评的情况也是类似的。还有一个重要的引理使得这个问题更易于优化。事实证明通过将原型选择问题转化为最小化问题并向成本中添加一个正则化项我们可以通过矩阵运算高效地计算成本函数。我不会在这里详细讨论但你可以查看原始论文了解详情。使用MMD-Critic包进行实验现在我们已经理解了 MMD-Critic 方法终于可以开始试验它了你可以通过运行以下命令来安装它pip install mmd-critic包中的实现比这里展示的要快得多所以不必担心。我们可以使用如下方式运行一个简单的示例使用斑点数据fromsklearn.datasetsimportmake_blobsfrommmd_criticimportMMDCriticfrommmd_critic.kernelsimportRBFKernel n_samples50# Total number of samplescenters4# Number of clusterscluster_std1# Standard deviation of the clustersX,_make_blobs(n_samplesn_samples,centerscenters,cluster_stdcluster_std,n_features2,random_state42)XX.tolist()# MMD critic with the kernel used for the prototypes being an RBF with sigma1,# for the criticisms one with sigma0.025criticMMDCritic(X,RBFKernel(1),RBFKernel(0.025))protos,_critic.select_prototypes(centers)criticisms,_critic.select_criticisms(10,protos)然后绘制点和批评结果得到https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ad6879f3f0b021cff94f57e71b28503f.png绘制找到的原型绿色和批评红色你会注意到我提供了使用单独内核进行原型和批评选择的选项。这是因为我发现批评的结果尤其对 sigma 超参数极其敏感。这是 MMD Critic 方法以及核方法的一项不幸限制。总体而言我发现使用较大的 sigma 值选择原型较小的 sigma 值选择批评效果较好。我们当然也可以使用更复杂的数据集。例如这里是该方法在 MNIST 上的应用³fromsklearn.datasetsimportfetch_openmlimportnumpyasnpfrommmd_criticimportMMDCriticfrommmd_critic.kernelsimportRBFKernel# Load MNIST datamnistfetch_openml(mnist_784,version1)images(mnist[data].astype(np.float32)).to_numpy()/255.0labelsmnist[target].astype(np.int64)criticMMDCritic(images[:15000],RBFKernel(2.5),RBFKernel(0.025))protos,_critic.select_prototypes(40)criticisms,_critic.select_criticisms(40,protos)这为我们提供了以下原型https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/6b5e036ce193ebc12d8db963cca7fcfb.pngMMD Critic 方法为 MNIST 找到的原型。MNIST 在 GPL-3.0 许可协议下可以免费用于商业用途。以及批评意见https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/37ca1ea60129a1dc22e293b08e4442e7.png由 MMD Critic 方法发现的批评意见相当不错吧结论这就是 MMD-Critic 方法的全部内容。核心非常简单使用起来不错除了需要调整 Sigma 超参数。我希望新发布的 Python 包能带来更多的应用。如有任何疑问请联系 mchakcalpoly.edu。所有图片均由作者提供除非另有说明。脚注[1] 如果你曾经学习过 SVM 和核技巧你可能会熟悉 RKHS 和核——在这些方法中使用的核实际上只是某些RKHS 中的内积。最常见的是 RBF 核其相关的 RKHS 函数是一个无限维的平滑函数集。[2] 我没有详细阅读此来源只是简单浏览了一下。它似乎大部分无关紧要而且对数行列式正则化项的实现相对简单。不过如果你想阅读它尽管去看。[3] 出于法律原因你可以在这里找到包含 MNIST 数据集的代码库。该数据集在 GPL-3.0 许可协议下可以免费用于商业用途。参考文献[1]onlinelibrary.wiley.com/doi/book/10.1002/9780470316801[2]proceedings.neurips.cc/paper_files/paper/2016/file/5680522b8e2bb01943234bce7bf84534-Paper.pdf[3]f0nzie.github.io/interpretable_ml-rsuite/proto.html#examples-5[4]jmlr.csail.mit.edu/papers/volume13/gretton12a/gretton12a.pdf[5]www.stat.cmu.edu/~ryantibs/journalclub/mmd.pdf[6]jmlr.org/papers/volume9/krause08a/krause08a.pdf

相关新闻