PyTorch零基础入门之逻辑斯蒂回归
作者:山顶夕景
PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序,它是一个可续计算包,提供两个高级功能:1、具有强大的GPU加速的张量计算(如NumPy)。2、包含自动求导系统的深度神经网络
学习总结
(1)和上一讲的模型训练是类似的,只是在线性模型的基础上加个sigmoid,然后loss函数改为交叉熵BCE函数(当然也可以用其他函数),另外一开始的数据y_data也从数值改为类别0和1(本例为二分类,注意x_data
和y_data
这里也是矩阵的形式)。
一、sigmoid函数
logistic function是一种sigmoid函数(还有其他sigmoid函数),但由于使用过于广泛,pytorch默认logistic function叫为sigmoid函数。还有如下的各种sigmoid函数:
二、和Linear的区别
逻辑斯蒂和线性模型的unit区别如下图:
sigmoid
函数是不需要参数的,所以不用对其初始化(直接调用nn.functional.sigmoid
即可)。
另外loss函数从MSE改用交叉熵BCE:尽可能和真实分类贴近。
如下图右方表格所示,当 y ^ \hat{y} y^越接近y时则BCE Loss值越小。
三、逻辑斯蒂回归(分类)PyTorch实现
# -*- coding: utf-8 -*- """ Created on Mon Oct 18 08:35:00 2021 @author: 86493 """ import torch import torch.nn as nn import matplotlib.pyplot as plt import torch.nn.functional as F import numpy as np # 准备数据 x_data = torch.Tensor([[1.0], [2.0], [3.0]]) y_data = torch.Tensor([[0], [0], [1]]) losslst = [] class LogisticRegressionModel(nn.Module): def __init__(self): super(LogisticRegressionModel, self).__init__() self.linear = torch.nn.Linear(1, 1) def forward(self, x): # 和线性模型的网络的唯一区别在这句,多了F.sigmoid y_predict = F.sigmoid(self.linear(x)) return y_predict model = LogisticRegressionModel() # 使用交叉熵作损失函数 criterion = torch.nn.BCELoss(size_average = False) optimizer = torch.optim.SGD(model.parameters(), lr = 0.01) # 训练 for epoch in range(1000): y_predict = model(x_data) loss = criterion(y_predict, y_data) # 打印loss对象会自动调用__str__ print(epoch, loss.item()) losslst.append(loss.item()) # 梯度清零后反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 画图 plt.plot(range(1000), losslst) plt.ylabel('Loss') plt.xlabel('epoch') plt.show() # test # 每周学习的时间,200个点 x = np.linspace(0, 10, 200) x_t = torch.Tensor(x).view((200, 1)) y_t = model(x_t) y = y_t.data.numpy() plt.plot(x, y) # 画 probability of pass = 0.5的红色横线 plt.plot([0, 10], [0.5, 0.5], c = 'r') plt.xlabel('Hours') plt.ylabel('Probability of Pass') plt.grid() plt.show()
可以看出处于通过和不通过的分界线是Hours=2.5。
Reference
到此这篇关于PyTorch零基础入门之逻辑斯蒂回归的文章就介绍到这了,更多相关PyTorch 逻辑斯蒂回归内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!