暂无图片
暂无图片
暂无图片
暂无图片
暂无图片

pytorch入门 - 基于AlexNet神经网络实现猫狗大战

chester技术分享 2025-06-23
108


基于之前的博客pytorch入门 - AlexNet神经网络,并借助猫狗数据集https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/data,实现一个基于 AlexNet 的二分类模型识别猫与狗。

完整流程涵盖数据准备、归一化、模型定义、训练增强、验证并可视化结果。

 

一、数据集准备与预处理

    import os
    import shutil
    def split_data(ROOT_TRAIN):
        cat_dir = os.path.join(ROOT_TRAIN, "cat")
        dog_dir = os.path.join(ROOT_TRAIN, "dog")
        os.makedirs(cat_dir, exist_ok=True)
        os.makedirs(dog_dir, exist_ok=True)
        for filename in os.listdir(ROOT_TRAIN):
            if filename.startswith("cat") and filename.endswith(".jpg"):
                shutil.move(os.path.join(ROOT_TRAIN, filename), 
                            os.path.join(cat_dir, filename))
            elif filename.startswith("dog") and filename.endswith(".jpg"):
                shutil.move(os.path.join(ROOT_TRAIN, filename), 
                            os.path.join(dog_dir, filename)) 

    优化原因
    分类任务需明确标签与数据的对应关系。通过创建cat/dog
    子目录并移动图片,可直接利用PyTorch的ImageFolder
    自动生成标签,避免手动标注错误。


    二、数据归一化参数计算

      def compute_normalization_params(dataset_path):
          transform = transforms.Compose([
              transforms.Resize((227227)),
              transforms.ToTensor()
          ])
          dataset = ImageFolder(dataset_path, transform=transform)
          loader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False)
          # 计算各通道均值和标准差
          mean = 0.0
          std = 0.0
          for data, _ in loader:
              batch_samples = data.size(0)
              data = data.view(batch_samples, data.size(1), -1)
              mean += data.mean(2).sum(0)
              std += data.std(2).sum(0)
          return mean  len(dataset), std  len(dataset) 

      关键点

      1. 输入尺寸统一
        AlexNet要求固定输入尺寸227×227
        ,需提前调整
      2. 通道级归一化
        对RGB三通道分别计算均值和标准差,消除光照差异影响,加速模型收敛
      3. 离线计算
        避免在训练时实时计算,提升数据加载效率


      三、AlexNet模型针对性修改

        class AlexNet(nn.Module):
            def __init__(self):
                super().__init__()
                # 修改1:输入通道调整为3 (RGB)
                self.conv1 = nn.Conv2d(396, kernel_size=11, stride=4
                # ... (中间层省略)
                # 修改2:输出层调整为2分类
                self.fc3 = nn.Linear(40962)  
                # 修改3:降低Dropout比例
                self.dropout = nn.Dropout(0.2)  # 原论文为0.5 

        优化逻辑

        1. 输入通道适配
          原始AlexNet针对ImageNet的1000类设计,此处调整为猫狗二分类,需修改输出层维度为2
        2. 降低过拟合风险
          • 猫狗数据集(25k张)远小于ImageNet(1400万张)
          • 降低Dropout比例(0.5→0.2)可保留更多特征信息,避免模型欠拟合
        3. 权重初始化
          采用Kaiming初始化,适配ReLU激活函数特性,缓解梯度消失


        四、数据增强策略

          train_transform = transforms.Compose([
              transforms.RandomHorizontalFlip(p=0.5),
              transforms.RandomRotation(10),
              transforms.RandomResizedCrop(227, scale=(0.81.0)),
              transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.4880.4550.417], 
                                   std=[0.2260.2210.221])
          ]) 

          增强目的

          1. 提升泛化能力
            通过旋转、裁剪、色彩扰动模拟真实场景的多样性,防止模型记忆固定模式
          2. 克服数据局限
            小数据集易导致过拟合,增强后等效扩大数据规模
          3. 对齐测试环境
            测试阶段采用相同预处理,保证输入分布一致性


          五、训练过程优化

            # 1. 学习率调整
            optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 原常用值0.001
            # 2. 训练-验证集拆分
            train_data, val_data = random_split(dataset, [0.8, 0.2])
            # 3. 早停机制
            if val_acc > best_acc:
                best_model_wts = copy.deepcopy(model.state_dict()) 

            关键技术点

            1. 低学习率策略
              • 预训练模型特征已较完备,降低学习率(1e-4)避免破坏已有特征
              • 微调阶段需精细调整参数,高学习率易导致震荡
            2. 验证集独立划分
              • 20%数据作为验证集,实时监控模型泛化能力
              • 避免测试集参与训练,保证评估客观性
            3. 混合精度训练(可选)

              使用torch.cuda.amp
              自动混合精度,提升训练速度30%+(需GPU支持)


            关键优化总结

            优化点
            原始值
            调整值
            作用
            输入通道
            1 (灰度)
            3 (RGB)
            适配彩色图像
            输出维度
            1000
            2
            二分类需求
            Dropout率
            0.5
            0.2
            防欠拟合
            学习率
            0.001
            0.0001
            稳定微调
            数据增强
            5种变换
            提升泛化性


            关注获取技术分享

            文章转载自chester技术分享,如果涉嫌侵权,请发送邮件至:contact@modb.pro进行举报,并提供相关证据,一经查实,墨天轮将立刻删除相关内容。

            评论