tf.transpose()为转置函数,其中参数perm用来设置需要转置的维度和顺序
img = np.array([
[[1, 2, 3], [4, 5, 6]],
[[7, 8, 9], [10, 11, 12]]
])
# img = img[np.newaxis, :]
l1 = tf.convert_to_tensor(img)
l2 = tf.contrib.layers.flatten(l1)
l3 = tf.transpose(l1, (1, 0, 2))
l4=tf.contrib.layers.flatten(l3)
with tf.Session() as sess:
out = sess.run(l4)
print out, out.shape
img是一个2*2*3 (row*col*channel)的图像矩阵,在内存中的存储顺序为:channel=>col=>row,即从shape的最后一个维度往前开始存储,对应的perm为(0,1,2)
如果进行l3 = tf.transpose(l1, (0, 1, 2))则矩阵不变
如果进行l3 = tf.transpose(l1, (1, 0, 2))则对row和col进行转置,转置后,内存中的存储顺序改为:channel=>row=>col,shape=(2,2,3)
如果进行l3 = tf.transpose(l1, (2, 0, 1))则对先对row和col进行转置,再对col和channel进行转置,内存中的存储顺序改为:col=>row=>channel,shape=(3,2,2)