TensorFlow中的多维tensor运算(tf.tensordot)

本文会重点介绍关于TensorFlow中的tf.tensordot函数,但是在详细介绍这一函数之前,还会对其他矩阵乘法相关的函数进行简要说明。

1. tf.multiply
tf.multiply的操作等同于*的操作,即计算两个矩阵的按元素乘法。也就是求两个矩阵的哈达玛积(Hadamard product)。

1
2
3
4
5
a = tf.constant([1, 2, 3, 4, 5, 6], shape=(2, 3))
b = tf.constant([1, 2, 3, 4, 5, 6], shape=(2, 3))
c = tf.multiply(a, b)
print(c)
print(c == (a * b))

result1
2. tf.matmul
tf.matmul即是标准的矩阵运算函数,其要求参与运算的两个矩阵必须满足特定的行列关系。

1
2
3
4
a = tf.constant([1, 2, 3, 4, 5, 6], shape=(2, 3))
b = tf.constant([1, 2, 3, 4, 5, 6], shape=(3, 2))
c = tf.matmul(a, b)
print(c)

result2
3. tf.tensordot

函数参数

  • a: float32或float64类型的Tensor

  • b: 与a相同类型的Tensor

  • axes: 该参数用来表明a、b张量沿哪些轴进行收缩(收缩成一个轴)。该参数既可以是一个整数N,也可以是两个列表来指明要进行收缩的轴。

  • name: 操作的名称(可选)

接下来通过具体的例子来说明tensordot的操作,尤其是参数axes的含义。

1
2
3
4
5
6
7
a = tf.constant([1, 2, 3, 4, 5, 6], shape=(2, 3))
b = tf.constant([1, 2, 3, 4, 5, 6], shape=(3, 2))
# axes=1就表明,对a的最后一维和b的第一维进行收缩,此时就和标准的矩阵乘法是一样的。
c = tf.tensordot(a, b, axes=1)
# 当axes传入的是列表时,列表中的值将作为索引,分别指定两个张量按照哪些轴进行收缩。因此axes=[[1], [0]]表示按照a的第1号轴和按照b的第0号轴进行收缩。此时也是和标准的矩阵乘法是一样的。
d = tf.tensordot(a, b, axes=[[1], [0]])
print(c)

接下来是复杂一些的例子。

1
2
3
4
5
6
7
8
9
10
11
12
13
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24], shape=[2, 3, 4])

b = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], shape=[4, 3])

c = tf.tensordot(a, b, axes=2)
# 对a的后两个轴进行Flatten,此时a.shape=(2, 12)
# 对b的前两个轴进行Flatten,此时b.shape=(12,)
# 因此最终c.shape=(2,)
d = tf.tensordot(a, b, axes=[[1, 2], [0, 1]])
# 该结果与axes=2是一样的
print(c)

result3

最后需要强调的一点就是,不管两个按照哪些轴进行收缩,都必须保证两个张量按照指定轴收缩后的纬度值是一样的。