Get in touch
or send us a question?

TFX Pipelines in Colab

Colab is a lightweight development environment which differs significantly from a production environment. In production, you may have various pipeline components like data ingestion, transformation, model training, run histories, etc. across multiple, distributed systems. For this tutorial, you should be aware that siginificant differences exist in Orchestration and Metadata storage – it is all handled locally within Colab. Learn more about TFX in Colab here.


First, we install and import the necessary packages, set up paths, and download data.

Upgrade Pip

To avoid upgrading Pip in a system when running locally, check to make sure that we’re running in Colab. Local systems can of course be upgraded separately.

  import colab
  !pip install --upgrade pip

Install and import TFX

pip install -q tfx

Import packages

Did you restart the runtime?

If you are using Google Colab, the first time that you run the cell above, you must restart the runtime by clicking above “RESTART RUNTIME” button or using “Runtime > Restart runtime …” menu. This is because of the way that Colab loads packages.

import os
import tempfile
import urllib
import pandas as pd

import tensorflow_model_analysis as tfma
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

Check the TFX, and MLMD versions.

from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))
import ml_metadata as mlmd
print('MLMD version: {}'.format(mlmd.__version__))

TFX version: 1.13.0 MLMD version: 1.13.1

Download the dataset

In this colab, we use the Palmer Penguins dataset which can be found on Github. We processed the dataset by leaving out any incomplete records, and drops island and sex columns, and converted labels to int32. The dataset contains 334 records of the body mass and the length and depth of penguins’ culmens, and the length of their flippers. You use this data to classify penguins into one of three species.

_data_root = tempfile.mkdtemp(prefix='tfx-data')
_data_filepath = os.path.join(_data_root, "penguins_processed.csv")
urllib.request.urlretrieve(DATA_PATH, _data_filepath)

(‘/tmpfs/tmp/tfx-datap4i8w56n/penguins_processed.csv’, <http.client.HTTPMessage at 0x7f3a76776370>)

Create an InteractiveContext

To run TFX components interactively in this notebook, create an InteractiveContext. The InteractiveContext uses a temporary directory with an ephemeral MLMD database instance. Note that calls to InteractiveContext are no-ops outside the Colab environment.

In general, it is a good practice to group similar pipeline runs under a Context.

interactive_context = InteractiveContext()

WARNING:absl:InteractiveContext pipeline_root argument not provided: using temporary directory /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g as root for pipeline outputs. WARNING:absl:InteractiveContext metadata_connection_config not provided: using SQLite ML Metadata database at /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/metadata.sqlite.

Construct the TFX Pipeline

A TFX pipeline consists of several components that perform different aspects of the ML workflow. In this notebook, you create and run the ExampleGenStatisticsGenSchemaGen, and Trainer components and use the Evaluator and Pusher component to evaluate and push the trained model.

Refer to the components tutorial for more information on TFX pipeline components.Note: Constructing a TFX Pipeline by setting up the individual components involves a lot of boilerplate code. For the purpose of this tutorial, it is alright if you do not fully understand every line of code in the pipeline setup.

Instantiate and run the ExampleGen Component

example_gen = tfx.components.CsvExampleGen(input_base=_data_root)

WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.’t find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.

Instantiate and run the StatisticsGen Component

statistics_gen = tfx.components.StatisticsGen(

Instantiate and run the SchemaGen Component

infer_schema = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

Instantiate and run the Trainer Component

# Define the module file for the Trainer component
trainer_module_file = ''
%%writefile {trainer_module_file}

# Define the training algorithm for the Trainer module file
import os
from typing import List, Text

import tensorflow as tf
from tensorflow import keras

from tfx import v1 as tfx
from tfx_bsl.public import tfxio

from tensorflow_metadata.proto.v0 import schema_pb2

# Features used for classification - culmen length and depth, flipper length,
# body mass, and species.

_LABEL_KEY = 'species'

    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'

def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema, batch_size: int) ->
  return data_accessor.tf_dataset_factory(
          batch_size=batch_size, label_key=_LABEL_KEY), schema).repeat()

def _build_keras_model():
  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  d = keras.layers.Dense(8, activation='relu')(d)
  d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)
  model = keras.Model(inputs=inputs, outputs=outputs)
  return model

def run_fn(fn_args: tfx.components.FnArgs):
  schema = schema_pb2.Schema()
  tfx.utils.parse_pbtxt_file(fn_args.schema_path, schema)
  train_dataset = _input_fn(
      fn_args.train_files, fn_args.data_accessor, schema, batch_size=10)
  eval_dataset = _input_fn(
      fn_args.eval_files, fn_args.data_accessor, schema, batch_size=10)
  model = _build_keras_model()
      epochs=int(fn_args.train_steps / 20),
      validation_steps=fn_args.eval_steps), save_format='tf')


Run the Trainer component.

trainer = tfx.components.Trainer(

Epoch 1/5 20/20 [==============================] – 3s 19ms/step – loss: 0.9458 – sparse_categorical_accuracy: 0.7500 – val_loss: 0.8589 – val_sparse_categorical_accuracy: 0.7800 Epoch 2/5 20/20 [==============================] – 0s 11ms/step – loss: 0.6942 – sparse_categorical_accuracy: 0.8000 – val_loss: 0.5478 – val_sparse_categorical_accuracy: 0.7800 Epoch 3/5 20/20 [==============================] – 0s 11ms/step – loss: 0.4146 – sparse_categorical_accuracy: 0.8100 – val_loss: 0.3478 – val_sparse_categorical_accuracy: 0.7800 Epoch 4/5 20/20 [==============================] – 0s 11ms/step – loss: 0.2747 – sparse_categorical_accuracy: 0.9350 – val_loss: 0.2253 – val_sparse_categorical_accuracy: 0.9600 Epoch 5/5 20/20 [==============================] – 0s 11ms/step – loss: 0.1738 – sparse_categorical_accuracy: 0.9700 – val_loss: 0.1330 – val_sparse_categorical_accuracy: 0.9800 INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/Trainer/model/4/Format-Serving/assets INFO:tensorflow:Assets written to: /tmpfs/tmp/tfx-interactive-2023-07-28T11_11_10.063419-p4royv0g/Trainer/model/4/Format-Serving/assets

Evaluate and push the model

Use the Evaluator component to evaluate and ‘bless’ the model before using the Pusher component to push the model to a serving directory.

_serving_model_dir = os.path.join(tempfile.mkdtemp(),
eval_config = tfma.EvalConfig(
        tfma.ModelSpec(label_key='species', signature_name='serving_default')
                        lower_bound={'value': 0.6})))
evaluator = tfx.components.Evaluator(

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/ tf_record_iterator (from is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: `` WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_model_analysis/writers/ tf_record_iterator (from is deprecated and will be removed in a future version. Instructions for updating: Use eager execution and: ``

pusher = tfx.components.Pusher(

Running the TFX pipeline populates the MLMD Database. In the next section, you use the MLMD API to query this database for metadata information.