numpy.stack()比较难理解。其文档中的一个例子,如何理解axis=2的情况呢?
这个例子是用10 * 3 * 4张量。为了简化,我们先研究2 * 3 * 4。
2 * 3 * 4张量可以表示成如下。np.stack(axis=2)相当于在axis=2即第三维度的元素之间堆叠,图中用线连接起来就表示堆叠
因为图中的线连起来只有两个元素,所以堆叠的结果:
所以结果是一个3 * 4 * 2的张量。
同理可知10 * 3 * 4在stack(axis=2)时结果是3 * 4 * 10。