)
从图像到LaTeX深入拆解数学公式识别中的Attention机制与位置编码数学公式识别一直是计算机视觉领域最具挑战性的任务之一。与普通OCR不同公式中的二维空间结构、符号间的复杂嵌套关系以及专业符号的多样性使得传统方法难以取得理想效果。本文将聚焦于公式识别中的核心难题——如何保留二维空间信息并深入解析Attention机制与位置编码Positional Embedding如何协同解决这一难题。1. 数学公式识别的独特挑战数学公式识别之所以困难根源在于其二维结构的特殊性。在普通文本OCR中CRNNCNNRNN架构通过CNN提取局部特征后用RNN处理一维序列就能取得不错效果。但这种方法对公式识别存在三个致命缺陷空间信息丢失将二维图像扁平化为序列时上下标、分式等位置关系被破坏长程依赖缺失RNN难以捕捉跨多行的括号匹配等远距离关系结构歧义相同符号在不同位置可能代表不同含义如a_i与a^i传统方法尝试用规则或语法树解决这些问题但泛化能力有限。直到Attention机制的引入才为这些问题提供了端到端的解决方案。关键观察公式识别的本质是将二维空间结构转换为具有层次关系的标记序列这需要模型同时理解局部特征和全局结构。2. Attention机制在公式识别中的革新2.1 从Seq2Seq到Transformer典型的公式识别模型采用Encoder-Decoder架构# 简化版模型结构 encoder CNNBackbone() # 提取视觉特征 decoder TransformerDecoder() # 生成LaTeX标记与传统Seq2Seq不同公式识别中的Attention机制需要处理两个维度的信息空间Attention关注图像中相关区域的位置关系语义Attention理解符号间的数学逻辑关系2.2 二维Attention的实现在图像到LaTeX任务中Attention权重计算需要考虑二维坐标。以下是简化的实现逻辑def spatial_attention(query, key, value, pos_emb): # query: 当前解码位置 [batch, dim] # key/value: 编码器输出 [batch, h, w, dim] # pos_emb: 位置编码 [h, w, dim] # 融合内容与位置信息 key key pos_emb query query.unsqueeze(1).unsqueeze(1) # 计算二维注意力权重 attn_scores torch.matmul(query, key.transpose(-1, -2)) attn_weights F.softmax(attn_scores, dim-1) return torch.matmul(attn_weights, value)这种设计使得模型可以同时关注内容相关性符号的视觉特征匹配度位置相关性符号间的空间距离关系3. 位置编码的数学原理与实现3.1 为什么需要位置编码当CNN特征图被展平输入Decoder时行列位置信息会丢失。位置编码通过注入空间坐标信息解决这一问题。其核心要求是唯一性每个位置有独特编码相对性能表示位置间的距离关系可扩展性适应不同尺寸的输入3.2 正弦/余弦位置编码Transformer使用的正弦编码具有理想的数学性质PE(pos,2i) sin(pos/10000^(2i/dmodel)) PE(pos,2i1) cos(pos/10000^(2i/dmodel))扩展到二维情况时需要对行、列分别编码def get_2d_pos_embed(height, width, dim): # 行编码 row_pos torch.arange(height).unsqueeze(1) row_emb position_encoding(row_pos, dim//2) # 列编码 col_pos torch.arange(width).unsqueeze(1) col_emb position_encoding(col_pos, dim//2) # 合并行列编码 pos_emb torch.cat([ row_emb.unsqueeze(1).repeat(1,width,1), col_emb.unsqueeze(0).repeat(height,1,1) ], dim-1) return pos_emb # [H, W, D]3.3 位置编码的可视化分析下表对比了不同位置编码方式的特性编码类型唯一性相对位置可训练适用场景正弦编码✔✔✘固定尺寸输入可学习参数✔✘✔小尺寸输入相对位置偏置✘✔✔自注意力机制混合编码✔✔✔变尺寸输入在公式识别中正弦编码因其能处理不同长宽比的输入而成为首选。4. 实战改进的Positional Embedding实现4.1 多尺度位置编码原始正弦编码在极端位置如10000会出现数值不稳定。改进方案采用对数间隔的频率def add_timing_signal_nd(x, min_scale1.0, max_scale1e4): Tensor2Tensor库中的实现改进 num_dims len(x.shape) - 2 channels x.shape[-1] # 对数间隔的频率 num_timescales channels // (num_dims * 2) log_inc math.log(max_scale/min_scale) / (num_timescales-1) inv_timescales min_scale * torch.exp( torch.arange(num_timescales).float() * -log_inc) # 为每个维度生成编码 signals [] for dim in range(num_dims): length x.shape[dim1] pos torch.arange(length).float() scaled_time pos.unsqueeze(1) * inv_timescales.unsqueeze(0) signal torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim1) signals.append(signal) # 合并所有维度的编码 return x combine_signals(signals)4.2 位置编码与Attention的协同在解码阶段位置信息参与注意力权重的三种方式绝对位置注入将位置编码加到Key/Query相对位置偏置在注意力得分中加入位置关系项位置感知卷积用卷积操作隐式编码位置实验表明在公式识别任务中方式1和3的组合效果最佳方法组合CROHME数据集准确率推理速度(FPS)绝对位置内容Attention58.7%12.3相对位置偏置56.2%9.8位置卷积内容Attention60.1%15.65. 前沿进展与优化方向当前最先进的公式识别系统已采用更复杂的位置感知机制动态位置编码根据图像内容自适应调整位置权重层次化位置建模分别处理字符级和结构级位置关系几何关系编码显式建模符号间的几何约束如对齐、包含等一个值得关注的趋势是将位置编码与图神经网络结合用图结构显式表示公式的二维关系。这种混合架构在复杂公式识别上比纯Attention模型有显著提升。