Skip to contents

This layer will shift and scale inputs into a distribution centered around 0 with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, and calling (input - mean) / sqrt(var) at runtime.

The mean and variance values for the layer must be either supplied on construction or learned via adapt(). adapt() will compute the mean and variance of the data and store them as the layer's weights. adapt() should be called before fit(), evaluate(), or predict().

Usage

layer_normalization(
  object,
  axis = -1L,
  mean = NULL,
  variance = NULL,
  invert = FALSE,
  ...
)

Arguments

object

Object to compose the layer with. A tensor, array, or sequential model.

axis

Integer, list of integers, or NULL. The axis or axes that should have a separate mean and variance for each index in the shape. For example, if shape is (NULL, 5) and axis=1, the layer will track 5 separate mean and variance values for the last axis. If axis is set to NULL, the layer will normalize all elements in the input by a scalar mean and variance. When -1, the last axis of the input is assumed to be a feature dimension and is normalized per index. Note that in the specific case of batched scalar inputs where the only axis is the batch axis, the default will normalize each index in the batch separately. In this case, consider passing axis=NULL. Defaults to -1.

mean

The mean value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's build() method is called.

variance

The variance value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's build() method is called.

invert

If TRUE, this layer will apply the inverse transformation to its inputs: it would turn a normalized input back into its original form.

...

For forward/backward compatability.

Value

The return value depends on the value provided for the first argument. If object is:

  • a keras_model_sequential(), then the layer is added to the sequential model (which is modified in place). To enable piping, the sequential model is also returned, invisibly.

  • a keras_input(), then the output tensor from calling layer(input) is returned.

  • NULL or missing, then a Layer instance is returned.

Examples

Calculate a global mean and variance by analyzing the dataset in adapt().

adapt_data <- op_array(c(1., 2., 3., 4., 5.), dtype='float32')
input_data <- op_array(c(1., 2., 3.), dtype='float32')
layer <- layer_normalization(axis = NULL)
layer %>% adapt(adapt_data)
layer(input_data)

## tf.Tensor([-1.4142135  -0.70710677  0.        ], shape=(3), dtype=float32)

Calculate a mean and variance for each index on the last axis.

adapt_data <- op_array(rbind(c(0., 7., 4.),
                       c(2., 9., 6.),
                       c(0., 7., 4.),
                       c(2., 9., 6.)), dtype='float32')
input_data <- op_array(matrix(c(0., 7., 4.), nrow = 1), dtype='float32')
layer <- layer_normalization(axis=-1)
layer %>% adapt(adapt_data)
layer(input_data)

## tf.Tensor([[-1. -1. -1.]], shape=(1, 3), dtype=float32)

Pass the mean and variance directly.

input_data <- op_array(rbind(1, 2, 3), dtype='float32')
layer <- layer_normalization(mean=3., variance=2.)
layer(input_data)

## tf.Tensor(
## [[-1.4142135 ]
##  [-0.70710677]
##  [ 0.        ]], shape=(3, 1), dtype=float32)

Use the layer to de-normalize inputs (after adapting the layer).

adapt_data <- op_array(rbind(c(0., 7., 4.),
                       c(2., 9., 6.),
                       c(0., 7., 4.),
                       c(2., 9., 6.)), dtype='float32')
input_data <- op_array(c(1., 2., 3.), dtype='float32')
layer <- layer_normalization(axis=-1, invert=TRUE)
layer %>% adapt(adapt_data)
layer(input_data)

## tf.Tensor([[ 2. 10.  8.]], shape=(1, 3), dtype=float32)

See also

Other numerical features preprocessing layers:
layer_discretization()

Other preprocessing layers:
layer_category_encoding()
layer_center_crop()
layer_discretization()
layer_feature_space()
layer_hashed_crossing()
layer_hashing()
layer_integer_lookup()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_rescaling()
layer_resizing()
layer_string_lookup()
layer_text_vectorization()

Other layers:
Layer()
layer_activation()
layer_activation_elu()
layer_activation_leaky_relu()
layer_activation_parametric_relu()
layer_activation_relu()
layer_activation_softmax()
layer_activity_regularization()
layer_add()
layer_additive_attention()
layer_alpha_dropout()
layer_attention()
layer_average()
layer_average_pooling_1d()
layer_average_pooling_2d()
layer_average_pooling_3d()
layer_batch_normalization()
layer_bidirectional()
layer_category_encoding()
layer_center_crop()
layer_concatenate()
layer_conv_1d()
layer_conv_1d_transpose()
layer_conv_2d()
layer_conv_2d_transpose()
layer_conv_3d()
layer_conv_3d_transpose()
layer_conv_lstm_1d()
layer_conv_lstm_2d()
layer_conv_lstm_3d()
layer_cropping_1d()
layer_cropping_2d()
layer_cropping_3d()
layer_dense()
layer_depthwise_conv_1d()
layer_depthwise_conv_2d()
layer_discretization()
layer_dot()
layer_dropout()
layer_einsum_dense()
layer_embedding()
layer_feature_space()
layer_flatten()
layer_gaussian_dropout()
layer_gaussian_noise()
layer_global_average_pooling_1d()
layer_global_average_pooling_2d()
layer_global_average_pooling_3d()
layer_global_max_pooling_1d()
layer_global_max_pooling_2d()
layer_global_max_pooling_3d()
layer_group_normalization()
layer_group_query_attention()
layer_gru()
layer_hashed_crossing()
layer_hashing()
layer_identity()
layer_integer_lookup()
layer_lambda()
layer_layer_normalization()
layer_lstm()
layer_masking()
layer_max_pooling_1d()
layer_max_pooling_2d()
layer_max_pooling_3d()
layer_maximum()
layer_minimum()
layer_multi_head_attention()
layer_multiply()
layer_permute()
layer_random_brightness()
layer_random_contrast()
layer_random_crop()
layer_random_flip()
layer_random_rotation()
layer_random_translation()
layer_random_zoom()
layer_repeat_vector()
layer_rescaling()
layer_reshape()
layer_resizing()
layer_rnn()
layer_separable_conv_1d()
layer_separable_conv_2d()
layer_simple_rnn()
layer_spatial_dropout_1d()
layer_spatial_dropout_2d()
layer_spatial_dropout_3d()
layer_spectral_normalization()
layer_string_lookup()
layer_subtract()
layer_text_vectorization()
layer_tfsm()
layer_time_distributed()
layer_torch_module_wrapper()
layer_unit_normalization()
layer_upsampling_1d()
layer_upsampling_2d()
layer_upsampling_3d()
layer_zero_padding_1d()
layer_zero_padding_2d()
layer_zero_padding_3d()
rnn_cell_gru()
rnn_cell_lstm()
rnn_cell_simple()
rnn_cells_stack()