
SAM分割与单目深度估计参考DataWhale具身智能课程every-embodiedSAM与DPT是什么技术组合◦ Segment Anything Model (SAM)提供高质量、实时的交互式图像分割。◦ DPT (Dense Prediction Transformer)从单张图片中估计相对深度图。目标与价值单纯的分割SAM提供物体轮廓但无距离信息单纯的深度图提供距离但物体边界模糊。两者结合可实现3D物体感知使机器人能同时知道“是什么物体”以及“它有多远”。关键原理◦ SAM 由图像编码器、提示编码器和掩码解码器组成能实现实时交互分割。◦ 单目深度模型如DPT通过学习透视、遮挡等上下文线索来推测深度其输出通常是逆深度即值越大越亮表示物体越近值越小越暗表示物体越远。应用场景文档旨在构建一个交互式系统用户点击图片中的物体系统即可实时分割出该物体并计算出其平均相对深度。代码实现交互式分割与深度估计importtorchimportnumpyasnpfromPILimportImageimportmatplotlib.pyplotaspltimportosimporttime# --- 模型导入 ---fromtransformersimportDPTImageProcessor,DPTForDepthEstimationfromsegment_anythingimportsam_model_registry,SamAutomaticMaskGenerator devicecudaiftorch.cuda.is_available()elsecpuprint(f使用设备:{device})# --- 辅助函数显示并保存结果 ---defsave_visualization(image,mask_or_depth,modesam,output_nameoutput.png):plt.figure(figsize(12,8))ifmodedepth:# 并排显示原图 vs 深度图plt.subplot(1,2,1)plt.imshow(image)plt.title(Original Image)plt.axis(off)plt.subplot(1,2,2)plt.imshow(mask_or_depth,cmapinferno)plt.colorbar(labelRelative Depth)plt.title(Depth Estimation)plt.axis(off)elifmodesam:# 叠加显示 SAM Maskplt.imshow(image)axplt.gca()ax.set_autoscale_on(False)# 将 Mask 按面积排序大的在下小的在上sorted_annssorted(mask_or_depth,key(lambdax:x[area]),reverseTrue)img_overlaynp.ones((sorted_anns[0][segmentation].shape[0],sorted_anns[0][segmentation].shape[1],4))img_overlay[:,:,3]0# 透明度初始化foranninsorted_anns:mann[segmentation]color_masknp.concatenate([np.random.random(3),[0.4]])# 随机颜色 0.4 透明度img_overlay[m]color_mask ax.imshow(img_overlay)plt.title(SAM Segmentation)plt.axis(off)plt.savefig(output_name,bbox_inchestight)plt.close()print(f结果已保存至:{output_name})# --- 主流程 ---defmain():# 1. 路径设置rgb_pathimage_d61af3.jpg# 此处替换为实际图片sam_ckptsam_vit_h_4b8939.pthifnotos.path.exists(rgb_path):print(f错误: 找不到图片{rgb_path})return# 加载图片image_pilImage.open(rgb_path).convert(RGB)image_npnp.array(image_pil)# ---------------------------------------------------------# 任务 1: 深度估计 (Depth Estimation)# ---------------------------------------------------------print(\n--- [1/2] 正在运行深度估计 ---)try:depth_processorDPTImageProcessor.from_pretrained(Intel/dpt-large)depth_modelDPTForDepthEstimation.from_pretrained(Intel/dpt-large).to(device)inputsdepth_processor(imagesimage_pil,return_tensorspt).to(device)withtorch.no_grad():outputsdepth_model(**inputs)predicted_depthoutputs.predicted_depth# 插值还原尺寸predictiontorch.nn.functional.interpolate(predicted_depth.unsqueeze(1),sizeimage_pil.size[::-1],modebicubic,align_cornersFalse,).squeeze().cpu().numpy()# 保存深度图结果save_visualization(image_np,prediction,modedepth,output_nameresult_01_depth.png)exceptExceptionase:print(f深度估计失败:{e})# ---------------------------------------------------------# 任务 2: SAM 全图分割 (Segment Anything)# ---------------------------------------------------------print(\n--- [2/2] 正在运行 SAM 分割 ---)ifos.path.exists(sam_ckpt):try:samsam_model_registry[vit_h](checkpointsam_ckpt).to(device)mask_generatorSamAutomaticMaskGenerator(sam)masksmask_generator.generate(image_np)# 保存 SAM 结果 (注意文件名不同避免覆盖)save_visualization(image_np,masks,modesam,output_nameresult_02_sam_seg.png)exceptExceptionase:print(fSAM 分割失败:{e})else:print(f跳过 SAM: 未找到权重文件{sam_ckpt})if__name____main__:main()注意力热图核心知识要点两种“热图”的区分◦ YOLO生成的“检测置信度热图”显示模型认为哪些区域可能存在物体及其置信度是基于模型检测输出的结果可视化。◦ 学术上的“注意力热图”如Grad-CAM揭示模型在做出特定分类决策时其“注意力”聚焦在输入图像的哪些区域。这是一种模型可解释性XAI 技术用于理解神经网络的内部决策依据。注意力热图Grad-CAM的工作原理其核心思想是通过梯度定位对分类决策最重要的图像区域。简要步骤为◦ 正向传播得到目标类别如“杯子”的预测分数。◦ 梯度计算计算该预测分数相对于最后一个卷积层特征图的梯度。梯度值大的位置意味着该处的特征对“判断为杯子”这一决策贡献大。◦ 加权求和用梯度信息对特征图进行加权平均生成一张与原始图像对应的热力图高亮区域即为模型的“注意力焦点”。YOLOv10的核心创新NMS-Free◦ 痛点传统YOLO依赖非极大值抑制NMS 这一后处理步骤来过滤冗余检测框这会增加延迟且阻碍真正的端到端部署。◦ 解决方案YOLOv10提出了一致性双重分配训练策略。▪ 在训练时一个分支执行一对一匹配学习为每个物体只生成一个高质量预测框。▪ 另一个分支执行传统一对多匹配提供丰富的监督信号。▪ 模型被约束使两个分支的输出趋于一致从而学会了在推理时自主抑制冗余框无需NMS。基于YOLOv10的目标检测热力图生成实现代码importosfrompathlibimportPathimportmatplotlib.pyplotaspltimportnumpyasnpfromPILimportImageimportcv2fromultralyticsimportYOLOdefgenerate_and_save_heatmap_for_image(model,img_path,output_dir,target_class_namecup,conf_threshold0.25): 对单张图片进行推理生成检测热力图并保存。 Args: model: 加载好的YOLO模型。 img_path (str): 输入图片路径。 output_dir (str): 输出目录。 target_class_name (str): 需要生成热力图的目标类别名称。 conf_threshold (float): 检测置信度阈值。 # 1. 加载图片img_npcv2.imread(img_path)ifimg_npisNone:print(f 警告: 无法读取图片{img_path}跳过。)returnimg_np_rgbcv2.cvtColor(img_np,cv2.COLOR_BGR2RGB)img_height,img_widthimg_np.shape[:2]# 2. 获取目标类别的IDclass_namesmodel.names# 模型支持的类别名称字典 {id: name}target_cls_idNoneforcls_id,nameinclass_names.items():ifnametarget_class_name:target_cls_idcls_idbreakiftarget_cls_idisNone:print(f 警告: 模型不支持类别 {target_class_name}跳过{img_path}。)return# 3. 模型推理resultsmodel(img_np,verboseFalse,confconf_threshold)# 4. 创建空白热力图并填充heatmapnp.zeros((img_height,img_width),dtypenp.float32)forresultinresults:forboxinresult.boxes:clsint(box.cls)conffloat(box.conf)ifclstarget_cls_id:# 只处理目标类别x1,y1,x2,y2map(int,box.xyxy[0])# 用检测框的置信度值填充其区域cv2.rectangle(heatmap,(x1,y1),(x2,y2),conf,thicknesscv2.FILLED)# 5. 后处理热力图高斯模糊使其更平滑ifheatmap.max()0:# 应用高斯模糊核大小可根据需要调整必须是正奇数kernel_size(51,51)ifmin(img_height,img_width)100else(25,25)kernel_size(kernel_size[0]//2*21,kernel_size[1]//2*21)# 确保为奇数heatmapcv2.GaussianBlur(heatmap,kernel_size,0)# 归一化到 [0, 255] 便于可视化heatmap(heatmap/heatmap.max())*255heatmap_uint8heatmap.astype(np.uint8)# 6. 生成并保存可视化结果图fig,(ax1,ax2)plt.subplots(1,2,figsize(15,6))# 子图1原始图片ax1.imshow(img_np_rgb)ax1.set_title(Original Image)ax1.axis(off)# 子图2热力图imax2.imshow(heatmap_uint8,cmapinferno)ax2.set_title(fHeatmap for {target_class_name})ax2.axis(off)plt.colorbar(im,axax2,fraction0.046,pad0.04)# 保存图片img_stemPath(img_path).stem output_pathos.path.join(output_dir,f{img_stem}_heatmap.png)plt.tight_layout()plt.savefig(output_path,dpi150,bbox_inchestight)plt.close(fig)# 关闭图形以释放内存print(f 已保存:{output_path})defbatch_process_heatmap(image_folder,output_folder,model_weightsyolov10x.pt,target_classcup): 批处理主函数处理一个文件夹中的所有图片。 Args: image_folder (str): 存放输入图片的文件夹路径。 output_folder (str): 保存输出热力图的文件夹路径。 model_weights (str): YOLO模型权重文件名称或路径。 target_class (str): 感兴趣的目标物体类别。 # 创建输出目录os.makedirs(output_folder,exist_okTrue)# 1. 加载模型print(f[1/3] 正在加载模型{model_weights}...)try:modelYOLO(model_weights)exceptExceptionase:print(f 错误: 加载模型失败 -{e})print( 请确保模型文件存在或它将自动从网上下载。)return# 2. 获取图片列表print(f[2/3] 正在扫描图片目录{image_folder}...)valid_extensions(.jpg,.jpeg,.png,.bmp)image_paths[]forfileinos.listdir(image_folder):iffile.lower().endswith(valid_extensions):image_paths.append(os.path.join(image_folder,file))ifnotimage_paths:print( 错误: 未在指定文件夹中找到支持的图片文件。)returnprint(f 找到{len(image_paths)}张待处理图片。)# 3. 批量处理print(f[3/3] 开始批处理目标类别为 {target_class}...)fori,img_pathinenumerate(image_paths):print(f 处理中 ({i1}/{len(image_paths)}):{os.path.basename(img_path)})generate_and_save_heatmap_for_image(model,img_path,output_folder,target_class)print(批处理完成)# --- 使用示例 ---if__name____main__:# 配置参数INPUT_IMAGE_DIR./input_images# 你的输入图片文件夹OUTPUT_RESULT_DIR./heatmap_results# 输出结果文件夹TARGET_OBJECTcup# 你关心的物体类别MODEL_TO_USEyolov10x.pt# 可改为 yolov10n/s/m/l/x 等不同规格# 运行批处理batch_process_heatmap(INPUT_IMAGE_DIR,OUTPUT_RESULT_DIR,MODEL_TO_USE,TARGET_OBJECT)代码核心说明批处理逻辑核心函数 batch_process_heatmap 会遍历 INPUT_IMAGE_DIR 文件夹下所有图片自动处理并将带有热力图的结果保存到 OUTPUT_RESULT_DIR。热力图生成generate_and_save_heatmap_for_image 函数完成了核心工作◦ 加载图片并进行模型推理。◦ 根据 target_class 筛选特定类别的检测框。◦ 用检测框的置信度填充生成初始热力图并通过高斯模糊使其平滑、可视化效果更佳。◦ 将原始图片与热力图并列显示并保存为图片。模型加载代码使用 Ultralytics 的 YOLO 接口。指定 MODEL_TO_USE如 yolov10x.pt首次运行时会自动下载权重文件。可定制性您可以轻松修改目标类别TARGET_OBJECT、模型大小、置信度阈值等参数以适应不同任务。