python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > Pytorch知识点

Pytorch相关知识介绍与应用

作者:明湖小虾

最近又重拾了机器学习的相关技术,在本科毕设的阶段下学习使用了Tensorflow 2.x工具,当时也是不求甚解,直接拿来用了,但现在已经有充足的时间、精力和基础知识来重新学习一下

前言

目前机器学习框架有两大方向,Pytorch和Tensorflow 2。对于机器学习的小白的我来说,直观的感受是Tensorflow的框架更加傻瓜式,在这个框架下只需要定义神经网络的结构、输入和输出,然后直接使用其框架下的各种框架函数即可。而对于Pytorch来说,则使用者能操作、定义的细节更多,但与此同时使用难度也会更高。

通过各种资料也显示,在学术研究范围内,越来越多的人使用Pytorch,其实Tensorflow也不错,但对于普通小白来说入手更快,应用也更快。本着全面发展,多尝试的心态,开始Pytorch学习。

小编将从自身的理解习惯开始不断更新这篇博文:

1.Pytorch简介

Pytorch就是一个神经网络框架,使用Pytorch可以跳过很多不必要的底层工作,很多通用的方法、数据结构都已经实现供我们调用,从而可以让我们将精力集中在改进数据质量、网络结构和评估方法上去。

使用和训练神经网络从思考顺序上来说无非就三个阶段:

1)构思神经网络的输入、输出和网络结构,其中输入输出非常关键。

2)训练数据集(粗糙的原始数据)。

3)如何将训练数据集转换成神经网络能够接受并且能够正确输出的结构。

4)训练神经网络并进行预测。

2.Pytorch定义神经网络的输入输出和结构

使用Pytorch定义神经网络非常通用的格式:

class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()#继承tOrch中已经写好的类,包含神经网络其余所有通用必要方法函数。
        self.flatten=nn.Flatten()#加入展平函数。
        self.net=nn.Sequential(#调用Sequential方法定义神经网络。
            nn.Linear(100*3,100*3),
            nn.ReLU(),
            nn.Linear(100*3,100*3),
            nn.ReLU(),
            nn.Linear(100*3,27)
        )
    def forward(self,x):#自定义神经网络的前向传播函数,本文使用了正常的前向传播函数,但最终的结果给出三个输出。
        result=self.net(x)
        r1=result[:9]
        r2=result[9:18]
        r3=result[18:27]    
        return [r1,r2,r3]
 

到这里基本上已经定义了自己的神经网络了,输入为100*3=200个数据、输出为27个数据。那么问题来了,怎么把数据输入进去呢?

3.Pytorch神经网络的数据格式-tensor

对于编程小白、机器学习小白的我或者大家来说,tensor的直接定义不好理解。

tensor表面上只进行了存储,但实际上它包含了很多中方法,直接使用tensor.Method()调用相关方法即可,而省去了自己来定义函数,再操作数据结构。并且在Pytorch进行训练时,也会在其内部调用这些方法,所以就需要我们使用这些数据结构来作为Pytorch神经网的输入,并且神经网络的输出也是tensor形式,numpy array 和 list 和 tensor 的转换其实就是数据相同,但集成了不同方法的数据结构。

那么下面就是输入数据的定义。train_data和labels都是我们使用python方法写出的list。

#train_data、labels都是list,经过list->ndarray->tensor的转换过程。
train_data=torch.tensor(np.array(train_data)).to(torch.float32).to(device)
labels=torch.tensor(np.array(labels)).to(torch.float32).to(device)

4.神经网络进行预测

使用神经网络进行预测(前向传播)、计算损失函数、反向传播更新梯度

1)进行前向传播

#train_data[0]即为训练数据的第一条输入数据。
prediction=model(train_data[0])

2)计算损失

#定义优化器
optim=torch.optim.SGD(model.parameters(),lr=1e-2,momentum=0.9)
# 定义自己的loss 
loss=(prediction[0]-labels[0]).sum()+(prediction[1]-labels[1]).sum()+(prediction[2]-labels[2]).sum()
#反向传播
optim.zero_grad()#清除上一次的静态梯度,防止累加。
loss.backward()#计算反向传播梯度。
optim.step()#进行一次权值更新。
 

此处的计算损失和权值依据输入数据更新一次的结果,由此加入一个循环,便可以实现神经网络的训练过程。

3) 训练网络

在正式进入训练网络之前,我们还需要了解一个叫做Batch的东西,如果我们将数据一个一个送进去训练,那么神经网络训练的速度将是十分缓慢的,因此我们每次可以丢进去很多数据让神经网络进行预测,通过计算总体的损失就可以让梯度更快地下降。但训练数据有时又很巨大,所以就需要将整个训练数据打包成一批一批的进入训练,并重复若干次,每训练整个数据一次,会经历若干个batch,这一过程称为一个epoch。

所以为了使网络预测结果更快地收敛,即更快地训练神经网络,我们需要首先对数据进行打包。

import torch.utils.data as Data
bath=50#每批次大小
loader=Data.DataLoader(#制作数据集,只能由cpu读取
    dataset = train_data_set,
    batch_size=bath,#每批次包含数据条数
    shuffle=True,#是否打乱数据
    num_workers=1,#多少个线程搬运数据
)

然后,我们就可以进行神经网络的训练了:

pstep=2#每个多少个批次就输出一次结果
for epoch in range(1000):
    running_loss=0.0
    for step,(inps,labs) in enumerate(loader):
        #取出数据并搬运至GPU进行计算。
        labs=labs.to(device)
        inps=inps.to(device) 
        outputs = model(inps)#将数据输入进去并进行前向传播
        loss=loss_fn(outputs,labs)#损失函数的定义
        optimizer.zero_grad()#清楚上一次的静态梯度,防止累加。
        loss.backward()#反向传播更新梯度
        optimizer.step()#进行一次优化。
        running_loss += loss.item()#不加item()会造成内存堆叠
        size=len(labs)*3
        correct=0
        #print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1))
        #逐个判断计算准确率
        correct+=(outputs.argmax(-1)==labs.argmax(-1)).type(torch.float).sum().item()
        if step % pstep == pstep-1:    # print every 10 mini-batches
            print('[%d, %5d] loss: %.3f correct:%.3f' %
                  (epoch + 1, step + 1, running_loss / pstep,correct/size))
            if correct/size>1:#错误检查
                print("outputs:",outputs.argmax(-1),"\nlabs:",labs.argmax(-1),"\ncorrect:",correct,"\nSize:",size)
            running_loss = 0.0
#保存模型
torch.save(model.state_dict(), "model.pth")
print("Saved PyTorch Model State to model.pth")

到此这篇关于Pytorch相关知识介绍与应用的文章就介绍到这了,更多相关Pytorch知识点内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文