np.where在多维数组的应用方式
作者:Alex-Leung
这篇文章主要介绍了np.where在多维数组的应用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
np.where在多维数组的应用
函数用途
返回查找的参数,在数组中的索引。
Code
举例:
- 一般卷积神经网络的输入或者输出为一个四维的数组/Tensor。
- 一般为[batch_size, channel, height, width]
下面代码目标是输出所有值为0的数字的索引。
output = [[ [[1, 0, 2], [2, 1, 0], [1, 0, 0]] ]] arr = np.array(output) print(arr.shape) res = np.where(arr==0) print(res)
Output
# print(arr.shape) (1, 1, 3, 3) # print(res) (array([0, 0, 0, 0], dtype=int64), array([0, 0, 0, 0], dtype=int64), array([0, 1, 2, 2], dtype=int64), array([1, 2, 1, 2], dtype=int64))
np.where的输出结果为一个list,里面包含4个ndarray,分别代表四维。
[0, 0, 0, 0] # axis=0 [0, 0, 0, 0] # axis=1 [0, 1, 2, 2] # axis=2 [1, 2, 1, 2] # axis=3
正确读值,从列来看,四个0值的索引分别是
print(arr[0][0][0][1]) # output:0 print(arr[0][0][1][2]) # output:0 print(arr[0][0][2][1]) # output:0 print(arr[0][0][2][2]) # output:0
np.where()用法解析
语法说明
np.where(condition,x,y)
- 当where内有三个参数时,第一个参数表示条件,当条件成立时where方法返回x,当条件不成立时where返回y
np.where(condition)
- 当where内只有一个参数时,那个参数表示条件,当条件成立时,where返回的是每个符合condition条件元素的坐标,返回的是以元组的形式,坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。
多条件condition
- -&表示与,|表示或。
- 如a = np.where((a>0)&(a<5), x, y),当a>0与a<5满足时,返回x的值,当a>0与a<5不满足时,返回y的值。
- 注意:x, y必须和a保持相同维度,数组的数值才能一一对应。
示例
(1)一个参数
import numpy as np a = np.arange(0, 100, 10) b = np.where(a < 50) c = np.where(a >= 50)[0] print(a) print(b) print(c)
结果如下:
[ 0 10 20 30 40 50 60 70 80 90]
(array([0, 1, 2, 3, 4]),)
[5 6 7 8 9]
说明:
- b是符合小于50条件的元素位置,b的数据类型是tuple
- c是符合大于等于50条件的元素位置,c的数据类型是numpy.ndarray
(2)三个参数
a = np.arange(10) b = np.arange(0,100,10) print(np.where(a > 5, 1, -1)) print(b) print(np.where((a>3) & (a<8),a,b)) c=np.where((a<3) | (a>8),a,b) print(c)
结果如下:
[-1 -1 -1 -1 -1 -1 1 1 1 1]
[ 0 10 20 30 40 50 60 70 80 90]
[ 0 10 20 30 4 5 6 7 80 90]
[ 0 1 2 30 40 50 60 70 80 9]
说明:
- np.where(a > 5, 1, -1) ,满足条件是1,不满足是-1
- np.where((a>3) & (a<8),a,b),满足条件是a ,不满足是b ,a和b的维度相同
注意:
& | 与和或,每个条件一定要用括号,否则报错
c=np.where((a<3 | a>8),a,b)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。