PyTorch之torch.matmul函数的使用及说明
作者:Midsummer-逐梦
PyTorch的torch.matmul是一个强大的矩阵乘法函数,支持不同维度张量的乘法运算,包括广播机制。提供了矩阵乘法的语法,参数说明,以及使用示例,帮助理解其应用方式和乘法规则
一、简介
torch.matmul
用于两维或更高维张量的矩阵乘法操作。
它支持广播机制,并且能够处理不同形状和维度的张量,适用于广泛的应用场景。
二、语法
torch.matmul
函数的基本语法如下:
torch.matmul(input, other, *, out=None)
三、参数
input
:第一个输入张量。other
:第二个输入张量,与input
进行矩阵乘法。out
(可选):存储输出结果的张量。
四、示例
下面通过几个简单的例子来演示 torch.matmul
的用法。
示例 1:二维矩阵乘法
import torch # 创建两个二维张量 a = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([[5, 6], [7, 8]]) # 使用 torch.matmul 进行矩阵乘法 result = torch.matmul(a, b) print(a) print(b) print("二维矩阵乘法结果:") print(result)
输出:
tensor([[1, 2],
[3, 4]])
tensor([[5, 6],
[7, 8]])
二维矩阵乘法结果:
tensor([[19, 22],
[43, 50]])
在这个例子中,torch.matmul
对两个二维张量进行了标准的矩阵乘法。
示例 2:高维张量乘法
import torch # 创建两个高维张量 a = torch.randn(2, 3, 4) b = torch.randn(2, 4, 5) # 使用 torch.matmul 进行高维张量乘法 result = torch.matmul(a, b) print("高维张量乘法结果的形状:") print(result.shape)
输出:
高维张量乘法结果的形状:
torch.Size([2, 3, 5])
在这个例子中,torch.matmul
对两个高维张量进行了矩阵乘法,并且结果张量的形状是 [2, 3, 5]
,符合矩阵乘法的规则。
没有了解过的童鞋可能对这里的乘法规则有所迷惑,因此解释一下:对于高维($\geq$3维度)矩阵乘法,只要保持最后两个维(低二维)的矩阵满足普通矩阵乘法规则,高维的各维度保持相等或对应维度中有一个为1即可。
在这里第2维为(3,4)与(4,5)满足普通矩阵乘法要求,然后高维相等直接对应位置矩阵相乘即可。
下面的广播机制是高维为1的情况,此时会触发广播机制完成高维矩阵的乘法。
示例 3:广播机制
import torch # 创建两个可以广播的张量 a = torch.randn(2, 3, 4) b = torch.randn(4, 5) # 使用 torch.matmul 进行广播机制的矩阵乘法 result = torch.matmul(a, b) print("广播机制下的矩阵乘法结果的形状:") print(result.shape)
输出:
广播机制下的矩阵乘法结果的形状:
torch.Size([2, 3, 5])
在这个例子中,b
张量的形状是 [4, 5]
,通过广播机制,与 a
张量的形状 [2, 3, 4]
进行了兼容,并得到了结果张量的形状 [2, 3, 5]
。
总结
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。