blob: b661a20430d3f0acc5267a4b671cd0954a0b8f2f [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST end-to-end training and inference example.
The source code here is from
https://www.tensorflow.org/xla/tutorials/jit_compile, where there is also a
walkthrough.
To execute in TFRT BEF, run with
`--config=cuda --test_env=XLA_FLAGS=--xla_gpu_bef_executable`
To execute in TFRT JitRt, run with
`--config=cuda --test_env=XLA_FLAGS=--xla_gpu_jitrt_executable`
To dump debug output (e.g., LMHLO MLIR, TFRT MLIR, TFRT BEF), run with
`--test_env=XLA_FLAGS="--xla_dump_to=/tmp/mnist"`.
"""
from absl import app
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000
def main(_):
# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(
TRAIN_BATCH_SIZE).repeat()
# Casting from raw data to the required datatypes.
def cast(images, labels):
images = tf.cast(
tf.reshape(images, [-1, IMAGE_SIZE]), tf.float32)
labels = tf.cast(labels, tf.int64)
return (images, labels)
layer = tf.keras.layers.Dense(NUM_CLASSES)
optimizer = tf.keras.optimizers.Adam()
@tf.function(jit_compile=True)
def train_mnist(images, labels):
images, labels = cast(images, labels)
with tf.GradientTape() as tape:
predicted_labels = layer(images)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=predicted_labels, labels=labels
))
layer_variables = layer.trainable_variables
grads = tape.gradient(loss, layer_variables)
optimizer.apply_gradients(zip(grads, layer_variables))
for images, labels in train_ds:
if optimizer.iterations > TRAIN_STEPS:
break
train_mnist(images, labels)
images, labels = cast(test[0], test[1])
predicted_labels = layer(images)
correct_prediction = tf.equal(tf.argmax(predicted_labels, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % accuracy)
if __name__ == "__main__":
app.run(main)