运行 VM-UNet 踩坑记录

发布时间:2026/5/24 18:03:22

运行 VM-UNet 踩坑记录 运行 VM-UNet 踩坑记录源码地址https://github.com/JCruan519/VM-UNet前言想运行一下 VM-UNet 的开源代码作者的开源资料写的非常之详细我也是按照步骤一步一步来做的环境安装的也非常顺利一跑代码哇塞直接能够运行稳了。但是运行的时候loss一直为0不知道为什么我以为就是简单可能他就是这样的毕竟我也不是写这个代码的想着能够运行起来就非常不错了。我就说让他默认一直运行吧但是到了后面epoch30的时候要进行评价指标的计算了到这个时候我才发现有报错File/code/VM-UNet/engine.py, line144,intest_one_epoch TN, FP, FN, TPconfusion[0,0], confusion[0,1], confusion[1,0], confusion[1,1]IndexError: index1is out of boundsforaxis1with size1出现越界了我就想怎么可能越界啊这不可能的啊后面我去查一下源代码confusionconfusion_matrix(y_true,y_pre)TN,FP,FN,TPconfusion[0,0],confusion[0,1],confusion[1,0],confusion[1,1]这里再进行confusion_matrix运算的时候当y_true,y_pre只有一种值的时候就会出现confusion.shape (1,1)的情况所以推测就是数据出了问题但是数据怎么可能会出问题呢经过调试确实看到所有的y_true,y_pre都为0导致只有一种值confusion访问出现越界可是为什么预测和真实值都是0呢这个问题简直百思不得其解啊我就猜测是模型的错误正好作者也有一个已经训练好的模型那我就去试试训练好的模型也是同样的结果loss0并且要进行预测的时候也是出现上面的越界问题我就去找具体的数据读取在哪里看看数据到底有没有正确读取找到进行test的入口进行单步调试deftest_one_epoch(test_loader,model,criterion,logger,config,test_data_nameNone):# switch to evaluate modemodel.eval()preds[]gts[]loss_list[]withtorch.no_grad():fori,datainenumerate(tqdm(test_loader)):img,mskdata# 调试发现到这里的时候imgmsk 的数据是正确的也就是读取到了正确的数据值img,mskimg.cuda(non_blockingTrue).float(),msk.cuda(non_blockingTrue).float()# 运行完上面那一步以后呢数据全部为0了outmodel(img)losscriterion(out,msk)loss_list.append(loss.item())mskmsk.squeeze(1).cpu().detach().numpy()gts.append(msk)iftype(out)istuple:outout[0]outout.squeeze(1).cpu().detach().numpy()preds.append(out)ifi%config.save_interval0:save_imgs(img,msk,out,i,config.work_diroutputs/,config.datasets,config.threshold,test_data_nametest_data_name)predsnp.array(preds).reshape(-1)gtsnp.array(gts).reshape(-1)y_prenp.where(predsconfig.threshold,1,0)y_truenp.where(gts0.5,1,0)confusionconfusion_matrix(y_true,y_pre)TN,FP,FN,TPconfusion[0,0],confusion[0,1],confusion[1,0],confusion[1,1]accuracyfloat(TNTP)/float(np.sum(confusion))iffloat(np.sum(confusion))!0else0sensitivityfloat(TP)/float(TPFN)iffloat(TPFN)!0else0specificityfloat(TN)/float(TNFP)iffloat(TNFP)!0else0f1_or_dscfloat(2*TP)/float(2*TPFPFN)iffloat(2*TPFPFN)!0else0mioufloat(TP)/float(TPFPFN)iffloat(TPFPFN)!0else0iftest_data_nameisnotNone:log_infoftest_datasets_name:{test_data_name}print(log_info)logger.info(log_info)log_infoftest of best model,loss:{np.mean(loss_list):.4f},miou:{miou},f1_or_dsc:{f1_or_dsc},accuracy:{accuracy},\ specificity:{specificity},sensitivity:{sensitivity},confusion_matrix:{confusion}print(log_info)logger.info(log_info)returnnp.mean(loss_list)好啊好啊好 也算是终于找到了问题的关键经过运算img, msk img.cuda(non_blockingTrue).float(), msk.cuda(non_blockingTrue).float()这一步以后imgmsk 的值就为 0 了。那么现在只需要找原因就好了。显卡与torch版本不匹配问题我知道经过上面的步骤以后tensor的值就会变为 0 所以为了验证是不是原本代码的问题我直接把问题抽出来直接运行importtorch xtorch.tensor([0.123])print(CPU:, x)x_cudax.cuda()print(GPU:, x_cuda)结果CPU: tensor([0.1230])torch.float32 ~/.conda/envs/vmunet/lib/python3.8/site-packages/torch/cuda/__init__.py:155: UserWarning: NVIDIA GeForce RTX5090with CUDA capability sm_120 is not compatible with the current PyTorch installation. The current PyTorchinstallsupports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86. If you want to use the NVIDIA GeForce RTX5090GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/ warnings.warn(incompatible_device_warn.format(device_name, capability, .join(arch_list), device_name))GPU: tensor([0.],devicecuda:0)torch.float32 同步后GPU: tensor([0.],devicecuda:0)看到了吧问题就是NVIDIA GeForce RTX 5090 with CUDA capability sm_120 is not compatible with the current PyTorch installation. The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86. If you want to use the NVIDIA GeForce RTX 5090 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/PyTorch 版本与显卡 5090 的 CUDA 能力 sm_120 不匹配卧槽了真离谱啊解决既然知道了是版本不匹配的问题那就是版本经过多方尝试啊最终就是要想适配5090系列显卡**PyTorch的CUDA就要选择CUDA 12.8以上的才行**后面就是重新装配一下VM-UNet的依赖了。代码如下conda create--namevmunetpython3.10-yconda avtivate vmunet pip3installtorch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128 pipinstallpackaging pipinstalltimm#这里就不要按照 README 文件中指定的版本了他会回去追溯版本,导致循环下载会下载很多内容这个很恼火要等很久pipinstallpytest chardet yacs termcolor pipinstallsubmitit tensorboardX pipinstalltriton pipinstallcausal_conv1d-1.6.1cu12torch2.10cxx11abiTRUE-cp310-cp310-linux_x86_64.whl# 这个 whl 包要到官方地址下载与 python与cp后面的数字对应 和 PyTorch要对应cu以及torch的版本 匹配的包并且要注意后面的平台x86_64 官方地址 https://github.com/Dao-AILab/causal-conv1d/releasespipinstallmamba_ssm-2.3.1cu12torch2.10cxx11abiTRUE-cp310-cp310-linux_x86_64.whl# 同样的要遵循上面那一条下载的规则官方地址https://github.com/state-spaces/mamba/releases/pipinstallscikit-learn matplotlib thop h5py SimpleITK scikit-image medpy yacs总结这一顿操作下来感觉搞了一个星期不止的感觉啊给我最大的一个印象就是要学会单步调试单步调试了就能够精准定位问题在哪里知道问题在哪里那么离解决问题就不远了下次再进行这样的操作的时候能够想起来是不是和显卡版本有问题啦

相关新闻