python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > pytorch tensor合并与分割

pytorch tensor合并与分割方式

作者:wyw0000

这篇文章主要介绍了pytorch tensor合并与分割方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教

1. cat

torch.cat(tensors, dim=0, *, out=None) → Tensor

在指定维度上,连接给定tensor序列或empty,除连接的dimension外,所有得的ensor必须有相同的shape

参数:

输出:

连接后的tensor

上图分别是在列和行两个维度连接后的结果

2. stack

创建新维度来连接张量序列

torch.stack(tensors, dim=0, *, out=None) → Tensor

参数:

输出:

连接后的tensor

注意:cat和stack的区别

stack连接的tensor必须具有相同的size,否则报错,cat是除连接的维度外,其他维度shape必须相同

如下示例:

3. split

把一个tensor切分成块,每个块是原tensor的一部分

torch.split(tensor, split_size_or_sections, dim=0)

参数:

输出:

Tuple[Tensor, …]

示例:

4. chunk

强制将一个tensor切分成指定数量的块,每个块是原tensor的一部分

torch.chunk(input, chunks, dim=0) → List of Tensors

参数:

输出:

切分后的list

示例:

注意:split与chunk的区别

区别主要是第二个参数,split第二个参数切分块的size,而chunk是切分块的数量

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。 

您可能感兴趣的文章:
阅读全文