pytouch 学习(2):使用 pytouch 实现 CNN 手写字体识别

步骤如下:

  1. 导入所需的包
  2. 处理数据集
  3. 导入数据集
  4. 定义网络结构
  5. 定义损失器和优化器
  6. 训练
  7. 测试
  8. 优化

1、导入所需包

import torch 
from torch.utils import data # 获取迭代数据,处理 数据
from torch.autograd import Variable # 获取变量
import torchvision #包由流行的数据集、模型架构和常见的计算机视觉图像转换组成
from torchvision.datasets import mnist # 获取数据集,手写字体库
import matplotlib.pyplot as plt # 生成图形

2、数据处理

将获取到的数据集转换为 pytouch 可用的 tensor 类型

然后用平均值和标准差归一化图像。

# 数据集的预处理
data_tf = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(), # 数据类型转换
        torchvision.transforms.Normalize([0.5],[0.5]) # 数据归一化,也就是把图片变成 (-1,1)的范围
    ]
)

data_path = r'数据集路径'   # 获取数据集
train_data = mnist.MNIST(data_path,train=True,transform=data_tf,download=False)
test_data = mnist.MNIST(data_path,train=False,transform=data_tf,download=False)

第二个参数:train 表示该数据集是否为训练数据。 第四个参数,如果数据集不在 本机,可以直接下载。

3、导入数据集

下载完成后,导入数据,使用 data.DataLoader 函数

train_loader = data.DataLoader(train_data,batch_size=128,shuffle=True)
test_loader = data.DataLoader(test_data,batch_size=100,shuffle=True)

定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
num_workers=0, collate_fn=default_collate, pin_memory=False,
drop_last=False)

dataset:加载的数据集(Dataset对象) 
batch_size:每个batch有多少个样本,就相当于与把 数据集分成了N份,每份有这么多,进行并行处理,加快计算速度。
shuffle::是否将数据打乱 
sampler: 样本抽样,后续会详细介绍 
num_workers:使用多进程加载的进程数,0代表不使用多进程 
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可 
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些 
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

4、定义网络结构

输出图形边框:output_shape = (image_shape-filter_shape+2*padding)/stride + 1

# 定义网络结构
class CNNnet(torch.nn.Module):
    def __init__(self):
        super(CNNnet,self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, #输入通道数
                            out_channels=16,# 输出 通道数,一般为16证书被,按计算机计算特性设计
                            kernel_size=3,# 卷积核大小 3*3
                            stride=2, # 步长
                            padding=1),# 边框 0 填充数量
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU()
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(16,32,3,2,1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU()
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv2d(32,64,3,2,1),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU()
        )
        self.conv4 = torch.nn.Sequential(
            torch.nn.Conv2d(64,64,2,2,0),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU()
        )
        self.mlp1 = torch.nn.Linear(2*2*64,100)
        self.mlp2 = torch.nn.Linear(100,10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.mlp1(x.view(x.size(0),-1))
        x = self.mlp2(x)
        return x
model = CNNnet()
print(model)

输出结构如下:

CNNnet(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (conv4): Sequential(
    (0): Conv2d(64, 64, kernel_size=(2, 2), stride=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (mlp1): Linear(in_features=256, out_features=100, bias=True)
  (mlp2): Linear(in_features=100, out_features=10, bias=True)
)

5、定义损失器和优化器

算法:torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)[source] 梯度算法

params:待优化参数

lr: 学习率,默认为 0.001

batas:用于计算梯度以及梯度平方的运行平均值的系数(默认:0.9,0.999)

eps:为了增加数值计算的稳定性而加到分母里的项 默认:1e-8

weight_decay :权重衰减(L2惩罚) 默认为0

loss_func = torch.nn.CrossEntropyLoss()# Pytorch常用的交叉熵损失函数
opt = torch.optim.Adam(model.parameters(),lr=0.001) 

parameters():构建好神经网络后,网络的参数都保存在parameters()函数当中

使用损失函数的步骤:

  • 获取损失:loss = loss_func(out,batch_y)
  • 清空上一步残余更新参数:opt.zero_grad()
  • 误差反向传播:loss.backward()
  • 将参数更新值施加到netparmeters上:opt.step()

6、训练

注意,此处有一个模型保存的步骤,torch.save()


# 进行训练

loss_count = []
for epoch in range(2):
    for i,(x,y) in enumerate(train_loader):
        batch_x = Variable(x) # torch.Size([128, 1, 28, 28])
        batch_y = Variable(y) # torch.Size([128])
        # 获取最后输出
        out = model(batch_x) # torch.Size([128,10])
        # 获取损失
        loss = loss_func(out,batch_y)
        # 使用优化器优化损失
        opt.zero_grad()  # 清空上一步残余更新参数值
        loss.backward() # 误差反向传播,计算参数更新值
        opt.step() # 将参数更新值施加到net的parmeters上 梯度下降进一步?
        if i%20 == 0:
            loss_count.append(loss)
            print('{}:t'.format(i), loss.item())
            torch.save(model,r'G:/vs/numbe_select/log_CNN/model.txt')
        if i % 100 == 0:
            for a,b in test_loader:
                test_x = Variable(a)
                test_y = Variable(b)
                out = model(test_x)
                # print('test_out:t',torch.max(out,1)[1])
                # print('test_y:t',test_y)
                accuracy = torch.max(out,1)[1].numpy() == test_y.numpy()
                print('accuracy:t',accuracy.mean()) # 求均值,一定程度上反映了结果的准确度。
                break
plt.figure('PyTorch_CNN_Loss')
plt.plot(loss_count,label='Loss')
plt.legend()
plt.show()

结果如下 :

损失图

7、测试

# 测试网络
model = torch.load(r'C:UserslievDesktopmyprojectyin_testlog_CNN')
# 加载前面训练好的模型,进行测试
accuracy_sum = []
for i,(test_x,test_y) in enumerate(test_loader):
    test_x = Variable(test_x)
    test_y = Variable(test_y)
    out = model(test_x)
    # print('test_out:t',torch.max(out,1)[1])
    # print('test_y:t',test_y)
    accuracy = torch.max(out,1)[1].numpy() == test_y.numpy()
    accuracy_sum.append(accuracy.mean())
    print('accuracy:t',accuracy.mean())

print('总准确率:t',sum(accuracy_sum)/len(accuracy_sum))
# 精确率图
print('总准确率:t',sum(accuracy_sum)/len(accuracy_sum))
plt.figure('Accuracy')
plt.plot(accuracy_sum,'o',label='accuracy')
plt.title('Pytorch_CNN_Accuracy')
plt.legend()
plt.show()

8、优化

1、全连接第一层增加ReLU激活函数:提高了0.02

accuracy:	 0.98
accuracy:	 0.99
accuracy:	 0.98
accuracy:	 0.99
总准确率:	 0.9872999999999992

2、去掉批量归一化:降低了0.01

accuracy:	 0.97
accuracy:	 0.97
accuracy:	 0.92
总准确率:	 0.9746999999999996

3、使用LeakyReLU激活函数:降低0.01

accuracy:	 0.97
accuracy:	 0.98
accuracy:	 1.0
总准确率:	 0.9848999999999997

4、使用PReLU激活函数:提升0.01

accuracy:	 0.97
accuracy:	 1.0
accuracy:	 1.0
accuracy:	 0.97
总准确率:	 0.9867999999999998

转载自:https://blog.csdn.net/qq_34714751/article/details/85610966

未经允许不得转载:书生吴小帅 » pytouch 学习(2):使用 pytouch 实现 CNN 手写字体识别

赞 (12)