首页

源码搜藏网

首页 > 开发教程 > ajax教程 >

pytorch中的广播语义

创建时间:2022-03-09 19:25  

pytorch的广播语义(broadcasting semantics),和numpy的很像,所以可以先看看numpy的文档

1、什么是广播语义?

官方文档有这样一个解释:

In short, if a PyTorch operation supports broadcast, then its Tensor arguments can be automatically expanded to be of equal sizes (without making copies of the data).

这句话的意思大概是:简单的说,如果一个pytorch操作支持广播,那么它的Tensor参数可以自动的扩展为相同的尺寸(不需要复制数据)。

按照我的理解,应该是指算法计算过程中,不同的Tensor如果size不同,但是符合一定的规则,那么可以自动的进行维度扩展,来实现Tensor的计算。在维度扩展的过程中,并不是真的把维度小的Tensor复制为和维度大的Tensor相同,因为这样太浪费内存了。

2、广播语义的规则

首先来看标准的情况,两个Tensor的size相同,则可以直接计算:

x = torch.empty((4, 2, 3))
y = torch.empty((4, 2, 3))
print((x+y).size())

输出:

torch.Size([4, 2, 3])

但是,如果两个Tensor的维度并不相同,pytorch也是可以根据下面的两个法则进行计算:

第一个规则要求每个参与计算的Tensor至少有一个维度,第二个规则是指在维度迭代时,从最后一个维度开始,可以有三种情况:

3、不符合广播语义的例子

x = torch.empty((0, ))
y = torch.empty((2, 3))
print((x + y).size())

输出:

RuntimeError: The size of tensor a (0) must match the size of tensor b (3) at non-singleton dimension 1

这里,不满足第一个规则“每个参与计算的Tensor至少有一个维度”。

x = torch.empty(5, 2, 4, 1)
y = torch.empty(3, 1, 1)
print((x + y).size())

输出:

RuntimeError: The size of tensor a (2) must match
the size of tensor b (3) at non-singleton dimension 1

这里,不满足第二个规则,因为从最后的维度开始迭代的过程中,倒数第三个维度:x是2,y是3。这并不符合第二条规则的三种情况,所以不能使用广播语义。

4、符合广播语义的例子

x = torch.empty(5, 3, 4, 1)
y = torch.empty(3, 1, 1)
print((x + y).size())

输出:

torch.Size([5, 3, 4, 1])

x是四维的,y是三维的,从最后一个维度开始迭代:

到此这篇关于pytorch中的广播语义的文章就介绍到这了,更多相关pytorch广播语义内容请搜索源码搜藏网以前的文章或继续浏览下面的相关文章希望大家以后多多支持源码搜藏网!

上一篇:python数组排序办法之sort、sorted和argsort详细介绍
下一篇:Python的ini配置文件你知道吗

相关内容

热门推荐