blob: 2ae6fcb7ec2c3714f900902ab2241d5d916c6462 [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras preprocessing layers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import keras_export
_COUNT_NAME = 'count'
_MEAN_NAME = 'mean'
_VARIANCE_NAME = 'variance'
# TODO(momernick): Find a good example of normalization?
@keras_export('keras.layers.experimental.preprocessing.Normalization', v1=[])
class Normalization(CombinerPreprocessingLayer):
"""Feature-wise normalization of the data.
This layer will coerce its inputs into a distribution centered around
0 with standard deviation 1. It accomplishes this by precomputing the mean and
variance of the data, and calling (input-mean)/sqrt(var) at runtime.
What happens in `adapt`: Compute mean and variance of the data and store them
as the layer's weights. `adapt` should be called before `fit`, `evaluate`,
or `predict`.
Attributes:
axis: Integer or tuple of integers, the axis or axes that should be
normalized (typically the features axis). We will normalize each element
in the specified axis. If set to 'None', the layer will perform scalar
normalization (diving the input by a single scalar value). 0 (the batch
axis) is not allowed.
"""
def __init__(self, axis=-1, dtype=None, **kwargs):
# This ensures that if the value of K.floatx() changes after file-loading
# time, the dtype value will change to reflect it.
dtype = dtype or K.floatx()
super(Normalization, self).__init__(
combiner=_NormalizingCombiner(axis), dtype=dtype, **kwargs)
if axis == 0:
raise ValueError('The argument \'axis\' may not be 0.')
self.axis = axis
def build(self, input_shape):
self._broadcast_shape = [1 for _ in range(len(input_shape))]
if isinstance(self.axis, (tuple, list)):
mean_and_var_shape = []
for i in self.axis:
mean_and_var_shape.append(input_shape[i])
self._broadcast_shape[i] = input_shape[i]
else:
if self.axis is None:
mean_and_var_shape = ()
else:
mean_and_var_shape = input_shape[self.axis]
self._broadcast_shape[self.axis] = input_shape[self.axis]
# count is not used in this class's call() method, but is used to re-create
# the accumulator during multiple calls to 'adapt'.
# TODO(omalleyt): should mean and variance be set to self.dtype?
self.mean = self._add_state_variable(
name=_MEAN_NAME,
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.zeros_initializer)
self.variance = self._add_state_variable(
name=_VARIANCE_NAME,
shape=mean_and_var_shape,
dtype=K.floatx(),
initializer=init_ops.ones_initializer)
self.count = self._add_state_variable(
name=_COUNT_NAME,
shape=(),
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer)
super(Normalization, self).build(input_shape)
def call(self, inputs):
# If the inputs are not floats, cast them to floats. This avoids issues
# with int-float multiplication and division below.
if inputs.dtype != K.floatx():
inputs = math_ops.cast(inputs, K.floatx())
# We need to reshape the mean and variance data to ensure that Tensorflow
# broadcasts the data correctly.
mean = array_ops.reshape(self.mean, self._broadcast_shape)
variance = array_ops.reshape(self.variance, self._broadcast_shape)
return (inputs - mean) / math_ops.sqrt(variance)
def compute_output_shape(self, input_shape):
return input_shape
def compute_output_signature(self, input_spec):
return input_spec
def get_config(self):
config = {'axis': self.axis}
base_config = super(Normalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def set_weights(self, weights):
"""Override for set_weights to ensure we can set just mean/var weights."""
if len(weights) == 2:
weights.append(np.array(0))
super(Normalization, self).set_weights(weights)
class _NormalizingCombiner(Combiner):
"""Combiner for the Normalization preprocessing layer.
This class encapsulates the computations for finding the mean and variance
of a set of data in a stable and numerically correct way. Its associated
accumulator is a namedtuple('count', 'mean', 'variance').
Attributes:
axis: The axis to compute mean and var over.
"""
COUNT_IDX = 0
MEAN_IDX = 1
VAR_IDX = 2
def __init__(self, axis):
self.axis = axis
def compute(self, values, accumulator=None):
"""Compute a step in this computation, returning a new accumulator."""
# This is the shape of all reduced axes (not specified in 'axis').
if self.axis is None:
reduction_counts = values.shape
else:
reduction_counts = np.delete(values.shape, self.axis)
# We get the number of elements that will be reduced by multiplying all
# values of 'shape' corresponding to the reduced axes.
count = np.prod(reduction_counts, dtype=np.int32)
# We want to reduce across dimensions except those specified in 'axis'
# when using np.mean or np.variance; create the tuple of axes to reduce
# over here.
if self.axis is None:
reduction_axes = None
else:
reduction_axes = tuple(np.delete(range(values.ndim), self.axis))
mean = np.mean(values, axis=reduction_axes, dtype=np.float64)
variance = np.var(values, axis=reduction_axes, dtype=np.float64)
# Create an accumulator with our new data and either return it or combine
# it with the passed accumulator.
if accumulator is None:
return self._create_accumulator(count, mean, variance)
else:
return self.add_data_to_accumulator(count, mean, variance, accumulator)
def add_data_to_accumulator(self, count, mean, variance, accumulator):
"""Add new data to the totals in an accumulator."""
# Combine accumulators and return the result.
combined_count = count + accumulator[self.COUNT_IDX]
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = (mean * count + accumulator[self.MEAN_IDX] *
accumulator[self.COUNT_IDX]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
accumulator_var_contribution = accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
data_var_contribution = count * (variance + np.square(mean - combined_mean))
combined_variance = (accumulator_var_contribution +
data_var_contribution) / combined_count
accumulator[self.COUNT_IDX] = combined_count
accumulator[self.MEAN_IDX] = np.nan_to_num(combined_mean)
accumulator[self.VAR_IDX] = np.nan_to_num(combined_variance)
return accumulator
def merge(self, accumulators):
"""Merge several accumulators to a single accumulator."""
# Combine accumulators and return the result.
combined_count = np.sum(
[accumulator[self.COUNT_IDX] for accumulator in accumulators])
# To combine accumulator means, we weight each accumulator's mean by the
# number of elements that were accumulated, and then divide by the
# total number of elements.
combined_mean = np.add.reduce([
accumulator[self.MEAN_IDX] * accumulator[self.COUNT_IDX]
for accumulator in accumulators
]) / combined_count
# The variance is computed using the lack-of-fit sum of squares
# formula (see https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).
def variance_contribution(accumulator):
return accumulator[self.COUNT_IDX] * (
accumulator[self.VAR_IDX] +
np.square(accumulator[self.MEAN_IDX] - combined_mean))
combined_variance = np.add.reduce([
variance_contribution(accumulator) for accumulator in accumulators
]) / combined_count
return self._create_accumulator(combined_count, combined_mean,
combined_variance)
def extract(self, accumulator):
"""Convert an accumulator into a dict of output values."""
return {
_COUNT_NAME: accumulator[self.COUNT_IDX],
_MEAN_NAME: accumulator[1],
_VARIANCE_NAME: accumulator[2]
}
def restore(self, output):
"""Create an accumulator based on 'output'."""
# There is no special internal state here, so we just return the relevant
# internal value.
count = output[_COUNT_NAME]
mean = output[_MEAN_NAME]
var = output[_VARIANCE_NAME]
if (count == 0 and (mean.any() != 0.0 or var.any() != 0.0)):
raise RuntimeError(
'The mean and/or variance of a Normalization preprocessing layer '
"were set without also setting 'count'. If 'count' is not also set,"
" 'adapt' cannot be called unless the 'reset_state' arg is True.")
return self._create_accumulator(output[_COUNT_NAME], output[_MEAN_NAME],
output[_VARIANCE_NAME])
def serialize(self, accumulator):
"""Serialize an accumulator for a remote call."""
output_dict = {
_COUNT_NAME: accumulator[self.COUNT_IDX].tolist(),
_MEAN_NAME: accumulator[1].tolist(),
_VARIANCE_NAME: accumulator[2].tolist()
}
return compat.as_bytes(json.dumps(output_dict))
def deserialize(self, encoded_accumulator):
"""Deserialize an accumulator received from 'serialize()'."""
value_dict = json.loads(compat.as_text(encoded_accumulator))
return self._create_accumulator(
np.array(value_dict[_COUNT_NAME]), np.array(value_dict[_MEAN_NAME]),
np.array(value_dict[_VARIANCE_NAME]))
def _create_accumulator(self, count, mean, variance):
"""Convert any 'nan' values in the given accumulator to numeric values."""
return [count, mean, variance]