python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch torch.nn.Identity()的作用

pytorch之torch.nn.Identity()的作用及解释

作者:会写代码的孙悟空

这篇文章主要介绍了pytorch之torch.nn.Identity()的作用及解释,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

torch.nn.Identity()的作用及解释

class Identity(Module):
    r"""A placeholder identity operator that is argument-insensitive.
    Args:
        args: any argument (unused)
        kwargs: any keyword argument (unused)
    Examples::
        >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 20])
    """
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()
    def forward(self, input: Tensor) -> Tensor:
        return input

通过阅读源码可以看到,identity模块不改变输入。

直接return input

一种编码技巧吧,比如我们要加深网络,有些层是不改变输入数据的维度的,

在增减网络的过程中我们就可以用identity占个位置,这样网络整体层数永远不变,

看起来可能舒服一些,

可能理解的不到位。。。。

Pytorch-torch.nn.identity()方法

identity模块不改变输入,直接return input

一种编码技巧吧,比如我们要加深网络,有些层是不改变输入数据的维度的,在增减网络的过程中我们就可以用identity占个位置,这样网络整体层数永远不变,

应用:

例如此时:如果此时我们使用了se_layer,那么就SELayer(dim),否则就输入什么就输出什么(什么都不做)

m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
input = torch.randn(128, 20)
output = m(input)
print(output.size()) >> torch.Size([128, 20])

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

您可能感兴趣的文章:
阅读全文