Yuhang He's Blog

Some birds are not meant to be caged, their feathers are just too bright.

TensorFlow: get_shape() vs tf.shape()

It is desirable to delve into the bolts and nuts of get_shape() and tf.shape() as both of them can be utilized to get the shape of a tensor but careless usage of them easily leads to error. Since both of them deal with Tensor, it is naturally preferable to have a brief understanding of TensorFlow’s Tensor.

What is a Tensor in TensorFlow?

You can think that “TensorFlow deals with two things: Tensor and Operator”. In a nutshell, Tensor can be treated as N-dimensional array, storing data of various types, like tf.int32, tf.float32 etc. A tensor contains three attributes:

  • name, the name of a Tensor is used as an index for the tensor.
  • shape, describing the dimension information of the tensor.
  • type, showing what kind of data stored in Tensor.

Actually, TensorFlow Tensor contains two kinds of shape: static shape and dynamic shape. In TensorFlow FAQ, it says: In TensorFlow, a tensor has both a static (inferred) shape and a dynamic (true) shape. The static shape can be read using the tf.Tensor.get_shape() method; this shape is inferred from the operations that were used to create the tensor, and may be partially complete. If the static is not fully defined, the dynamic shape of a Tensor can be determined by evaluating tf.shape(t). Two things can be obtained from the FAQ:

  1. tf.Tensor.get_shape() is a member function of TensorFlow Tensor. The shape information inferred by it may be incomplete.
  2. tf.shape() is a TensorFlow operator, returning the dynamic shape of a tensor. Of course, the returned shape is explicitly well-defined.
  3. if a Tensor’s shape is well-determined, tf.Tensor.get_shape() and tf.shape() return the same shape value.

tf.Tensor.get_shape()

First, let’s test some code,

1
2
3
4
5
import tensorflow as tf
tensor1 = tf.constant( 1, shape = [10, 10 ] )
tensor1_static_shape = tensor1.get_shape()
print( type( tensor1_static_shape ) ) #<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
print( tensor1_static_shape.as_list() ) #[10, 10]

Note that we don’t need a tf.Session() to get the static shape. My understanding is that since all Tensors are created during graph construction and tf.Tensor.get_shape() is a Tensor inherent member function, no session is needed to run this function. Let’s continue to read another piece of code,

1
2
3
4
5
6
7
8
9
10
11
import tensorflow as tf
tensor1 = tf.placeholder( dtype = tf.int32, shape = [None, 10 ] )
tensor1_static_shape = tensor1.get_shape()
print( type( tensor1_static_shape ) ) #<class 'tensorflow.python.framework.tensor_shape.TensorShape'>
print( tensor1_static_shape.as_list() ) #[None, 10]
img_raw = tf.gfile.FastGFile('1.jpg','rb').read()
img_tmp = tf.image.decode_jpeg( img_raw )
img_shape1 = img_tmp.get_shape()
print(img_shape1) #(?, ?, ?)
img_shape2 = img_tmp.get_shape().as_list()
print(img_shape2) #[None, None, None]

We can observe that, when the shape of the tensor cannot be fully inferred during graph construction, the relevant static shape would be set as None or ?. This is why tf.Tensor.get_shape() cannot be used as method to get a tensor’s shape and further do something with this shape. Once the static shape can be full determined, however, tf.Tensor.get_shape() can be successfully exploited to build the subsequent graph without explicitly running session. (One obvious application is the classification task as the batch size is known in advance.)

tf.shape()

First of all, keep in mind that tf.shape() is a TensorFlow operator, which receives a tensor as input and outputs another tensor. That is, you can set a session in order to get output tensor’s value. You have to use tf.shape() under the circumstance of that a tensor’s shape cannot be completely inferred with tf.Tensor.get_shape().