Yuhang He's Blog

Some birds are not meant to be caged, their feathers are just too bright.

The secret of tf.Print in Tensorflow

Does anyone get frustrated by Python print() function when you want to print out some tensor value during the TensorFlow graph construction? Often we want to inspect the itermidiate value for debugging purpose, but Python print() merely prints out the side information about this tensor, such as its name, shape information (so long as it can be accurately infered from the constructed graph), which is not the right information we want to get.

Neither the TensorFlow nor the print() is to be blamed. The reason for this awkardness is simple: TensorFlow splits graph construction and computation apart, Python print() merely gives a node within a graph with its basic information, such as its shape (if available), name and data type information. Let’s dive deeper with one simple example: $ e = c \times (a + b) $. The basic script should look like:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import tensorflow as tf
g = tf.Graph()
with g.as_default():
	a = tf.constant( value = 1.0, dtype = tf.float32, name = 'a' )
	b = tf.constant( value = 2.0, dtype = tf.float32, name = 'b' )
	d = tf.constant( value = 3.0, dtype = tf.float32, name = 'd' )
	c = tf.add( a, b )
	print( c )
	e = tf.multiply( c, d )
with tf.Session( graph = g ) as sess:
	init_op = tf.group( tf.global_variabls_initializer(), tf.local_variables_initializer() )
	sess.run( init_op )
	e_val = sess.run( e )
	print( e_val )

The print( c ) outputs information looks like:

1
Tensor("add:0", shape=(), dtype=float32)

It works as the way as we expected: no tensor value is outputed! Fortunately, TensorFlow already takes care of this issue: tf.Print(). First, let’s take a look at its constructor:

1
2
3
4
5
6
7
8
Print(
	input_,
	data,
	message = None,
	first_n = None,
	summarize = None,
	name = None
)

In concise, it receives the input_ tensor, outputs the data information with the prefixed message information. Before directly entering into the final correct snippet, I would like to highlight two rules TensorFlow always obeys:

  • tf.Print() is capable of printing out all the tensors it has access to. Here “all the tensors” means the tensors the data flows through until the current tf.Print() operator. That is, in this example, tensor a and tensor b can be accessed by tf.Print().
  • To calculate the defined loss value, TensorFlow would choose the most directly related operators to run by a session. That is, any “dangling” or “irrelevant” operator would not be executed.

Beneath these two rules, two things gradually become clear: First, tf.Print() can print out all tensors the data flow has covered while tf.Print() operator has been reached. Second, instead of randomly putting tf.Print() as a “dangling” or “irrelevant” operator, we have to add the tf.Print() to the place of the graph where the data has to flow through.

With aforementioned discussion, here I provide the correct code snippet:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tensorflow as tf
g = tf.Graph()
with g.as_default():
	a = tf.constant( value = 1.0, dtype = tf.float32, name = 'a' )
	b = tf.constant( value = 2.0, dtype = tf.float32, name = 'b' )
	d = tf.constant( value = 3.0, dtype = tf.float32, name = 'd' )
	c = tf.add( a, b )
	#print( c )
	c_output = tf.Print( c, [c, b, a ], "begin to print out c, b, a respectively" )
	e = tf.multiply( c_output, d )
with tf.Session( graph = g ) as sess:
	init_op = tf.group( tf.global_variabls_initializer(), tf.local_variables_initializer() )
	sess.run( init_op )
	e_val = sess.run( e )
	print( e_val )

The final outout of this snippet is:

1
2
begin to print out c, a, b respectively [3][1][2]
9.0

The following picture illustrates the graph without and with “dangling” tf.Print() operator.

tf.Print() graph visualization

Believe it or not, if the graph is constructed with the tf.Print() operator being “dangled”, tf.Print() won’t print out anything! Thus, we should be careful when utilizing tf.Print() for either debugging or visualization purpose.