用PyTorch搞定Million-AID遥感数据集:从下载到训练,一个完整的代码仓库搭建指南

发布时间:2026/5/18 14:35:21

用PyTorch搞定Million-AID遥感数据集:从下载到训练,一个完整的代码仓库搭建指南 从零构建PyTorch遥感项目Million-AID数据集全流程实战指南引言在计算机视觉领域遥感图像分析正经历着前所未有的发展机遇。随着卫星和无人机技术的进步高质量遥感数据的获取变得比以往任何时候都更加容易。然而面对海量的遥感数据如何有效地组织、预处理并用于深度学习模型训练仍然是许多研究者和工程师面临的挑战。Million-AID数据集作为当前最大的公开遥感场景分类基准之一包含了超过100万张图像涵盖51个精细场景类别。这个数据集不仅规模庞大还具有复杂的层级分类结构和显著的长尾分布特性为模型训练带来了独特的挑战。本文将带领您从零开始构建一个完整的PyTorch项目来处理Million-AID数据集。不同于简单的代码片段展示我们将重点关注工程化实现和可复现性涵盖从数据集下载、目录结构设计、自定义Dataset类实现到训练流程优化的全流程。无论您是刚开始接触遥感图像分析的研究生还是希望将Million-AID应用于实际项目的算法工程师本指南都将提供可直接落地的解决方案。1. 项目初始化与环境配置1.1 创建项目结构一个良好的项目结构是高效开发的基础。对于遥感图像分类任务我们推荐以下目录组织方式Million-AID-Project/ ├── configs/ # 配置文件 │ └── default.yaml # 默认训练配置 ├── data/ # 数据处理模块 │ ├── __init__.py │ ├── dataset.py # 自定义Dataset类 │ ├── preprocess.py # 数据预处理 │ └── transforms.py # 自定义数据增强 ├── models/ # 模型定义 │ ├── __init__.py │ ├── base_model.py # 基础模型类 │ └── custom_cnn.py # 自定义CNN架构 ├── utils/ # 工具函数 │ ├── logger.py # 日志记录 │ └── metrics.py # 评估指标 ├── scripts/ # 实用脚本 │ ├── download_data.sh # 数据集下载 │ └── visualize.py # 数据可视化 ├── train.py # 训练入口 └── requirements.txt # 依赖列表提示使用tree命令可以快速生成项目目录结构图便于文档记录和团队协作。1.2 环境依赖安装建议使用conda创建独立的Python环境conda create -n million-aid python3.8 conda activate million-aid pip install -r requirements.txt其中requirements.txt应包含以下核心依赖torch1.12.1 torchvision0.13.1 numpy1.21.0 pillow9.0.0 tqdm4.64.0 tensorboard2.10.0 opencv-python4.6.0 scikit-learn1.1.01.3 数据集获取与验证Million-AID数据集可通过官方渠道申请下载。下载完成后建议运行校验脚本确保数据完整性import hashlib import os def check_file_integrity(filepath, expected_md5): 验证文件MD5校验码 hash_md5 hashlib.md5() with open(filepath, rb) as f: for chunk in iter(lambda: f.read(4096), b): hash_md5.update(chunk) return hash_md5.hexdigest() expected_md5 # 示例验证训练集压缩包 train_zip_path Million-AID/train.zip expected_md5 a1b2c3d4e5f6... # 替换为实际MD5值 if check_file_integrity(train_zip_path, expected_md5): print(文件校验通过) else: print(文件损坏请重新下载)2. 高效处理Million-AID数据集2.1 理解数据集结构Million-AID采用三级分类体系目录结构示例如下Million-AID/ ├── train/ │ ├── agriculture_land/ │ │ ├── arable_land/ │ │ │ ├── dry_field/ │ │ │ │ ├── image_001.jpg │ │ │ │ └── ... │ │ │ └── irrigated_field/ │ │ └── permanent_crop/ │ └── commercial_land/ └── test/ └── ... (类似train结构)2.2 实现自定义Dataset类针对这种复杂结构我们需要设计能够自动解析类别标签的Datasetfrom torch.utils.data import Dataset from PIL import Image import os import torchvision.transforms as T class MillionAIDDataset(Dataset): def __init__(self, root_dir, transformNone, modetrain): self.root_dir os.path.join(root_dir, mode) self.transform transform self.classes, self.class_to_idx self._find_classes() self.samples self._make_dataset() def _find_classes(self): 构建类别到索引的映射 classes [] for l1 in os.listdir(self.root_dir): # 第一级类别 l1_path os.path.join(self.root_dir, l1) if not os.path.isdir(l1_path): continue for l2 in os.listdir(l1_path): # 第二级类别 l2_path os.path.join(l1_path, l2) if not os.path.isdir(l2_path): continue for l3 in os.listdir(l2_path): # 第三级类别 l3_path os.path.join(l2_path, l3) if os.path.isdir(l3_path): classes.append(f{l1}/{l2}/{l3}) class_to_idx {cls_name: i for i, cls_name in enumerate(classes)} return classes, class_to_idx def _make_dataset(self): 构建(图像路径, 标签)对的列表 samples [] for cls_name, cls_idx in self.class_to_idx.items(): l1, l2, l3 cls_name.split(/) cls_dir os.path.join(self.root_dir, l1, l2, l3) for img_name in os.listdir(cls_dir): img_path os.path.join(cls_dir, img_name) if os.path.isfile(img_path): samples.append((img_path, cls_idx)) return samples def __len__(self): return len(self.samples) def __getitem__(self, idx): img_path, label self.samples[idx] img Image.open(img_path).convert(RGB) if self.transform: img self.transform(img) return img, label2.3 数据增强策略针对遥感图像特点我们设计专门的增强策略from torchvision import transforms # 基础变换 train_transform transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ColorJitter(brightness0.2, contrast0.2, saturation0.2), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ]) # 测试集变换应保持确定性 test_transform transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean[0.485, 0.456, 0.406], std[0.229, 0.224, 0.225]) ])3. 处理长尾分布问题3.1 分析数据分布Million-AID存在明显的类别不平衡类别级别样本数量范围代表性类别头部类别30,000-45,000居住区、农田中部类别5,000-15,000工业区、交通设施尾部类别2,000-5,000特殊作物、小众建筑3.2 采样策略对比我们实现多种采样器来处理不平衡问题from torch.utils.data import WeightedRandomSampler def get_weighted_sampler(dataset): 根据类别频率计算样本权重 class_counts torch.zeros(len(dataset.classes)) for _, label in dataset.samples: class_counts[label] 1 class_weights 1. / class_counts sample_weights torch.tensor([class_weights[label] for _, label in dataset.samples]) return WeightedRandomSampler(sample_weights, len(sample_weights)) # 替代方案使用现成的库 from torchsampler import ImbalancedDatasetSampler sampler ImbalancedDatasetSampler(train_dataset)3.3 损失函数调整除了采样策略还可以在损失函数层面处理不平衡import torch.nn as nn import torch.nn.functional as F class FocalLoss(nn.Module): def __init__(self, alphaNone, gamma2.0): super().__init__() self.gamma gamma self.alpha alpha # 可传入类别权重张量 def forward(self, inputs, targets): BCE_loss F.cross_entropy(inputs, targets, reductionnone) pt torch.exp(-BCE_loss) loss (1-pt)**self.gamma * BCE_loss if self.alpha is not None: alpha_t self.alpha[targets] loss alpha_t * loss return loss.mean() # 使用示例 alpha torch.tensor([...]) # 根据类别频率计算 criterion FocalLoss(alphaalpha, gamma2.0)4. 模型训练与评估4.1 基础训练流程实现模块化的训练循环def train_one_epoch(model, loader, optimizer, criterion, device, epoch): model.train() total_loss 0.0 correct 0 total 0 pbar tqdm(loader, descfEpoch {epoch}) for inputs, labels in pbar: inputs, labels inputs.to(device), labels.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() total_loss loss.item() _, predicted outputs.max(1) total labels.size(0) correct predicted.eq(labels).sum().item() pbar.set_postfix({ loss: total_loss/(total/loader.batch_size), acc: 100.*correct/total }) return total_loss/len(loader), correct/total4.2 多尺度训练技巧针对遥感图像特点实现多尺度训练class MultiScaleTransform: 在训练过程中随机选择不同尺度 def __init__(self, scales): self.scales scales self.transforms [ transforms.Compose([ transforms.Resize(scale), transforms.RandomCrop(224), # 其他变换... ]) for scale in scales ] def __call__(self, img): t random.choice(self.transforms) return t(img) # 使用示例 train_transform MultiScaleTransform(scales[256, 288, 320])4.3 模型评估指标除了准确率遥感任务中常用的评估指标from sklearn.metrics import confusion_matrix, classification_report def evaluate(model, loader, device, classes): model.eval() all_preds [] all_labels [] with torch.no_grad(): for inputs, labels in loader: inputs inputs.to(device) outputs model(inputs) _, preds torch.max(outputs, 1) all_preds.extend(preds.cpu().numpy()) all_labels.extend(labels.numpy()) # 计算混淆矩阵 cm confusion_matrix(all_labels, all_preds) print(混淆矩阵:\n, cm) # 分类报告 print(\n分类报告:) print(classification_report( all_labels, all_preds, target_namesclasses )) return cm5. 工程优化技巧5.1 加速数据加载使用torchdata库实现并行数据加载from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService def get_dataloader(dataset, batch_size, num_workers4, shuffleTrue): rs MultiProcessingReadingService( num_workersnum_workers, worker_init_fnlambda _: torch.manual_seed(42) ) return DataLoader2( dataset, batch_sizebatch_size, shuffleshuffle, reading_servicers )5.2 混合精度训练利用AMP加速训练并减少显存占用from torch.cuda.amp import autocast, GradScaler scaler GradScaler() for inputs, labels in train_loader: inputs, labels inputs.to(device), labels.to(device) optimizer.zero_grad() with autocast(): outputs model(inputs) loss criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()5.3 模型部署优化使用TorchScript导出生产环境可用的模型# 导出为TorchScript model.eval() example_input torch.rand(1, 3, 224, 224).to(device) traced_script torch.jit.trace(model, example_input) traced_script.save(million_aid_model.pt) # 量化模型减小体积 quantized_model torch.quantization.quantize_dynamic( model, {torch.nn.Linear}, dtypetorch.qint8 )在实际项目中我们发现将数据预处理逻辑也包含在TorchScript中能显著提升端到端性能。可以通过自定义nn.Module包装整个预处理推理流程来实现这一点。

相关新闻