blob: 7406787ba438345dc485c50e347e40597b2037f5 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reversible residual network compatible with eager execution.
Code for main model.
Reference [The Reversible Residual Network: Backpropagation
Without Storing Activations](https://arxiv.org/pdf/1707.04585.pdf)
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.eager.python.examples.revnet import blocks
class RevNet(tf.keras.Model):
"""RevNet that depends on all the blocks."""
def __init__(self, config):
"""Initialize RevNet with building blocks.
Args:
config: tf.contrib.training.HParams object; specifies hyperparameters
"""
super(RevNet, self).__init__()
self.axis = 1 if config.data_format == "channels_first" else 3
self.config = config
self._init_block = blocks.InitBlock(config=self.config)
self._final_block = blocks.FinalBlock(config=self.config)
self._block_list = self._construct_intermediate_blocks()
self._moving_average_variables = []
def _construct_intermediate_blocks(self):
# Precompute input shape after initial block
stride = self.config.init_stride
if self.config.init_max_pool:
stride *= 2
if self.config.data_format == "channels_first":
w, h = self.config.input_shape[1], self.config.input_shape[2]
input_shape = (self.config.init_filters, w // stride, h // stride)
else:
w, h = self.config.input_shape[0], self.config.input_shape[1]
input_shape = (w // stride, h // stride, self.config.init_filters)
# Aggregate intermediate blocks
block_list = tf.contrib.checkpoint.List()
for i in range(self.config.n_rev_blocks):
# RevBlock configurations
n_res = self.config.n_res[i]
filters = self.config.filters[i]
if filters % 2 != 0:
raise ValueError("Number of output filters must be even to ensure"
"correct partitioning of channels")
stride = self.config.strides[i]
strides = (self.config.strides[i], self.config.strides[i])
# Add block
rev_block = blocks.RevBlock(
n_res,
filters,
strides,
input_shape,
batch_norm_first=(i != 0), # Only skip on first block
data_format=self.config.data_format,
bottleneck=self.config.bottleneck,
fused=self.config.fused,
dtype=self.config.dtype)
block_list.append(rev_block)
# Precompute input shape for the next block
if self.config.data_format == "channels_first":
w, h = input_shape[1], input_shape[2]
input_shape = (filters, w // stride, h // stride)
else:
w, h = input_shape[0], input_shape[1]
input_shape = (w // stride, h // stride, filters)
return block_list
def call(self, inputs, training=True):
"""Forward pass."""
saved_hidden = None
if training:
saved_hidden = [inputs]
h = self._init_block(inputs, training=training)
if training:
saved_hidden.append(h)
for block in self._block_list:
h = block(h, training=training)
if training:
saved_hidden.append(h)
logits = self._final_block(h, training=training)
return (logits, saved_hidden) if training else (logits, None)
def compute_loss(self, logits, labels):
"""Compute cross entropy loss."""
if self.config.dtype == tf.float32 or self.config.dtype == tf.float16:
cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
else:
# `sparse_softmax_cross_entropy_with_logits` does not have a GPU kernel
# for float64, int32 pairs
labels = tf.one_hot(
labels, depth=self.config.n_classes, axis=1, dtype=self.config.dtype)
cross_ent = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=labels)
return tf.reduce_mean(cross_ent)
def compute_gradients(self, saved_hidden, labels, training=True, l2_reg=True):
"""Manually computes gradients.
This method silently updates the running averages of batch normalization.
Args:
saved_hidden: List of hidden states Tensors
labels: One-hot labels for classification
training: Use the mini-batch stats in batch norm if set to True
l2_reg: Apply l2 regularization
Returns:
A tuple with the first entry being a list of all gradients and the second
being the loss
"""
def _defunable_pop(l):
"""Functional style list pop that works with `tfe.defun`."""
t, l = l[-1], l[:-1]
return t, l
# Backprop through last block
x = saved_hidden[-1]
with tf.GradientTape() as tape:
tape.watch(x)
logits = self._final_block(x, training=training)
loss = self.compute_loss(logits, labels)
grads_combined = tape.gradient(loss,
[x] + self._final_block.trainable_variables)
dy, final_grads = grads_combined[0], grads_combined[1:]
# Backprop through intermediate blocks
intermediate_grads = []
for block in reversed(self._block_list):
y, saved_hidden = _defunable_pop(saved_hidden)
x = saved_hidden[-1]
dy, grads = block.backward_grads(x, y, dy, training=training)
intermediate_grads = grads + intermediate_grads
# Backprop through first block
_, saved_hidden = _defunable_pop(saved_hidden)
x, saved_hidden = _defunable_pop(saved_hidden)
assert not saved_hidden
with tf.GradientTape() as tape:
y = self._init_block(x, training=training)
init_grads = tape.gradient(
y, self._init_block.trainable_variables, output_gradients=dy)
# Ordering match up with `model.trainable_variables`
grads_all = init_grads + final_grads + intermediate_grads
if l2_reg:
grads_all = self._apply_weight_decay(grads_all)
return grads_all, loss
def _apply_weight_decay(self, grads):
"""Update gradients to reflect weight decay."""
return [
g + self.config.weight_decay * v if v.name.endswith("kernel:0") else g
for g, v in zip(grads, self.trainable_variables)
]
def get_moving_stats(self):
"""Get moving averages of batch normalization."""
device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
with tf.device(device):
return [v.read_value() for v in self.moving_average_variables]
def restore_moving_stats(self, values):
"""Restore moving averages of batch normalization."""
device = "/gpu:0" if tf.test.is_gpu_available() else "/cpu:0"
with tf.device(device):
for var_, val in zip(self.moving_average_variables, values):
var_.assign(val)
@property
def moving_average_variables(self):
"""Get all variables that are batch norm moving averages."""
def _is_moving_avg(v):
n = v.name
return n.endswith("moving_mean:0") or n.endswith("moving_variance:0")
if not self._moving_average_variables:
self._moving_average_variables = filter(_is_moving_avg, self.variables)
return self._moving_average_variables