深入理解numpy广播机制
作者:UQI-LIUWJ
广播(Broadcast)是 numpy 对不同形状(shape)的数组进行数值计算的方式,本文深入理解numpy广播机制,具有一定的参考价值,感兴趣的可以了解一下
1 广播规则
- 如果两个数组的维度数不相同,那么小维度数组的形状会在左边补1。
- 如果两个数组在某个维度上的大小不匹配,并且其中一个数组在该维度上的大小为1,则该数组会沿着这个维度扩展以匹配另一个数组的大小。
- 如果在任何维度上大小都不匹配并且没有一个大小为1,那么会引发错误。
2 举例
2.1 基本广播
import numpy as np a = np.array([1, 2, 3]) b = 2 print(a * b) # =[2 4 6]
在这里,b
被广播到与 a
相同的大小,然后进行乘法。
2.2 维度不同
a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) b = np.array([1, 0, 1]) print(a + b) ''' [[ 2 2 4] [ 5 5 7] [ 8 8 10]] '''
- a的维度是[3,3],b是3
- b的维度首先在左边补1(变成[1,3])
- [[1,0,1]]
- 然后b行复制,变成[3,3]
- [[1,0,1], [1,0,1], [1,0,1]]
- 然后两个[3,3]的矩阵相加即可
2.3 两个数组都需要广播
a = np.array([[1], [2], [3]]) b = np.array([1, 2, 3]) print(a + b) ''' [[2 3 4] [3 4 5] [4 5 6]] '''
在这里,a
的形状是 (3,1),b
的形状是 (3,)。
a
被广播到 (3,3),b
也被广播到 (3,3),然后它们进行加法。
2.4 不兼容的形状
a = np.array([1, 2, 3]) b = np.array([1, 2]) print(a + b) # 这将引发错误,因为形状不兼容
到此这篇关于深入理解numpy广播机制的文章就介绍到这了,更多相关numpy广播内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!