python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch交叉熵损失函数

pytorch中交叉熵损失函数的使用小细节

作者:Mr_health

这篇文章主要介绍了pytorch中交叉熵损失函数的使用细节,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

目前pytorch中的交叉熵损失函数主要分为以下三类,我们将其使用的要点以及场景做一下总结。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

①当为标准的二分类时,网络的输出节点为1

②当为非互斥的多分类时,分类个数即为网络的输出节点数

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

网络的输出节点可以为2,此时概率必须由softmax进行映射。

②当为非互斥的多分类时,分类个数即为网络的输出节点数,此时概率必须由sigmoid进行映射

1.二分类

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

类型三:F.binary_cross_entropy()与torch.nn.BCELoss(),以下两种情况都可以使用:

这里我们以网络输出节点为2为例,由于类型二要求网络的输出节点为1,因此暂时不纳入讨论,主要讨论类型和类型三。

测试代码如下:

(网络输出节点为1的二分类就是目前GAN的实现方式,该方式下类型一的函数不可用,只能采用类型二和类型三,后面将会详细讨论)

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1],
                    [-1.587,  -0.5907]])
classes = 2
label = torch.tensor([1, 1])
logits = torch.from_numpy(logits).float()
 
#F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
#nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
 
#可以看到,loss1是等于loss2的
 
prob = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
#F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob, one_hot_label) #输入概率和one-hot
print(loss3)
 
#torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob, one_hot_label)
print(loss4)
 
#同理,loss3是等于loss4的
 
#手动实现二分类的交叉熵损失
shixian = -torch.mean(torch.sum(one_hot_label * torch.log(prob), axis = 1))  #手动实现
print(shixian)

2.多分类

此时网络输出时多节点,每一个节点代表一个类别。

类型一:F.cross_entropy()与torch.nn.CrossEntropyLoss()

类型二:F.binary_cross_entropy_with_logits()与torch.nn.BCEWithLogitsLoss()

类型三:F.binary_cross_entropy()与torch.nn.BCELoss()

这里我们首先讨论下类型一和类型三,为什么类型三不能用于多分类的互斥任务,只能用于多分类多标签的分类任务?我们来看一段代码,这里有三个类别,两个样本。

softmax = torch.nn.Softmax()
logits = np.array([[0.7, -0.1, 0.2],
                    [-1.587,  -0.5907, 0.3]])
classes = 3
label = torch.tensor([1, 2])
logits = torch.from_numpy(logits).float()
 
### F.cross_entropy
loss1 = F.cross_entropy(logits, label)  
print(loss1)
 
### nn.CrossEntropyLoss()
criterion = nn.CrossEntropyLoss()
loss2 = criterion(logits, label)
print(loss2)
##loss1 = loss2

上面是采用类型一的两个函数计算而来,loss1 = loss2 = 0.9833

然后我们用类型三的函数来实现,同样将logit通过softmax映射为概率,运行后的结果可以看loss3 =loss4 = 0.5649,不等于类型一的函数的结果的。

prob_softmax = softmax(logits)  #计算概率
one_hot_label = one_hot(label, classes)
 
## F.binary_cross_entropy
loss3 = F.binary_cross_entropy(prob_softmax, one_hot_label) #输入概率和one-hot
print(loss3)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss4 = adversarial_loss(prob_softmax, one_hot_label)
print(loss4)

最后我们再手动实现类型三的损失究竟是怎么得到的:

#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_softmax) + (1-one_hot_label) * torch.log(1-prob_softmax))
print(shixian)

可以看出来,F.binary_cross_entropy()与torch.nn.BCELoss()是将网络的每个节点看作是一个二分类的节点来计算交叉熵损失的。

进一步来讨论下类型二和类型三的一致性,代码如下。由于类型二中函数自动将logit通过sigloid函数映射为概率,为了检验一致性性,我门也需要通过sigmoid计算类型三所需要的概率。

最后可以看到下面的输出均为0.6378

sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
##类型二
##F.binary_cross_entropy_with_logits
loss5 = F.binary_cross_entropy_with_logits(logits, one_hot_label)
print(loss5)
 
##torch.nn.BCEWithLogitsLoss()
BCEWithLogitsLoss = torch.nn.BCEWithLogitsLoss()
loss6 = BCEWithLogitsLoss(logits, one_hot_label)
print(loss6)
 
##类型三
##F.binary_cross_entropy
loss7 = F.binary_cross_entropy(prob_sig, one_hot_label) #输入概率和one-hot
print(loss7)
 
## torch.nn.BCELoss()
adversarial_loss = torch.nn.BCELoss()
loss8 = adversarial_loss(prob_sig, one_hot_label)
print(loss8)
 
#手动实现
shixian = -torch.mean(one_hot_label * torch.log(prob_sig) + (1-one_hot_label) * torch.log(1-prob_sig))
print(shixian)

3. GAN中的实现:二分类

GAN中的判别器出的损失就是典型的最小化二分类的交叉熵损失。但是在实现上,与二分类网络不同。

正因为判别器的输出是一维,类型一的两个函数F.cross_entropy()与torch.nn.CrossEntropyLoss()是没有办法使用的,因为这两个函数要求输入是二维的,即分别在real和fake的logit。因此只能采用类型二或者类型三的函数。

很多GAN网络采用的二分类交叉熵损失函数如下:

#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss(logit,y)
#类型三:
adversarial_loss_3 = torch.nn.BCELoss(p,y)

前面我们讲到,类型二和类型三的函数都是将每一个节点视为一个二分类的节点,因此对于每一个给节点,其具体的表达式可以写为:

#类型二:
torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
# 其中logit表示判断为real的logit
# y=1表示real
# y=0表示fake
 
#类型三:
torch.nn.BCELoss(p, y) = - (ylog(p) + (1-y)log(1-p))
# 其中p表示判断为real的概率
# y=1表示real
# y=0表示fake

3.1 判别器损失计算

判别器输出维度为1,输出logit,有两个样本,都为fake图像

logits = np.array([1.2, -0.5])
logits = torch.from_numpy(logits).float()
sigmoid = nn.Sigmoid()
prob_sig = sigmoid(logits)  #计算概率
 
label = torch.tensor([1, 1]).float()
 
#类型二:
adversarial_loss_2 = torch.nn.BCEWithLogitsLoss()
loss_2 = adversarial_loss_2(logits, 1-label)  #因为是fake,需要将y设置为0
print(loss_2)
 
#类型三:
adversarial_loss_3 = torch.nn.BCELoss()
loss_3 = adversarial_loss_3(prob_sig, 1-label) #因为是fake,需要将y设置为0
print(loss_3)
#输出均为0.9687

 通过上述代码可以分析如下:

(1)当样本为fake时,网络输出其为real的logit:

(2)样本为real,网络输出其为real的logit:

GAN网络在更新判别器时,代码一般如下:

criterion = torch.nn.BCELoss()
real_out = D(real_img)  # 将真实图片放入判别器中
d_loss_real = criterion(real_out, 1)  # 真实样本的损失
 
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
d_loss_fake = criterion(fake_out, 0)  # 生成样本的损失
 
d_loss = d_loss_real + d_loss_fake  #  两个相加 就是标准的交叉熵损失
 
optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

3.2 生成器的损失计算

前面判别器处的损失是最小化交叉熵损失:

min - (ylog(p) + (1-y)log(1-p))

那么生成器与之相反就是最大化交叉熵损失:

max - (ylog(p) + (1-y)log(1-p))

因为真实样本于与生成器无关,因此可以转变为min log(1-p)

max - ((1-y)log(1-p)) = min (1-y)log(1-p) = min log(1-p)

上述形式为饱和形式,转变为非饱和如下。

min -log(p)

可以看到上式子在形式上就是将fake图像当作real图像进行优化。

可以这么理解:生成器的作用的就是尽可能生成逼近与real的fake,由于判别器判断的结果p就是表示图像为real的概率,那么生成器就希望p越高越好。而在训练判别器时,判别器对real的优化就是让其p越高越好,即尽可能的区分real和fake。

因此在更新生成器时,fake处的损失与更新判别器在real处的损失在逻辑上是一致的。

criterion = torch.nn.BCELoss()
fake_img = G(z)  # 随机噪声放入生成网络中,生成一张假的图片
fake_out = D(fake_img)  # 判别器判断假的图片,
G_loss = criterion(fake_out, 1)  # 假样本的损失
 
 
optimizer_G.zero_grad()
G_loss .backward()
optimizer_G.step()

3.3 小结

在GAN网络中,由于输出网络只有一个节点,表示图像属于real的logit或者prob,因此一般使用类型二和类型三的损失函数。

两类函数的实现如下:

torch.nn.BCEWithLogitsLoss(logit,y) = - (ylog(sigmoid(logit)) + (1-y)log(1-sigmoid(logit)))
torch.nn.BCELoss(p, y) = - (ylog(prob) + (1-y)log(1-prob))

因为上述实现:

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文