Deeplabv3+实现双输出任务(分割+分类)

发布时间:2026/5/22 20:20:41

Deeplabv3+实现双输出任务(分割+分类) 1. 引言DeepLabv3+ 是经典的语义分割模型。根据实际项目需求,我对其网络结构进行了修改,使其支持双输出任务:同时输出像素级分割结果与图像级分类结果。2. 代码修改2.1网络结构修改nets/deeplabv3_plus.py修改DeepLab类,增加分类头import torch import torch.nn as nn import torch.nn.functional as F from nets.xception import xception from nets.mobilenetv2 import mobilenetv2 class MobileNetV2(nn.Module): def __init__(self, downsample_factor=8, pretrained=True): super(MobileNetV2, self).__init__() from functools import partial model = mobilenetv2(pretrained) self.features = model.features[:-1] self.total_idx = len(self.features) self.down_idx = [2, 4, 7, 14] if downsample_factor == 8: for i in range(self.down_idx[-2], self.down_idx[-1]): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=4) ) elif downsample_factor == 16: for i in range(self.down_idx[-1], self.total_idx): self.features[i].apply( partial(self._nostride_dilate, dilate=2) ) def _nostride_dilate(self, m, dilate): classname = m.__class__.__name__ if classname.find('Conv') != -1: if m.stride == (2, 2): m.stride = (1, 1) if m.kernel_size == (3, 3): m.dilation = (dilate // 2, dilate // 2) m.padding = (dilate // 2, dilate // 2) else: if m.kernel_size == (3, 3): m.dilation = (dilate, dilate) m.padding = (dilate, dilate) def forward(self, x): low_level_features = self.features[:4](x) x = self.features[4:](low_level_features) return low_level_features, x # -----------------------------------------# # ASPP特征提取模块 # 利用不同膨胀率的膨胀卷积进行特征提取 # -----------------------------------------# class ASPP(nn.Module): def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1): super(ASPP, self).__init__() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch5_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x): [b, c, row, col] = x.size() # -----------------------------------------# # 一共五个分支 # -----------------------------------------# conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) # -----------------------------------------# # 第五个分支,全局平均池化+卷积 # -----------------------------------------# global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch5_conv(global_feature) global_feature = self.branch5_bn(global_feature) global_feature = self.branch5_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) # -----------------------------------------# # 将五个分支的内容堆叠起来 # 然后1x1卷积整合特征。 # -----------------------------------------# feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) result = self.conv_cat(feature_cat) return result # ==============================新增双任务模型:分割 + 分类===================================== class DeepLab(nn.Module): def __init__( self, num_classes, # 分割类别数 num_classes_classify=2, # 分类类别数 backbone="mobilenet", pretrained=True, downsample_factor=16 ): super(DeepLab, self).__init__() if backbone == "xception": self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 2048 low_level_channels = 256 elif backbone == "mobilenet": self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained) in_channels = 320 low_level_channels = 24 else: raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone)) self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16 // downsample_factor) self.shortcut_conv = nn.Sequential( nn.Conv2d(low_level_channels, 48, 1), nn.BatchNorm2d(48), nn.ReLU(inplace=True) ) self.cat_conv = nn.Sequential( nn.Conv2d(48 + 256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(0.1), ) # 原来的分割头 self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1) # ==================== 新增:图像分类头 ==================== self.num_classes_classify = num_classes_classify self.classification_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Dropout(0.5), nn.Linear(256, 512), nn.ReLU(inplace=True), nn.Linear(512, num_classes_classify) ) # ========================================================== def forward(self, x): H, W = x.size(2), x.size(3) # -----------------------------------------# #

相关新闻