You might have heard of many different Deep Learning frameworks such as TensorFlow and PyTorch. They do everything for Deep Learning! You probably do NOT need any other framework! BUT…

Google Brain believes otherwise! In this tutorial, you become familiar with Trax and the motivations behind its development.


In this tutorial you will learn:

  • What is Trax,
  • why it has been developed,
  • what are its advantages, and
  • what is its high-level syntax?

Here, you will NOT learn:

  • What is Trax’s mechanics, and
  • How to employ it to actually train/test deep neural networks.

Why not other frameworks?

That’s a good question. Why not TensorFlow or PyTorch as they can bring cakes for you! Well, there are two issues with TensorFlow and PyTorch:

  1. You should write a long syntax for a simple task.
  2. Their language is not very high-level and it brings complexities.

The above issues are problematic in complicated tasks, in particular. So Google Brain created Trax to help their team to write neat and concise codes for detailed studies.

One may ask Keras is very concise and is a part of TensorFlow as well. Why not Keras?

It’s a great question. Well, there is no strong argument to use Trax instead of Keras-TensorFlow at this moment. But, Trax is actively maintained by Google Brain Team, and my personal belief is that it should be more convenient for the tasks that Google Brain Team is working on!

Computational Advantage


Trax uses the JAX library. It provides high-performance computing! According to Google, “JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs.” Such implementation can significantly improve the performance of Trax.

Google built JAX that makes high-performance accelerator code from Python and Numpy. It connects Autograd and XLA for high-performance machine learning research. Autograd assists JAX to distinguish native Python and Numpy. XLA compiler optimizes TensorFlow computations, and JAX uses it to run NumPy operations on GPU and TPU. 

Working with Trax

Here you become familiar with how to start with Trax and learn some basic functionalities. Let’s get started.

Installing Trax

Installing Trax could not be more comfortable! I suggest to install it in a python virtual environment. For further details, you can refer to the following post:

Now let’s assume you installed the virtual environment and activated it. Simply run the following in the command line:

pip install trax

You can do the following if you want to install a specific version:

pip install trax==x.x

where x.x is your desired version!

Trax Layers

First, let’s import layers to work with. As usual, we will import NumPy as well. You can investigate the full list of layers in the official documentation (layer section).

import numpy as np
from trax import layers as ly

The ‘ly’ is a placeholder for the layers building block.

Activation Layers

Let’s define a Sigmoid activation function with Trax:

# Make a sigmoid activation layer
# Sigmoid is of type trax.layers.base.PureLayer
# Ref:
sigmoid = ly.activation_fns.Sigmoid()

# Some attributes
print("name :",
print("weights :", sigmoid.weights)
print("# of inputs :", sigmoid.n_in)
print("# of outputs :", sigmoid.n_out)

Above, we showcased some of the attributes. The attributes in the trax.layers.base.PureLayer object.


We use combinators to composes layers (1) serially or (2) in parallel.

To compose layers serially, we use trax.layers.combinators.Serial. You can think of it as a simple neural network with multiple layers stacked together. An example would be as follows:

# Composing layers serially
stacked_layers = ly.Serial(

print("name :",
print("weights:", stacked_layers.weights)
print("sublayers :", stacked_layers.sublayers)
print("expected inputs :", stacked_layers.n_in)
print("promised outputs :", stacked_layers.n_out)

You can learn more with the useful help function:


Above, create to the following output:

Help on class Serial in module trax.layers.combinators:

class Serial(trax.layers.base.Layer)
 |  Serial(*sublayers, name=None, sublayers_to_print=None)
 |  Combinator that applies layers serially (by function composition).
 |  This combinator is commonly used to construct deep networks, e.g., like this::
 |      mlp = tl.Serial(
 |        tl.Dense(128),
 |        tl.Relu(),
 |        tl.Dense(10),
 |        tl.LogSoftmax()
 |      )
 |  A Serial combinator uses stack semantics to manage data for its sublayers.
 |  Each sublayer sees only the inputs it needs and returns only the outputs it
 |  has generated. The sublayers interact via the data stack. For instance, a
 |  sublayer k, following sublayer j, gets called with the data stack in the
 |  state left after layer j has applied. The Serial combinator then:
 |    - takes n_in items off the top of the stack (n_in = k.n_in) and calls
 |      layer k, passing those items as arguments; and
 |    - takes layer k's n_out return values (n_out = k.n_out) and pushes
 |      them onto the data stack.
 |  A Serial instance with no sublayers acts as a special-case (but useful)
 |  1-input 1-output no-op.
 |  Method resolution order:
 |      Serial
 |      trax.layers.base.Layer
 |      builtins.object
 |  Methods defined here:
 |  __init__(self, *sublayers, name=None, sublayers_to_print=None)
 |      Creates a partially initialized, unconnected layer instance.
 |      Args:
 |        n_in: Number of inputs expected by this layer.
 |        n_out: Number of outputs promised by this layer.
 |        name: Class-like name for this layer; for use when printing this layer.
 |        sublayers_to_print: Sublayers to display when printing out this layer;
 |          By default (when None) we display all sublayers.
 |  forward(self, xs)
 |      Computes this layer's output as part of a forward pass through the model.
 |      Authors of new layer subclasses should override this method to define the
 |      forward computation that their layer performs. Use `self.weights` to access
 |      trainable weights of this layer. If you need to use local non-trainable
 |      state or randomness, use `self.rng` for the random seed (no need to set it)
 |      and use `self.state` for non-trainable state (and set it to the new value).
 |      Args:
 |        inputs: Zero or more input tensors, packaged as described in the `Layer`
 |            class docstring.
 |      Returns:
 |        Zero or more output tensors, packaged as described in the `Layer` class
 |        docstring.
 |  init_weights_and_state(self, input_signature)
 |      Initializes weights and state for inputs with the given signature.
 |      Authors of new layer subclasses should override this method if their layer
 |      uses trainable weights or non-trainable state. To initialize trainable
 |      weights, set `self.weights` and to initialize non-trainable state,
 |      set `self.state` to the intended value.
 |      Args:
 |        input_signature: A `ShapeDtype` instance (if this layer takes one input)
 |            or a list/tuple of `ShapeDtype` instances; signatures of inputs.
 |  ----------------------------------------------------------------------
 |  Methods inherited from trax.layers.base.Layer:
 |  __call__(self, x, weights=None, state=None, rng=None)
 |      Makes layers callable; for use in tests or interactive settings.
 |      This convenience method helps library users play with, test, or otherwise
 |      probe the behavior of layers outside of a full training environment. It
 |      presents the layer as callable function from inputs to outputs, with the
 |      option of manually specifying weights and non-parameter state per individual
 |      call. For convenience, weights and non-parameter state are cached per layer
 |      instance, starting from default values of `EMPTY_WEIGHTS` and `EMPTY_STATE`,
 |      and acquiring non-empty values either by initialization or from values
 |      explicitly provided via the weights and state keyword arguments.
 |      Args:
 |        x: Zero or more input tensors, packaged as described in the `Layer` class
 |            docstring.
 |        weights: Weights or `None`; if `None`, use self's cached weights value.
 |        state: State or `None`; if `None`, use self's cached state value.
 |        rng: Single-use random number generator (JAX PRNG key), or `None`;
 |            if `None`, use a default computed from an integer 0 seed.
 |      Returns:
 |        Zero or more output tensors, packaged as described in the `Layer` class
 |        docstring.
 |  __repr__(self)
 |      Return repr(self).
 |  backward(self, inputs, output, grad, weights, state, new_state, rng)
 |      Custom backward pass to propagate gradients in a custom way.
 |      Args:
 |        inputs: Input tensors; can be a (possibly nested) tuple.
 |        output: The result of running this layer on inputs.
 |        grad: Gradient signal computed based on subsequent layers; its structure
 |            and shape must match output.
 |        weights: This layer's weights.
 |        state: This layer's state prior to the current forward pass.
 |        new_state: This layer's state after the current forward pass.
 |        rng: Single-use random number generator (JAX PRNG key).
 |      Returns:
 |        The custom gradient signal for the input. Note that we need to return
 |        a gradient for each argument of forward, so it will usually be a tuple
 |        of signals: the gradient for inputs and weights.
 |  init(self, input_signature, rng=None, use_cache=False)
 |      Initializes weights/state of this layer and its sublayers recursively.
 |      Initialization creates layer weights and state, for layers that use them.
 |      It derives the necessary array shapes and data types from the layer's input
 |      signature, which is itself just shape and data type information.
 |      For layers without weights or state, this method safely does nothing.
 |      This method is designed to create weights/state only once for each layer
 |      instance, even if the same layer instance occurs in multiple places in the
 |      network. This enables weight sharing to be implemented as layer sharing.
 |      Args:
 |        input_signature: `ShapeDtype` instance (if this layer takes one input)
 |            or list/tuple of `ShapeDtype` instances.
 |        rng: Single-use random number generator (JAX PRNG key), or `None`;
 |            if `None`, use a default computed from an integer 0 seed.
 |        use_cache: If `True`, and if this layer instance has already been
 |            initialized elsewhere in the network, then return special marker
 |            values -- tuple `(GET_WEIGHTS_FROM_CACHE, GET_STATE_FROM_CACHE)`.
 |            Else return this layer's newly initialized weights and state.
 |      Returns:
 |        A `(weights, state)` tuple.
 |  init_from_file(self, file_name, weights_only=False, input_signature=None)
 |      Initializes this layer and its sublayers from a pickled checkpoint.
 |      In the common case (`weights_only=False`), the file must be a gziped pickled
 |      dictionary containing items with keys `'flat_weights', `'flat_state'` and
 |      `'input_signature'`, which are used to initialize this layer.
 |      If `input_signature` is specified, it's used instead of the one in the file.
 |      If `weights_only` is `True`, the dictionary does not need to have the
 |      `'flat_state'` item and the state it not restored either.
 |      Args:
 |        file_name: Name/path of the pickeled weights/state file.
 |        weights_only: If `True`, initialize only the layer's weights. Else
 |            initialize both weights and state.
 |        input_signature: Input signature to be used instead of the one from file.
 |  output_signature(self, input_signature)
 |      Returns output signature this layer would give for `input_signature`.
 |  pure_fn(self, x, weights, state, rng, use_cache=False)
 |      Applies this layer as a pure function with no optional args.
 |      This method exposes the layer's computation as a pure function. This is
 |      especially useful for JIT compilation. Do not override, use `forward`
 |      instead.
 |      Args:
 |        x: Zero or more input tensors, packaged as described in the `Layer` class
 |            docstring.
 |        weights: A tuple or list of trainable weights, with one element for this
 |            layer if this layer has no sublayers, or one for each sublayer if
 |            this layer has sublayers. If a layer (or sublayer) has no trainable
 |            weights, the corresponding weights element is an empty tuple.
 |        state: Layer-specific non-parameter state that can update between batches.
 |        rng: Single-use random number generator (JAX PRNG key).
 |        use_cache: if `True`, cache weights and state in the layer object; used
 |          to implement layer sharing in combinators.
 |      Returns:
 |        A tuple of `(tensors, state)`. The tensors match the number (`n_out`)
 |        promised by this layer, and are packaged as described in the `Layer`
 |        class docstring.
 |  weights_and_state_signature(self, input_signature)
 |      Return a pair containing the signatures of weights and state.
 |  ----------------------------------------------------------------------
 |  Data descriptors inherited from trax.layers.base.Layer:
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  has_backward
 |      Returns `True` if this layer provides its own custom backward pass code.
 |      A layer subclass that provides custom backward pass code (for custom
 |      gradients) must override this method to return `True`.
 |  n_in
 |      Returns how many tensors this layer expects as input.
 |  n_out
 |      Returns how many tensors this layer promises as output.
 |  name
 |      Returns the name of this layer.
 |  rng
 |      Returns a single-use random number generator without advancing it.
 |  state
 |      Returns a tuple containing this layer's state; may be empty.
 |      If the layer has sublayers, the state by convention will be
 |      a tuple of length `len(sublayers)` containing sublayer states.
 |      Note that in this case self._state only marks which ones are shared.
 |  sublayers
 |      Returns a tuple containing this layer's sublayers; may be empty.
 |  weights
 |      Returns this layer's weights.
 |      Depending on the layer, the weights can be in the form of:
 |        - an empty tuple
 |        - a tensor (ndarray)
 |        - a nested structure of tuples and tensors
 |      If the layer has sublayers, the weights by convention will be
 |      a tuple of length `len(sublayers)` containing the weights of sublayers.
 |      Note that in this case self._weights only marks which ones are shared.

You can see help gives you great deal of information!

Of course, you can compose layers in parallel too! It can be done by trax.layers.combinators.Parallel. The layers that were paralleled together will be applied to a series of inputs based on determining which layer takes how many inputs and the results are going to be concatenated together.

For example, suppose one has two layers in parallel and we expect the following number of inputs and outputs for each:

  • A: 1 input and 1 output.
  • B: 2 inputs and 3 output.

Then the Parallel(A, B) object has:

  • inputs: X_1, X_2, and X_3,
  • outputs: A(X_1) and B(X_2,X_3), and
  • final output will be concatenation of A(X_1) and B(X_2,X_3).

Another good example can be found in the official documentation.

Custom Layer

You can define a custom layer with trax.layers.base.Fn. It can be define as follows:

def Custom_layer():
    # Set a name
    layer_name = "custom_layer"

    # Custom function
    def func(x):
        return x + x^2

    return ly.base.Fn(layer_name, func)

# Create layer object
custom_layer = Custom_layer()

# Check properties
print("name :",
print("expected inputs :", custom_layer.n_in)
print("promised outputs :", custom_layer.n_out)

# Inputs
x = np.array([0, -1, 1])

# Outputs
print("outputs :", custom_layer(x))

That lead to the following output:

name : custom_layer
expected inputs : 1
promised outputs : 1
outputs : [ 2 -4  0]

Ok, that was a brief overview of Trax layers with some examples. Feel free to explore more!


In this tutorial, I talked about Trax, the new Google Brain Deep Learning framework. Here, some of the main advantages of Trax and its high-level structure was described. In future tutorials, I will go through how to train a neural network with Trax in a real-world application.

Stay tuned and feel free to comment below if you have any questions and point of view.

5 2 votes
Article Rating
Notify of
Inline Feedbacks
View all comments
Would love your thoughts, please comment.x

Join our journey...

Access to exclusive materials...