)
30行代码玩转YOLO v1核心用PyTorch Lightning实现目标检测极简Demo当目标检测遇上现代深度学习框架一切都变得简单起来。本文将带你用PyTorch Lightning在30行代码内实现YOLO v1的核心思想无需深陷论文公式直接动手体验一次扫描的检测魅力。这个迷你项目专为想快速理解YOLO本质的实践者设计我们将在Google Colab环境中完成从数据准备到结果可视化的全流程。1. 准备工作环境与数据首先确保已安装PyTorch Lightning和必要的可视化工具pip install pytorch-lightning torchvision matplotlib我们将使用一个极简的合成数据集来演示YOLO的核心机制。这个数据集包含100张256x256像素的图片每张图片中心放置一个随机颜色的圆形或方形标签信息包括物体类别和边界框坐标。import numpy as np from torch.utils.data import Dataset, DataLoader class SimpleShapesDataset(Dataset): def __init__(self, num_samples100, grid_size7): self.num_samples num_samples self.grid_size grid_size self.classes [circle, square] def __len__(self): return self.num_samples def __getitem__(self, idx): # 生成随机图像和标签 image np.zeros((256, 256, 3), dtypenp.float32) label np.zeros((self.grid_size, self.grid_size, 5len(self.classes)), dtypenp.float32) # 随机选择形状和位置 shape_type np.random.randint(0, 2) center_x, center_y 128 np.random.randint(-64, 64), 128 np.random.randint(-64, 64) size np.random.randint(20, 50) # 绘制形状 if shape_type 0: # 圆形 cv2.circle(image, (center_x, center_y), size, (np.random.rand(), np.random.rand(), np.random.rand()), -1) else: # 方形 cv2.rectangle(image, (center_x-size, center_y-size), (center_xsize, center_ysize), (np.random.rand(), np.random.rand(), np.random.rand()), -1) # 转换为YOLO格式标签 grid_x int(center_x / (256 / self.grid_size)) grid_y int(center_y / (256 / self.grid_size)) label[grid_y, grid_x, 0:4] [ (center_x % (256 / self.grid_size)) / (256 / self.grid_size), (center_y % (256 / self.grid_size)) / (256 / self.grid_size), size / 256, size / 256 ] label[grid_y, grid_x, 4] 1.0 # confidence label[grid_y, grid_x, 5 shape_type] 1.0 # class probability return torch.from_numpy(image).permute(2, 0, 1), torch.from_numpy(label)2. 构建简化版YOLO网络YOLO v1的核心是一个将图像划分为S×S网格的卷积网络每个网格预测B个边界框和对应的类别概率。我们用PyTorch Lightning简化实现import torch.nn as nn import pytorch_lightning as pl class TinyYOLO(pl.LightningModule): def __init__(self, grid_size7, num_boxes2, num_classes2): super().__init__() self.grid_size grid_size self.num_boxes num_boxes self.num_classes num_classes # 简化版网络结构 self.net nn.Sequential( nn.Conv2d(3, 16, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32, 64, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding1), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(128 * (grid_size**2), grid_size * grid_size * (num_boxes*5 num_classes)) ) def forward(self, x): batch_size x.size(0) out self.net(x) return out.view(batch_size, self.grid_size, self.grid_size, self.num_boxes*5 self.num_classes)这个简化版网络保留了YOLO的三个关键设计网格划分输出特征图对应输入图像的网格划分统一预测每个网格同时预测边界框和类别概率端到端训练直接从像素到检测结果3. 实现YOLO风格损失函数YOLO的损失函数结合了定位误差、置信度误差和分类误差。我们实现一个简化版本def yolo_loss(preds, targets): # 定位损失只计算有目标的网格 obj_mask targets[..., 4] 1 loc_pred preds[..., :4][obj_mask] loc_target targets[..., :4][obj_mask] loc_loss torch.mean((loc_pred - loc_target)**2) # 置信度损失 conf_pred torch.sigmoid(preds[..., 4::5]) # 每5个值中的第5个是置信度 conf_target targets[..., 4::5] conf_loss torch.mean((conf_pred - conf_target)**2) # 分类损失只计算有目标的网格 cls_pred preds[..., self.num_boxes*5:][obj_mask] cls_target targets[..., self.num_boxes*5:][obj_mask] cls_loss torch.mean((cls_pred - cls_target)**2) return loc_loss conf_loss cls_loss class TinyYOLO(pl.LightningModule): # ... 前面的网络定义 ... def training_step(self, batch, batch_idx): x, y batch pred self(x) loss yolo_loss(pred, y) self.log(train_loss, loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr1e-3)关键点说明定位损失只计算包含目标的网格的边界框坐标误差置信度损失衡量预测框包含目标的可信度分类损失预测正确的类别概率4. 训练与可视化结果现在我们可以训练模型并可视化检测结果了# 数据准备 dataset SimpleShapesDataset() loader DataLoader(dataset, batch_size16, shuffleTrue) # 模型训练 model TinyYOLO() trainer pl.Trainer(max_epochs20, gpus1 if torch.cuda.is_available() else 0) trainer.fit(model, loader) # 可视化函数 def visualize_detection(image, pred): fig, ax plt.subplots(1) ax.imshow(image.permute(1, 2, 0).numpy()) grid_size pred.shape[0] cell_size 256 / grid_size for i in range(grid_size): for j in range(grid_size): if pred[i,j,4] 0.5: # confidence threshold # 解码预测框 x, y, w, h pred[i,j,:4] x (j x) * cell_size y (i y) * cell_size w w * 256 h h * 256 # 绘制边界框 rect plt.Rectangle((x-w/2, y-h/2), w, h, linewidth2, edgecolorr, facecolornone) ax.add_patch(rect) # 添加类别标签 cls_idx torch.argmax(pred[i,j,5:]) ax.text(x, y, f{model.classes[cls_idx]}, colorwhite, backgroundcolorred) plt.show() # 测试一个样本 test_image, _ dataset[0] with torch.no_grad(): pred model(test_image.unsqueeze(0))[0] visualize_detection(test_image, pred)这个极简实现捕捉了YOLO v1的三大核心思想网格划分检测将图像划分为S×S网格每个网格负责检测中心落在该区域的目标统一预测每个网格同时预测边界框坐标、置信度和类别概率端到端训练直接从图像像素到检测结果无需复杂的区域提议