)
胶囊网络实战用TensorFlow 2.x从零搭建CapsNet附MNIST代码在深度学习领域卷积神经网络CNN长期占据主导地位但它在处理空间层次结构时存在固有缺陷。2017年深度学习先驱Geoffrey Hinton提出胶囊网络Capsule Network通过向量神经元和动态路由机制显著提升了模型对物体空间关系的理解能力。本文将带您从零实现一个完整的CapsNet模型包含MNIST数据处理、动态路由算法实现等核心环节并提供可直接运行的TensorFlow 2.x代码。1. 环境准备与数据加载1.1 安装依赖库确保使用Python 3.7环境并安装以下依赖pip install tensorflow2.8.0 matplotlib numpy1.2 MNIST数据处理MNIST数据集包含60,000张28x28的手写数字图像。我们使用TensorFlow内置接口加载数据并进行标准化处理import tensorflow as tf # 加载MNIST数据集 (x_train, y_train), (x_test, y_test) tf.keras.datasets.mnist.load_data() # 数据预处理 def preprocess(images, labels): images tf.expand_dims(images, -1) # 增加通道维度 images tf.cast(images, tf.float32) / 255.0 # 归一化 return images, labels # 创建数据管道 train_ds tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_ds train_ds.map(preprocess).shuffle(10000).batch(128)2. 胶囊网络核心组件实现2.1 PrimaryCaps层构建PrimaryCaps层将传统标量神经元输出转换为向量形式的胶囊class PrimaryCaps(tf.keras.layers.Layer): def __init__(self, caps_dim8, n_caps32, kernel_size9, strides2): super().__init__() self.caps_dim caps_dim self.n_caps n_caps self.conv tf.keras.layers.Conv2D( filterscaps_dim * n_caps, kernel_sizekernel_size, stridesstrides, activationrelu ) def call(self, inputs): # [batch, 20, 20, 256] - [batch, 6, 6, 256] outputs self.conv(inputs) # 重塑为胶囊格式 [batch, 6, 6, 32, 8] outputs tf.reshape(outputs, [ tf.shape(outputs)[0], -1, self.n_caps, self.caps_dim ]) # 应用squash激活函数 return self.squash(outputs) def squash(self, vectors): norm tf.norm(vectors, axis-1, keepdimsTrue) return (norm / (1 norm**2)) * vectors2.2 动态路由算法实现动态路由是胶囊网络的核心机制决定低层胶囊如何将信息传递给高层胶囊def routing(u_hat, b_ij, iterations3): u_hat: 低层胶囊预测向量 [batch, 1152, 10, 16] b_ij: 初始对数先验 [batch, 1152, 10, 1] for i in range(iterations): # 计算耦合系数c_ij c_ij tf.nn.softmax(b_ij, axis2) # 计算高层胶囊输入s_j s_j tf.reduce_sum(c_ij * u_hat, axis1, keepdimsTrue) # 应用squash激活函数 v_j squash(s_j) # 更新对数先验b_ij if i iterations - 1: agreement tf.reduce_sum(u_hat * v_j, axis-1, keepdimsTrue) b_ij agreement return tf.squeeze(v_j, axis1)3. 完整CapsNet模型构建3.1 编码器结构编码器由卷积层、PrimaryCaps层和DigitCaps层组成class CapsNet(tf.keras.Model): def __init__(self): super().__init__() # 初始卷积层 self.conv1 tf.keras.layers.Conv2D( filters256, kernel_size9, strides1, activationrelu, input_shape(28,28,1) ) # PrimaryCaps层 self.primary_caps PrimaryCaps() # DigitCaps层参数 self.digit_caps_dim 16 self.n_digit_caps 10 self.W tf.Variable( initial_valuetf.random_normal_initializer(stddev0.1)( shape[1, 1152, self.n_digit_caps, self.digit_caps_dim, 8] ), trainableTrue ) def call(self, inputs): # 通过卷积层 [batch, 28,28,1] - [batch,20,20,256] x self.conv1(inputs) # PrimaryCaps层 [batch,20,20,256] - [batch,1152,8] u self.primary_caps(x) # 计算预测向量u_hat [batch,1152,10,16] u tf.expand_dims(u, axis2) # [batch,1152,1,8] u tf.expand_dims(u, axis3) # [batch,1152,1,1,8] u_hat tf.reduce_sum(self.W * u, axis-1) # 动态路由 b_ij tf.zeros([tf.shape(inputs)[0], 1152, self.n_digit_caps, 1]) v_j routing(u_hat, b_ij) return v_j3.2 解码器设计与重建损失解码器用于正则化训练过程通过胶囊向量重建原始图像class Decoder(tf.keras.layers.Layer): def __init__(self): super().__init__() self.dense1 tf.keras.layers.Dense(512, activationrelu) self.dense2 tf.keras.layers.Dense(1024, activationrelu) self.dense3 tf.keras.layers.Dense(784, activationsigmoid) def call(self, inputs, y_true): # 仅使用正确类别的胶囊向量 mask tf.one_hot(y_true, depth10) masked tf.reduce_sum(inputs * mask[:, None], axis1) # 通过全连接层重建图像 x self.dense1(masked) x self.dense2(x) x self.dense3(x) return tf.reshape(x, [-1, 28, 28, 1])4. 模型训练与评估4.1 自定义损失函数胶囊网络使用边缘损失和重建损失的组合class CapsuleLoss(tf.keras.losses.Loss): def __init__(self, m_plus0.9, m_minus0.1, lambda_0.5): super().__init__() self.m_plus m_plus self.m_minus m_minus self.lambda_ lambda_ def call(self, y_true, y_pred): # 计算边缘损失 L y_true * tf.square(tf.maximum(0., self.m_plus - y_pred)) \ self.lambda_ * (1 - y_true) * tf.square(tf.maximum(0., y_pred - self.m_minus)) return tf.reduce_mean(tf.reduce_sum(L, axis1))4.2 训练流程实现配置自定义训练循环以支持复杂损失计算# 初始化模型和优化器 model CapsNet() decoder Decoder() optimizer tf.keras.optimizers.Adam(0.001) loss_fn CapsuleLoss() tf.function def train_step(images, labels): with tf.GradientTape() as tape: # 前向传播 caps_output model(images) # 计算边缘损失 y_true tf.one_hot(labels, depth10) caps_loss loss_fn(y_true, tf.norm(caps_output, axis-1)) # 计算重建损失 reconstructed decoder(caps_output, labels) recon_loss tf.reduce_mean( tf.square(images - reconstructed) ) # 总损失 total_loss caps_loss 0.0005 * recon_loss # 反向传播 grads tape.gradient(total_loss, model.trainable_variables decoder.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables decoder.trainable_variables)) return caps_loss, recon_loss4.3 可视化训练结果训练过程中可以定期输出重建图像直观评估模型表现import matplotlib.pyplot as plt def plot_reconstructions(model, decoder, test_images, test_labels, n_samples5): # 获取模型预测 caps_output model.predict(test_images[:n_samples]) reconstructions decoder(caps_output, test_labels[:n_samples]).numpy() # 绘制对比图 plt.figure(figsize(n_samples * 2, 4)) for i in range(n_samples): # 原始图像 plt.subplot(2, n_samples, i 1) plt.imshow(test_images[i].squeeze(), cmapgray) plt.axis(off) # 重建图像 plt.subplot(2, n_samples, i 1 n_samples) plt.imshow(reconstructions[i].squeeze(), cmapgray) plt.axis(off) plt.show()在实际项目中我发现动态路由的迭代次数对模型性能影响显著。经过多次实验3次迭代通常能在训练效率和模型精度间取得良好平衡。另一个关键点是重建损失的权重系数——过大会干扰胶囊学习空间特征过小则失去正则化效果。建议从0.0005开始逐步调整。