This notebook contains all the sample code in chapter 19.


First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0.

# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
    !echo "deb stable tensorflow-model-server tensorflow-model-server-universal" > /etc/apt/sources.list.d/tensorflow-serving.list
    !curl | apt-key add -
    !apt update && apt-get install -y tensorflow-model-server
    !pip install -q -U tensorflow-serving-api
    IS_COLAB = True
except Exception:
    IS_COLAB = False

# TensorFlow ≥2.0 is required
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"

if not tf.test.is_gpu_available():
    print("No GPU was detected. CNNs can be very slow without a GPU.")
    if IS_COLAB:
        print("Go to Runtime > Change runtime and select a GPU hardware accelerator.")

# Common imports
import numpy as np
import os

# to make this notebook's output stable across runs

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# Where to save the figures
CHAPTER_ID = "deploy"
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
    plt.savefig(path, format=fig_extension, dpi=resolution)

Deploying TensorFlow models to TensorFlow Serving (TFS)

We will use the REST API or the gRPC API.

Save/Load a SavedModel

(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.mnist.load_data()
X_train_full = X_train_full[..., np.newaxis].astype(np.float32) / 255.
X_test = X_test[..., np.newaxis].astype(np.float32) / 255.
X_valid, X_train = X_train_full[:5000], X_train_full[5000:]
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
X_new = X_test[:3]

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28, 1]),
    keras.layers.Dense(100, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
              metrics=["accuracy"]), y_train, epochs=10, validation_data=(X_valid, y_valid))
Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 2s 40us/sample - loss: 0.7018 - accuracy: 0.8223 - val_loss: 0.3722 - val_accuracy: 0.9022
Epoch 2/10
55000/55000 [==============================] - 2s 36us/sample - loss: 0.3528 - accuracy: 0.9021 - val_loss: 0.3000 - val_accuracy: 0.9170
Epoch 3/10
55000/55000 [==============================] - 2s 36us/sample - loss: 0.3032 - accuracy: 0.9150 - val_loss: 0.2659 - val_accuracy: 0.9280
Epoch 4/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.2730 - accuracy: 0.9233 - val_loss: 0.2442 - val_accuracy: 0.9342
Epoch 5/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.2504 - accuracy: 0.9305 - val_loss: 0.2272 - val_accuracy: 0.9346
Epoch 6/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.2319 - accuracy: 0.9353 - val_loss: 0.2104 - val_accuracy: 0.9418
Epoch 7/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.2156 - accuracy: 0.9395 - val_loss: 0.1987 - val_accuracy: 0.9484
Epoch 8/10
55000/55000 [==============================] - 2s 36us/sample - loss: 0.2019 - accuracy: 0.9434 - val_loss: 0.1893 - val_accuracy: 0.9496
Epoch 9/10
55000/55000 [==============================] - 2s 41us/sample - loss: 0.1898 - accuracy: 0.9471 - val_loss: 0.1765 - val_accuracy: 0.9526
Epoch 10/10
55000/55000 [==============================] - 2s 39us/sample - loss: 0.1791 - accuracy: 0.9495 - val_loss: 0.1691 - val_accuracy: 0.9550
<tensorflow.python.keras.callbacks.History at 0x13d74aba8>
np.round(model.predict(X_new), 2)
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.96, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  ]],
model_version = "0001"
model_name = "my_mnist_model"
model_path = os.path.join(model_name, model_version)
!rm -rf {model_name}, model_path)
for root, dirs, files in os.walk(model_name):
    indent = '    ' * root.count(os.sep)
    print('{}{}/'.format(indent, os.path.basename(root)))
    for filename in files:
        print('{}{}'.format(indent + '    ', filename))
!saved_model_cli show --dir {model_path}
The given SavedModel contains the following tag-sets:
!saved_model_cli show --dir {model_path} --tag_set serve
The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "__saved_model_init_op"
SignatureDef key: "serving_default"
!saved_model_cli show --dir {model_path} --tag_set serve \
                      --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
  inputs['flatten_2_input'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 28, 28, 1)
      name: serving_default_flatten_2_input:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['dense_5'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 10)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict
!saved_model_cli show --dir {model_path} --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

  The given SavedModel SignatureDef contains the following input(s):
    inputs['flatten_2_input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 28, 28, 1)
        name: serving_default_flatten_2_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['dense_5'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 10)
        name: StatefulPartitionedCall:0
  Method name is: tensorflow/serving/predict

Let's write the new instances to a npy file so we can pass them easily to our model:"my_mnist_tests.npy", X_new)
input_name = model.input_names[0]

And now let's use saved_model_cli to make predictions for the instances we just saved:

!saved_model_cli run --dir {model_path} --tag_set serve \
                     --signature_def serving_default    \
                     --inputs {input_name}=my_mnist_tests.npy
2019-06-10 10:56:43.396851: I tensorflow/core/platform/] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
WARNING: Logging before flag parsing goes to stderr.
W0610 10:56:43.397369 140735810999168] From /Users/ageron/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/tools/ load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.
W0610 10:56:43.421489 140735810999168] From /Users/ageron/miniconda3/envs/tf2/lib/python3.6/site-packages/tensorflow/python/training/ checkpoint_exists (from is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
Result for output key dense_5:
[[1.17575204e-04 1.13160660e-07 5.96997386e-04 2.08104262e-03
  2.57820852e-06 6.44166794e-05 2.77263990e-08 9.96703804e-01
  3.96052455e-05 3.93810158e-04]
 [1.22226949e-03 2.92685600e-05 9.86054957e-01 9.63000767e-03
  8.81790996e-08 2.88744748e-04 1.58111588e-03 1.12290488e-09
  1.19344448e-03 1.09315742e-07]
 [6.40679718e-05 9.63618696e-01 9.04400647e-03 2.98595289e-03
  5.95759891e-04 3.74212675e-03 2.50709383e-03 1.14931818e-02
  5.52693009e-03 4.22279176e-04]]
np.round([[1.1739199e-04, 1.1239604e-07, 6.0210604e-04, 2.0804715e-03, 2.5779348e-06,
           6.4079795e-05, 2.7411186e-08, 9.9669880e-01, 3.9654213e-05, 3.9471846e-04],
          [1.2294615e-03, 2.9207937e-05, 9.8599273e-01, 9.6755642e-03, 8.8930705e-08,
           2.9156188e-04, 1.5831805e-03, 1.1311053e-09, 1.1980456e-03, 1.1113169e-07],
          [6.4066830e-05, 9.6359509e-01, 9.0598064e-03, 2.9872139e-03, 5.9552520e-04,
           3.7478798e-03, 2.5074568e-03, 1.1462728e-02, 5.5553433e-03, 4.2495009e-04]], 2)
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.96, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  ]])

TensorFlow Serving

Install Docker if you don't have it already. Then run:

docker pull tensorflow/serving

export ML_PATH=$HOME/ml # or wherever this project is
docker run -it --rm -p 8500:8500 -p 8501:8501 \
   -v "$ML_PATH/my_mnist_model:/models/my_mnist_model" \
   -e MODEL_NAME=my_mnist_model \

Once you are finished using it, press Ctrl-C to shut down the server.

Alternatively, if tensorflow_model_server is installed (e.g., if you are running this notebook in Colab), then the following 3 cells will start the server:

os.environ["MODEL_DIR"] = os.path.split(os.path.abspath(model_path))[0]
%%bash --bg
nohup tensorflow_model_server \
     --rest_api_port=8501 \
     --model_name=my_mnist_model \
     --model_base_path="${MODEL_DIR}" >server.log 2>&1
!tail server.log
2019-11-06 13:04:12.267136: I external/org_tensorflow/tensorflow/core/platform/] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2019-11-06 13:04:12.283035: I external/org_tensorflow/tensorflow/cc/saved_model/] Restoring SavedModel bundle.
2019-11-06 13:04:12.300096: I external/org_tensorflow/tensorflow/cc/saved_model/] Running initialization op on SavedModel bundle at path: /content/my_mnist_model/0002
2019-11-06 13:04:12.304438: I external/org_tensorflow/tensorflow/cc/saved_model/] SavedModel load for tags { serve }; Status: success. Took 39993 microseconds.
2019-11-06 13:04:12.304900: I tensorflow_serving/servables/tensorflow/] No warmup data file found at /content/my_mnist_model/0002/assets.extra/tf_serving_warmup_requests
2019-11-06 13:04:12.305057: I tensorflow_serving/core/] Successfully loaded servable version {name: my_mnist_model version: 2}
2019-11-06 13:04:12.306462: I tensorflow_serving/model_servers/] Running gRPC ModelServer at ...
[warn] getaddrinfo: address family for nodename not supported
2019-11-06 13:04:12.307179: I tensorflow_serving/model_servers/] Exporting HTTP/REST API at:localhost:8501 ...
[ : 238] NET_LOG: Entering the event loop ...
import json

input_data_json = json.dumps({
    "signature_name": "serving_default",
    "instances": X_new.tolist(),
repr(input_data_json)[:1500] + "..."
'\'{"signature_name": "serving_default", "instances": [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294117748737335, 0.7254902124404907, 0.6235294342041016, 0.5921568870544434, 0.23529411852359772, 0.1411764770746231, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8705882430076599, 0.9960784316062927, 0.9960784316062927, 0.9960784316062927, 0.9960784316062927, 0.9450980424880981, 0.7764706015586853, 0.7764706015586853, 0.7764706015586853, 0.776470...'

Now let's use TensorFlow Serving's REST API to make predictions:

import requests

SERVER_URL = 'http://localhost:8501/v1/models/my_mnist_model:predict'
response =, data=input_data_json)
response.raise_for_status() # raise an exception in case of error
response = response.json()
y_proba = np.array(response["predictions"])
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.96, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  ]])

Using the gRPC API

from tensorflow_serving.apis.predict_pb2 import PredictRequest

request = PredictRequest() = model_name
request.model_spec.signature_name = "serving_default"
input_name = model.input_names[0]
import grpc
from tensorflow_serving.apis import prediction_service_pb2_grpc

channel = grpc.insecure_channel('localhost:8500')
predict_service = prediction_service_pb2_grpc.PredictionServiceStub(channel)
response = predict_service.Predict(request, timeout=10.0)
outputs {
  key: "dense_4"
  value {
    dtype: DT_FLOAT
    tensor_shape {
      dim {
        size: 3
      dim {
        size: 10
    float_val: 2.0824443708988838e-05
    float_val: 1.4913139168015732e-08
    float_val: 0.0004813199338968843
    float_val: 0.001888890634290874
    float_val: 2.682592992186983e-07
    float_val: 8.666840585647151e-06
    float_val: 1.6853943241024183e-10
    float_val: 0.9975269436836243
    float_val: 3.833709342870861e-05
    float_val: 3.4738284739432856e-05
    float_val: 0.00017358684272039682
    float_val: 0.0002858016814570874
    float_val: 0.9816810488700867
    float_val: 0.0157401692122221
    float_val: 1.1949770339914068e-10
    float_val: 0.00023017563216853887
    float_val: 3.078056761296466e-05
    float_val: 5.393230750883049e-09
    float_val: 0.0018584482604637742
    float_val: 1.8884094288296183e-09
    float_val: 3.397366526769474e-05
    float_val: 0.9835277795791626
    float_val: 0.001533020636998117
    float_val: 0.0014515116345137358
    float_val: 0.00018795969663187861
    float_val: 0.0011680654715746641
    float_val: 0.0014667459763586521
    float_val: 0.006120447069406509
    float_val: 0.004315734840929508
    float_val: 0.00019466254161670804
model_spec {
  name: "my_mnist_model"
  version {
    value: 2
  signature_name: "serving_default"

Convert the response to a tensor:

output_name = model.output_names[0]
outputs_proto = response.outputs[output_name]
y_proba = tf.make_ndarray(outputs_proto)
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.98, 0.02, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.98, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]],

Or to a NumPy array if your client does not include the TensorFlow library:

output_name = model.output_names[0]
outputs_proto = response.outputs[output_name]
shape = [dim.size for dim in outputs_proto.tensor_shape.dim]
y_proba = np.array(outputs_proto.float_val).reshape(shape)
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.98, 0.02, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.98, 0.  , 0.  , 0.  , 0.  , 0.  , 0.01, 0.  , 0.  ]])

Deploying a new model version


model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28, 28, 1]),
    keras.layers.Dense(50, activation="relu"),
    keras.layers.Dense(50, activation="relu"),
    keras.layers.Dense(10, activation="softmax")
history =, y_train, epochs=10, validation_data=(X_valid, y_valid))
Train on 55000 samples, validate on 5000 samples
Epoch 1/10
55000/55000 [==============================] - 2s 39us/sample - loss: 0.7035 - accuracy: 0.8060 - val_loss: 0.3445 - val_accuracy: 0.9032
Epoch 2/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.3213 - accuracy: 0.9084 - val_loss: 0.2660 - val_accuracy: 0.9252
Epoch 3/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.2663 - accuracy: 0.9236 - val_loss: 0.2304 - val_accuracy: 0.9392
Epoch 4/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.2331 - accuracy: 0.9331 - val_loss: 0.2069 - val_accuracy: 0.9430
Epoch 5/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.2105 - accuracy: 0.9390 - val_loss: 0.1910 - val_accuracy: 0.9446
Epoch 6/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.1924 - accuracy: 0.9442 - val_loss: 0.1732 - val_accuracy: 0.9518
Epoch 7/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.1771 - accuracy: 0.9489 - val_loss: 0.1679 - val_accuracy: 0.9526
Epoch 8/10
55000/55000 [==============================] - 2s 37us/sample - loss: 0.1650 - accuracy: 0.9527 - val_loss: 0.1574 - val_accuracy: 0.9546
Epoch 9/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.1540 - accuracy: 0.9555 - val_loss: 0.1446 - val_accuracy: 0.9590
Epoch 10/10
55000/55000 [==============================] - 2s 35us/sample - loss: 0.1448 - accuracy: 0.9583 - val_loss: 0.1414 - val_accuracy: 0.9608
<tensorflow.python.keras.callbacks.History at 0x12f58f908>
model_version = "0002"
model_name = "my_mnist_model"
model_path = os.path.join(model_name, model_version)
'my_mnist_model/0002', model_path)
for root, dirs, files in os.walk(model_name):
    indent = '    ' * root.count(os.sep)
    print('{}{}/'.format(indent, os.path.basename(root)))
    for filename in files:
        print('{}{}'.format(indent + '    ', filename))

Warning: You may need to wait a minute before the new model is loaded by TensorFlow Serving.

import requests

SERVER_URL = 'http://localhost:8501/v1/models/my_mnist_model:predict'
response =, data=input_data_json)
response = response.json()
y_proba = np.array(response["predictions"])
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.96, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  ]])

Deploy the model to Google Cloud AI Platform

Follow the instructions in the book to deploy the model to Google Cloud AI Platform, download the service account's private key and save it to the my_service_account_private_key.json in the project directory. Also, update the project_id:

project_id = "onyx-smoke-242003"
import googleapiclient.discovery

os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "my_service_account_private_key.json"
model_id = "my_mnist_model"
model_path = "projects/{}/models/{}".format(project_id, model_id)
model_path += "/versions/v0001/" # if you want to run a specific version
ml_resource ="ml", "v1").projects()
def predict(X):
    input_data_json = {"signature_name": "serving_default",
                       "instances": X.tolist()}
    request = ml_resource.predict(name=model_path, body=input_data_json)
    response = request.execute()
    if "error" in response:
        raise RuntimeError(response["error"])
    return np.array([pred[output_name] for pred in response["predictions"]])
Y_probas = predict(X_new)
np.round(Y_probas, 2)
array([[0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 1.  , 0.  , 0.  ],
       [0.  , 0.  , 0.99, 0.01, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ],
       [0.  , 0.96, 0.01, 0.  , 0.  , 0.  , 0.  , 0.01, 0.01, 0.  ]])

Using GPUs

from tensorflow.python.client.device_lib import list_local_devices

devices = list_local_devices()
[name: "/device:CPU:0"
 device_type: "CPU"
 memory_limit: 268435456
 locality {
 incarnation: 11178133101787456811]

Distributed Training

def create_model():
    return keras.models.Sequential([
        keras.layers.Conv2D(filters=64, kernel_size=7, activation="relu",
                            padding="same", input_shape=[28, 28, 1]),
        keras.layers.Conv2D(filters=128, kernel_size=3, activation="relu",
        keras.layers.Conv2D(filters=128, kernel_size=3, activation="relu",
        keras.layers.Dense(units=64, activation='relu'),
        keras.layers.Dense(units=10, activation='softmax'),
batch_size = 100
model = create_model()
              metrics=["accuracy"]), y_train, epochs=10,
          validation_data=(X_valid, y_valid), batch_size=batch_size)

distribution = tf.distribute.MirroredStrategy()

# Change the default all-reduce algorithm:
#distribution = tf.distribute.MirroredStrategy(
#    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

# Specify the list of GPUs to use:
#distribution = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"])

# Use the central storage strategy instead:
#distribution = tf.distribute.experimental.CentralStorageStrategy()

#resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
#distribution = tf.distribute.experimental.TPUStrategy(resolver)

with distribution.scope():
    model = create_model()
WARNING: Logging before flag parsing goes to stderr.
W0603 15:31:26.178871 140735810999168] There is non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
batch_size = 100 # must be divisible by the number of workers, y_train, epochs=10,
          validation_data=(X_valid, y_valid), batch_size=batch_size)
array([[0.09101252, 0.07083996, 0.06410537, 0.11957529, 0.06693752,
        0.05124901, 0.04676544, 0.23180223, 0.13522181, 0.12249089],
       [0.08099081, 0.12387844, 0.14915964, 0.13171668, 0.05875394,
        0.08834281, 0.16267018, 0.06899565, 0.07834874, 0.05714307],
       [0.04303756, 0.2682051 , 0.0909673 , 0.11496522, 0.06084979,
        0.07125981, 0.08520001, 0.08517107, 0.09236596, 0.0879782 ]],

Custom training loop:


K = keras.backend

distribution = tf.distribute.MirroredStrategy()

with distribution.scope():
    model = create_model()
    optimizer = keras.optimizers.SGD()

with distribution.scope():
    dataset =, y_train)).repeat().batch(batch_size)
    input_iterator = distribution.make_dataset_iterator(dataset)
def train_step():
    def step_fn(inputs):
        X, y = inputs
        with tf.GradientTape() as tape:
            Y_proba = model(X)
            loss = K.sum(keras.losses.sparse_categorical_crossentropy(y, Y_proba)) / batch_size

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        return loss

    per_replica_losses = distribution.experimental_run(step_fn, input_iterator)
    mean_loss = distribution.reduce(tf.distribute.ReduceOp.SUM,
                                    per_replica_losses, axis=None)
    return mean_loss

n_epochs = 10
with distribution.scope():
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        for iteration in range(len(X_train) // batch_size):
            print("\rLoss: {:.3f}".format(train_step().numpy()), end="")
batch_size = 100 # must be divisible by the number of workers, y_train, epochs=10,
          validation_data=(X_valid, y_valid), batch_size=batch_size)

Training across multiple servers

A TensorFlow cluster is a group of TensorFlow processes running in parallel, usually on different machines, and talking to each other to complete some work, for example training or executing a neural network. Each TF process in the cluster is called a "task" (or a "TF server"). It has an IP address, a port, and a type (also called its role or its job). The type can be "worker", "chief", "ps" (parameter server) or "evaluator":

  • Each worker performs computations, usually on a machine with one or more GPUs.
  • The chief performs computations as well, but it also handles extra work such as writing TensorBoard logs or saving checkpoints. There is a single chief in a cluster. If no chief is specified, then the first worker is the chief.
  • A parameter server (ps) only keeps track of variable values, it is usually on a CPU-only machine.
  • The evaluator obviously takes care of evaluation. There is usually a single evaluator in a cluster.

The set of tasks that share the same type is often called a "job". For example, the "worker" job is the set of all workers.

To start a TensorFlow cluster, you must first specify it. This means defining all the tasks (IP address, TCP port, and type). For example, the following cluster specification defines a cluster with 3 tasks (2 workers and 1 parameter server). It's a dictionary with one key per job, and the values are lists of task addresses:

    "worker": ["", ""],
    "ps": [""]

Every task in the cluster may communicate with every other task in the server, so make sure to configure your firewall to authorize all communications between these machines on these ports (it's usually simpler if you use the same port on every machine).

When a task is started, it needs to be told which one it is: its type and index (the task index is also called the task id). A common way to specify everything at once (both the cluster spec and the current task's type and id) is to set the TF_CONFIG environment variable before starting the program. It must be a JSON-encoded dictionary containing a cluster specification (under the "cluster" key), and the type and index of the task to start (under the "task" key). For example, the following TF_CONFIG environment variable defines a simple cluster with 2 workers and 1 parameter server, and specifies that the task to start is the first worker:

import os
import json

os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["", ""],
        "ps": [""]
    "task": {"type": "worker", "index": 0}
TF_CONFIG='{"cluster": {"worker": ["", ""], "ps": [""]}, "task": {"type": "worker", "index": 0}}'

Some platforms (e.g., Google Cloud ML Engine) automatically set this environment variable for you.

Then you would write a short Python script to start a task. The same script can be used on every machine, since it will load the TF_CONFIG variable, which will tell it which task to start:

import tensorflow as tf

resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
worker0 = tf.distribute.Server(resolver.cluster_spec(),

Another way to specify the cluster specification is directly in Python, rather than through an environment variable:

cluster_spec = tf.train.ClusterSpec({
    "worker": ["", ""],
    "ps": [""]

You can then start a server simply by passing it the cluster spec and indicating its type and index. Let's start the two remaining tasks (remember that in general you would only start a single task per machine; we are starting 3 tasks on the localhost just for the purpose of this code example):

#worker1 = tf.distribute.Server(cluster_spec, job_name="worker", task_index=1)
ps0 = tf.distribute.Server(cluster_spec, job_name="ps", task_index=0)
os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["", ""],
        "ps": [""]
    "task": {"type": "worker", "index": 1}
'{"cluster": {"worker": ["", ""], "ps": [""]}, "task": {"type": "worker", "index": 1}}'
distribution = tf.distribute.experimental.MultiWorkerMirroredStrategy()


os.environ["TF_CONFIG"] = json.dumps({
    "cluster": {
        "worker": ["", ""],
        "ps": [""]
    "task": {"type": "worker", "index": 1}

with distribution.scope():
    model = create_model()
import tensorflow as tf
from tensorflow import keras
import numpy as np

# At the beginning of the program (restart the kernel before running this cell)
distribution = tf.distribute.experimental.MultiWorkerMirroredStrategy()

(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.mnist.load_data()
X_train_full = X_train_full[..., np.newaxis] / 255.
X_test = X_test[..., np.newaxis] / 255.
X_valid, X_train = X_train_full[:5000], X_train_full[5000:]
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]
X_new = X_test[:3]

n_workers = 2
batch_size = 32 * n_workers
dataset =[..., np.newaxis], y_train)).repeat().batch(batch_size)
def create_model():
    return keras.models.Sequential([
        keras.layers.Conv2D(filters=64, kernel_size=7, activation="relu",
                            padding="same", input_shape=[28, 28, 1]),
        keras.layers.Conv2D(filters=128, kernel_size=3, activation="relu",
        keras.layers.Conv2D(filters=128, kernel_size=3, activation="relu",
        keras.layers.Dense(units=64, activation='relu'),
        keras.layers.Dense(units=10, activation='softmax'),

with distribution.scope():
    model = create_model()
                  metrics=["accuracy"]), steps_per_epoch=len(X_train)//batch_size, epochs=10)
# Hyperparameter tuning

# Only talk to ps server
config_proto = tf.ConfigProto(device_filters=['/job:ps', '/job:worker/task:%d' % tf_config['task']['index']])
config = tf.estimator.RunConfig(session_config=config_proto)
# default since 1.10