You probably heard about TFRecord, as the TensorFlow preferred file format. TensorFlow by Google is everywhere these days. You just need to take a look at how many times it’s being searched to have an idea about the explosion of its demand. I used Google Trends to search TensorFlow:

Since the new release of TensorFlow (2.0 and after), almost everything drastically changed in TensorFlow. It became the trend again after being intimidated by PyTorch! Especially its API can bring cakes for you!! It helps to create the TensorFlow input pipeline very efficiently.


With all those great functionalities, one issue remains: Working with large data due to memory limitations!

To address that issue and many other setbacks, TensorFlow offers TFRecord format. TFRecord brings very important advantages:

  1. TFRecord is keep serialized data in binary format, which allows efficient reading data. It significantly affects the performance of the model.
  2. It is optimized for TensorFlow. Simple! It is not a surprise as it is the TensorFlow recommended format.

But why we do not use it more often? Because working with TFRecords is not very user-friendly!

In this tutorial, you will learn:

  1. How to write your data into the TFRecord file format.
  2. How to read a TFRecord file.

The above two elements are in general the only things you need to work with TFRecord. I tried to make it as simple as possible.

Writing Data to a TFRecord File

Let’s start writing to a TFRecord file. The process is as simple as follows:

  1. Read MNIST data and pre-process it.
  2. Write MNIST data to a TFRecord file.

NOTE: You may object by saying, why do we have to write MNIST data to a TFRecord file when MNIST is a small and ready-to-use dataset? The answer is simple. Using MNIST is just for education purposes. The following approach simply works for any data!

Reading the Data

Let’s first read MNIST. Of a thousand ways I can do it, I found the following very simple:

# Loading necessary libraries
import tensorflow as tf
from tensorflow import keras
import numpy as np
# Load MNIST data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Preprocessing
x_train = x_train / 255.0
x_test = x_test / 255.0
# Track the data type
dataType = x_train.dtype
print(f"Data type: {dataType}")
labelType = y_test.dtype
print(f"Data type: {labelType}")


  • In the line 7, I used the tf.keras.datasets.mnist.load_data function to download and read MNIST data.
  • In lines 10-11, I divided data parts by 255.0, so the final data range becomes [0,1].
  • In the line 4, I use dataType = x_train.dtype to return the data type. We need it later to read and decode the records from the TFRecord file.

Let’s visualize some samples:

im_list = []
n_samples_to_show = 16
c = 0
for i in range(n_samples_to_show):
# Visualization
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(4., 4.))
# Ref:
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(4, 4),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
# Show image grid
for ax, im in zip(grid, im_list):
    # Iterating over the grid returns the Axes.
    ax.imshow(im, 'gray')

Structure the Data with tf.Example

The tf.Example is a Protocol Buffers (Protobuf), which is a method of serializing structured data. Since in a TFRecord file, we are going to store our data as binary strings (sequences of strings), we need to specify the structure of data before storing it. Otherwise, how do we know the original shape and characteristics of the data when trying to read and reconstruct it from the TFRecord file?

Two approaches can be used for this aim provided by TensorFlow:

  1. tf.train.Example
  2. tf.train.SequenceExample

We use tf.train.Example for our experiments. For using tf.train.Example, we should convert our data to compatible feature types. For that, the tf.train.Feature protocol message should be used. The tf.train.Feature supports three types of features: tf.train.BytesList, tf.train.FloatList, tf.train.Int64List. For further details, refer to the official documentation.

For making features compatible, we use the following auxiliary functions:

# Convert values to compatible tf.Example types.
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

Now we can set the feature dictionary and structure the data with tf.train.Example:

# Create the features dictionary.
def image_example(image, label, dimension):
    feature = {
        'dimension': _int64_feature(dimension),
        'label': _int64_feature(label),
        'image_raw': _bytes_feature(image.tobytes()),
    return tf.train.Example(features=tf.train.Features(feature=feature))

There are two extremely important points in the above code:

  1. We store the dimension of the image in the structure since we need this information when reading the image so we can restructure the reconstructed image as its original shape.
  2. For image, we first convert it to bytes with .tostring() or .tobytes() function (both functions does the same and .tostring() is in fact an alias for .tobytes). Then, we feed it to _bytes_feature function. Remember that even Python strings should be converted to bytes before being fed to the tf.train.BytesList.
  3. Remember, in the end, we used tf.train.Features as it is slightly different from tf.train.Feature (the former has one extra s letter in the end if you did not notice!!!). tf.train.Features it’s like the wrapper of named features and takes a dictionary (called feature) as its feature argument. The dictionary keys are the feature name with values of type “tf.train.Feature”.

Write Records to TFRecord File

Now, we set the structures and provided the functionalities for the final step of writing the data into the TFRecord file. Now we serialize and store the structured samples in the file as follows:

record_file = 'mnistTrain.tfrecords'
n_samples = x_train.shape[0]
dimension = x_train.shape[1]
with as writer:
   for i in range(n_samples):
      image = x_train[i]
      label = y_train[i]
      tf_example = image_example(image, label, dimension)

The function is used to write a string to the specified file. The serialization is done with the SerializeToString() function. In fact, we can serialize any proto message to a binary-string by operating the SerializeToString() method.

Reading from TFRecord

Now it’s the time to read from a TFRecord file given the knowledge of the serialized structured samples.

Creating the Dataset

We use the TensorFlow API. The Tensorflow API makes the process of creating the input pipeline very handy! First, we need to read the TFRecord file and create a dataset by function:

# Create the dataset object from tfrecord file(s)
dataset =, buffer_size=100)

With the function, we can read multiple TFRecord files as well. The argument buffer_size is very useful when we have memory restrictions. It represents the number of bytes in the read buffer. It is not mandatory though. Even if you do not set that, TensorFlow pick a reasonable number itself!

Retrieving Records

Now that we have the dataset, we can loop through the dataset to extract the records. Take a look at the following loop:

for record in dataset:
    parsed_record = parse_record(record)
    decoded_record = decode_record(parsed_record)
    image, label = decoded_record
    print(image.shape, label.shape)

The above loop extract one sample of the data as we used break to stop the loop. Let’s investigate the parse_record() function:

# Decoding function
def parse_record(record):
    name_to_features = {
        'dimension':[], tf.int64),
        'label':[], tf.int64),
        'image_raw':[], tf.string),
    return, name_to_features)

The above function:

  1. Operates on each record.
  2. Return the different features stored in the record structure based on the labels of the features. Above we have ‘dimension’, ‘label’, ‘image_raw’ features.
  3. As we have fix length features, we use function. If you have variable length features, you can use function. The first argument is the shape of the input data and we set it as [].
  4. Finally, the function is used to parse the serialized record given the predetermined structure.

One this remains still. The parses images are raw strings. How we reconstructed the image to its original form? It is done using the following function:

def decode_record(record):
    image =
        record['image_raw'], out_type=dataType, little_endian=True, fixed_length=None, name=None
    label = record['label']
    dimension = record['dimension']
    image = tf.reshape(image, (dimension, dimension))
    return (image, label)

Above we used the to revert back the image from string to its original form. Consider:

  1. We used the dataType for the out_type argument in the function. It is very important. If you do NOT use the same data type, you reconstruct something inconsistent!
  2. We used tf.reshape(image, (dimension, dimension)) as the have no idea about the shape! The dimension information should have been saved when we were writing to TFRecords.

Now we are going to repeat the above loop for visualization:

im_list = []
n_samples_to_show = 16
c = 0
for record in dataset:
  if c > n_samples_to_show:
  parsed_record = parse_record(record)
  decoded_record = decode_record(parsed_record)
  image, label = decoded_record
# Visualization
fig = plt.figure(figsize=(4., 4.))
# Ref:
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(4, 4),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
# Show image grid
for ax, im in zip(grid, im_list):
    # Iterating over the grid returns the Axes.
    ax.imshow(im, 'gray')

It should display the following image grid:

Now we read and reconstructed the images as they were originally saved. So we are technically done here!

You can check the following video for a code walk-through.


In this tutorial, you learned what TFRecord file format is, what its advantages are, and how to work with it in TensorFlow. I provided a simple example of the MNIST data. The same approach can be applied to the majority of the scenarios. It’s usually better to convert your data to TFRecords first and then work with them. Once you learn how to do it, I doubt it that you go back and work with other types of data formats!! Unless your data is too small!! Definitely, the story does not finish here. Feel free to explore more and comment below if you think I missed anything, disagree with me, have any questions, or any other reason that I forgot to this about it! Thank you for your attention and reading so far.

Leave a Comment

Your email address will not be published. Required fields are marked *