python

关注公众号 jb51net

关闭
首页 > 脚本专栏 > python > NumPy np.where() 用法

详解NumPy中np.where() 的两种神奇用法

作者:司徒轩宇

np.where()是 NumPy 中用于条件选择和元素定位的核心函数,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

在数据科学和数值计算的世界里,NumPy 就像是一把瑞士军刀,而 np.where() 无疑是其中最锋利的工具之一。今天,我们将深入探索这个功能强大的函数,学会如何用它优雅地处理条件逻辑和数据选择。

什么是 np.where()?

简单来说,np.where() 是 NumPy 中用于条件选择和元素定位的核心函数。它有两种主要用法:

让我们通过实例来探索这两种用法!

用法一:三元条件替换(条件 ? 值1 : 值2)

这是 np.where() 最常用的形式,语法为:np.where(condition, x, y)

基础示例

import numpy as np

# 创建示例数组
temperatures = np.array([22, 28, 15, 32, 18, 25])

# 标记高温和低温
result = np.where(temperatures > 25, "高温", "舒适")
print(result)
# 输出:['舒适' '高温' '舒适' '高温' '舒适' '舒适']

实际应用:成绩分类

scores = np.array([75, 92, 58, 81, 45, 67, 88])

# 根据分数分类
grade = np.where(scores >= 90, "A",
          np.where(scores >= 80, "B",
            np.where(scores >= 70, "C",
              np.where(scores >= 60, "D", "F"))))

print(grade)
# 输出:['C' 'A' 'F' 'B' 'F' 'D' 'B']

多条件组合

data = np.array([12, 25, 7, 18, 30, 5, 22])

# 组合条件:大于10且小于20
result = np.where((data > 10) & (data < 20), data, 0)
print(result)  # 输出:[12  0  0 18  0  0  0]

# 使用 | 表示 OR 条件
result = np.where((data < 10) | (data > 20), data, -1)
print(result)  # 输出:[ -1  25   7  -1  30   5  22]

用法二:定位元素索引

当我们只提供条件参数时,np.where() 会返回满足条件元素的索引。

语法:np.where(condition)

一维数组示例

arr = np.array([0, 5, 0, 8, 0, 3, 0])

# 找到非零元素的索引
non_zero_indices = np.where(arr != 0)
print(non_zero_indices)  # 输出:(array([1, 3, 5]),)

# 提取非零值
print(arr[non_zero_indices])  # 输出:[5 8 3]

二维数组示例

matrix = np.array([[1, 0, 4],
                   [0, 5, 0],
                   [7, 0, 9]])

# 找到值大于3的元素位置
rows, cols = np.where(matrix > 3)

print("行索引:", rows)    # 输出:[0 1 2 2]
print("列索引:", cols)    # 输出:[2 1 0 2]
print("对应值:", matrix[rows, cols])  # 输出:[4 5 7 9]

实际应用:图像处理

# 创建一个简单的图像矩阵 (5x5)
image = np.array([[120, 130, 40, 200, 210],
                  [30, 145, 255, 180, 10],
                  [220, 25, 30, 190, 200],
                  [100, 110, 120, 130, 140],
                  [50, 60, 70, 80, 90]])

# 找到高光区域(值>200)
highlight_rows, highlight_cols = np.where(image > 200)

print("高光像素位置:")
for r, c in zip(highlight_rows, highlight_cols):
    print(f"({r}, {c}) - 值: {image[r, c]}")

# 输出:
# (0, 3) - 值: 200
# (0, 4) - 值: 210
# (1, 2) - 值: 255
# (2, 0) - 值: 220
# (2, 3) - 值: 190 -> 注意:190不大于200,实际应为 (2, 4): 200
# 更正:矩阵中 (2,4) 是200,所以应包含

进阶技巧与注意事项

1. 广播机制

np.where() 支持 NumPy 的广播机制,使不同形状的数组能够一起工作:

# 二维条件与一维值组合
condition_2d = np.array([[True, False], [False, True]])
result = np.where(condition_2d, [10, 20], 0)

print(result)
# 输出:
# [[10  0]
#  [ 0 20]]

2. 直接修改满足条件的值

data = np.array([5, 12, 8, 15, 3, 10])

# 将小于10的值替换为0
data[np.where(data < 10)] = 0
print(data)  # 输出:[ 0 12  0 15  0 10]

3. 多维度索引

对于三维或更高维数组,np.where() 同样适用:

# 创建3x3x3数组
cube = np.random.randint(0, 10, (3, 3, 3))

# 找到所有大于8的元素
indices = np.where(cube > 8)

# 输出三维索引
print("维度0:", indices[0])
print("维度1:", indices[1])
print("维度2:", indices[2])

# 访问这些元素
print("满足条件的值:", cube[indices])

性能优势

与 Python 循环相比,np.where() 有显著的性能优势:

import time

large_array = np.random.rand(10**6)

# 使用循环
start = time.time()
result_loop = [x*2 if x > 0.5 else x/2 for x in large_array]
print("循环耗时:", time.time() - start)

# 使用 np.where
start = time.time()
result_np = np.where(large_array > 0.5, large_array*2, large_array/2)
print("np.where耗时:", time.time() - start)

测试结果(可能因机器而异):

循环耗时: 0.45秒
np.where耗时: 0.02秒

到此这篇关于详解NumPy中np.where() 的两种神奇用法的文章就介绍到这了,更多相关NumPy np.where() 用法内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

您可能感兴趣的文章:
阅读全文