Skip to contents

Models

Create Models

keras_model_sequential()
Keras Model composed of a linear stack of layers
keras_model()
Keras Model (Functional API)
keras_input()
Create a Keras tensor (Functional API input).
clone_model()
Clone a Functional or Sequential Model instance.
Model()
Subclass the base Keras Model Class

Train Models

compile(<keras.src.models.model.Model>)
Configure a model for training.
fit(<keras.src.models.model.Model>)
Train a model for a fixed number of epochs (dataset iterations).
plot(<keras_training_history>)
Plot training history
predict(<keras.src.models.model.Model>)
Generates output predictions for the input samples.
evaluate(<keras.src.models.model.Model>)
Evaluate a Keras Model
train_on_batch()
Runs a single gradient update on a single batch of data.
predict_on_batch()
Returns predictions for a single batch of samples.
test_on_batch()
Test the model on a single batch of samples.
freeze_weights() unfreeze_weights()
Freeze and unfreeze weights

Inspect and Modify Models

summary(<keras.src.models.model.Model>) format(<keras.src.models.model.Model>) print(<keras.src.models.model.Model>)
Print a summary of a Keras Model
plot(<keras.src.models.model.Model>)
Plot a Keras model
get_config() from_config()
Layer/Model configuration
get_weights() set_weights()
Layer/Model weights as R arrays
get_layer()
Retrieves a layer based on either its name (unique) or index.
count_params()
Count the total number of scalars composing the weights.
pop_layer()
Remove the last layer in a Sequential model
quantize_weights()
Quantize the weights of a model.

Save and Load Models

save_model()
Saves a model as a .keras file.
load_model()
Loads a model saved via save_model().
save_model_weights()
Saves all layer weights to a .weights.h5 file.
load_model_weights()
Load weights from a file saved via save_model_weights().
save_model_config() load_model_config()
Save and load model configuration as JSON
export_savedmodel(<keras.src.models.model.Model>)
Create a TF SavedModel artifact for inference (e.g. via TF-Serving).
layer_tfsm()
Reload a Keras model/layer that was saved via export_savedmodel().
register_keras_serializable()
Registers a custom object with the Keras serialization framework.

Layers

Core Layers

layer_dense()
Just your regular densely-connected NN layer.
layer_einsum_dense()
A layer that uses einsum as the backing computation.
layer_embedding()
Turns positive integers (indexes) into dense vectors of fixed size.
layer_identity()
Identity layer.
layer_lambda()
Wraps arbitrary expressions as a Layer object.
layer_masking()
Masks a sequence by using a mask value to skip timesteps.

Reshaping Layers

layer_cropping_1d()
Cropping layer for 1D input (e.g. temporal sequence).
layer_cropping_2d()
Cropping layer for 2D input (e.g. picture).
layer_cropping_3d()
Cropping layer for 3D data (e.g. spatial or spatio-temporal).
layer_flatten()
Flattens the input. Does not affect the batch size.
layer_permute()
Permutes the dimensions of the input according to a given pattern.
layer_repeat_vector()
Repeats the input n times.
layer_reshape()
Layer that reshapes inputs into the given shape.
layer_upsampling_1d()
Upsampling layer for 1D inputs.
layer_upsampling_2d()
Upsampling layer for 2D inputs.
layer_upsampling_3d()
Upsampling layer for 3D inputs.
layer_zero_padding_1d()
Zero-padding layer for 1D input (e.g. temporal sequence).
layer_zero_padding_2d()
Zero-padding layer for 2D input (e.g. picture).
layer_zero_padding_3d()
Zero-padding layer for 3D data (spatial or spatio-temporal).

Convolutional Layers

layer_conv_1d()
1D convolution layer (e.g. temporal convolution).
layer_conv_1d_transpose()
1D transposed convolution layer.
layer_conv_2d()
2D convolution layer.
layer_conv_2d_transpose()
2D transposed convolution layer.
layer_conv_3d()
3D convolution layer.
layer_conv_3d_transpose()
3D transposed convolution layer.
layer_depthwise_conv_1d()
1D depthwise convolution layer.
layer_depthwise_conv_2d()
2D depthwise convolution layer.
layer_separable_conv_1d()
1D separable convolution layer.
layer_separable_conv_2d()
2D separable convolution layer.

Pooling Layers

layer_average_pooling_1d()
Average pooling for temporal data.
layer_average_pooling_2d()
Average pooling operation for 2D spatial data.
layer_average_pooling_3d()
Average pooling operation for 3D data (spatial or spatio-temporal).
layer_global_average_pooling_1d()
Global average pooling operation for temporal data.
layer_global_average_pooling_2d()
Global average pooling operation for 2D data.
layer_global_average_pooling_3d()
Global average pooling operation for 3D data.
layer_global_max_pooling_1d()
Global max pooling operation for temporal data.
layer_global_max_pooling_2d()
Global max pooling operation for 2D data.
layer_global_max_pooling_3d()
Global max pooling operation for 3D data.
layer_max_pooling_1d()
Max pooling operation for 1D temporal data.
layer_max_pooling_2d()
Max pooling operation for 2D spatial data.
layer_max_pooling_3d()
Max pooling operation for 3D data (spatial or spatio-temporal).

Activation Layers

layer_activation()
Applies an activation function to an output.
layer_activation_elu()
Applies an Exponential Linear Unit function to an output.
layer_activation_leaky_relu()
Leaky version of a Rectified Linear Unit activation layer.
layer_activation_parametric_relu()
Parametric Rectified Linear Unit activation layer.
layer_activation_relu()
Rectified Linear Unit activation function layer.
layer_activation_softmax()
Softmax activation layer.

Recurrent Layers

layer_bidirectional()
Bidirectional wrapper for RNNs.
layer_conv_lstm_1d()
1D Convolutional LSTM.
layer_conv_lstm_2d()
2D Convolutional LSTM.
layer_conv_lstm_3d()
3D Convolutional LSTM.
layer_gru()
Gated Recurrent Unit - Cho et al. 2014.
layer_lstm()
Long Short-Term Memory layer - Hochreiter 1997.
layer_rnn()
Base class for recurrent layers
layer_simple_rnn()
Fully-connected RNN where the output is to be fed back as the new input.
layer_time_distributed()
This wrapper allows to apply a layer to every temporal slice of an input.
rnn_cell_gru()
Cell class for the GRU layer.
rnn_cell_lstm()
Cell class for the LSTM layer.
rnn_cell_simple()
Cell class for SimpleRNN.
rnn_cells_stack()
Wrapper allowing a stack of RNN cells to behave as a single cell.
reset_state()
Reset the state for a model, layer or metric.

Attention Layers

layer_additive_attention()
Additive attention layer, a.k.a. Bahdanau-style attention.
layer_attention()
Dot-product attention layer, a.k.a. Luong-style attention.
layer_group_query_attention()
Grouped Query Attention layer.
layer_multi_head_attention()
Multi Head Attention layer.

Normalization Layers

layer_batch_normalization()
Layer that normalizes its inputs.
layer_group_normalization()
Group normalization layer.
layer_layer_normalization()
Layer normalization layer (Ba et al., 2016).
layer_spectral_normalization()
Performs spectral normalization on the weights of a target layer.
layer_unit_normalization()
Unit normalization layer.

Regularization Layers

layer_activity_regularization()
Layer that applies an update to the cost function based input activity.
layer_alpha_dropout()
Applies Alpha Dropout to the input.
layer_dropout()
Applies dropout to the input.
layer_gaussian_dropout()
Apply multiplicative 1-centered Gaussian noise.
layer_gaussian_noise()
Apply additive zero-centered Gaussian noise.
layer_spatial_dropout_1d()
Spatial 1D version of Dropout.
layer_spatial_dropout_2d()
Spatial 2D version of Dropout.
layer_spatial_dropout_3d()
Spatial 3D version of Dropout.

Merging Layers

layer_add()
Performs elementwise addition operation.
layer_average()
Averages a list of inputs element-wise..
layer_concatenate()
Concatenates a list of inputs.
layer_dot()
Computes element-wise dot product of two tensors.
layer_maximum()
Computes element-wise maximum on a list of inputs.
layer_minimum()
Computes elementwise minimum on a list of inputs.
layer_multiply()
Performs elementwise multiplication.
layer_subtract()
Performs elementwise subtraction.

Preprocessing Layers

layer_category_encoding()
A preprocessing layer which encodes integer features.
layer_center_crop()
A preprocessing layer which crops images.
layer_discretization()
A preprocessing layer which buckets continuous features by ranges.
layer_feature_space() feature_cross() feature_custom() feature_float() feature_float_rescaled() feature_float_normalized() feature_float_discretized() feature_integer_categorical() feature_string_categorical() feature_string_hashed() feature_integer_hashed()
One-stop utility for preprocessing and encoding structured data.
layer_hashed_crossing()
A preprocessing layer which crosses features using the "hashing trick".
layer_hashing()
A preprocessing layer which hashes and bins categorical features.
layer_integer_lookup()
A preprocessing layer that maps integers to (possibly encoded) indices.
layer_mel_spectrogram()
A preprocessing layer to convert raw audio signals to Mel spectrograms.
layer_normalization()
A preprocessing layer that normalizes continuous features.
layer_random_brightness()
A preprocessing layer which randomly adjusts brightness during training.
layer_random_contrast()
A preprocessing layer which randomly adjusts contrast during training.
layer_random_crop()
A preprocessing layer which randomly crops images during training.
layer_random_flip()
A preprocessing layer which randomly flips images during training.
layer_random_rotation()
A preprocessing layer which randomly rotates images during training.
layer_random_translation()
A preprocessing layer which randomly translates images during training.
layer_random_zoom()
A preprocessing layer which randomly zooms images during training.
layer_rescaling()
A preprocessing layer which rescales input values to a new range.
layer_resizing()
A preprocessing layer which resizes images.
layer_string_lookup()
A preprocessing layer that maps strings to (possibly encoded) indices.
layer_text_vectorization() get_vocabulary() set_vocabulary()
A preprocessing layer which maps text features to integer sequences.
adapt()
Fits the state of the preprocessing layer to the data being passed

Compatability Layers

layer_tfsm()
Reload a Keras model/layer that was saved via export_savedmodel().
layer_jax_model_wrapper()
Keras Layer that wraps a JAX model.
layer_flax_module_wrapper()
Keras Layer that wraps a Flax module.
layer_torch_module_wrapper()
Torch module wrapper layer.

Custom Layers

layer_lambda()
Wraps arbitrary expressions as a Layer object.
Layer()
Define a custom Layer class.

Layer Methods

get_config() from_config()
Layer/Model configuration
get_weights() set_weights()
Layer/Model weights as R arrays
count_params()
Count the total number of scalars composing the weights.
reset_state()
Reset the state for a model, layer or metric.

Callbacks

callback_model_checkpoint()
Callback to save the Keras model or model weights at some frequency.
callback_backup_and_restore()
Callback to back up and restore the training state.
callback_early_stopping()
Stop training when a monitored metric has stopped improving.
callback_terminate_on_nan()
Callback that terminates training when a NaN loss is encountered.
callback_learning_rate_scheduler()
Learning rate scheduler.
callback_reduce_lr_on_plateau()
Reduce learning rate when a metric has stopped improving.
callback_csv_logger()
Callback that streams epoch results to a CSV file.
callback_tensorboard()
Enable visualizations for TensorBoard.
callback_remote_monitor()
Callback used to stream events to a server.
callback_lambda()
Callback for creating simple, custom callbacks on-the-fly.
callback_swap_ema_weights()
Swaps model weights and EMA weights before and after evaluation.
Callback()
Define a custom Callback class

Operations

Functions that are safe to call with both symbolic and eager tensor.

Core Operations

op_cast()
Cast a tensor to the desired dtype.
op_cond()
Conditionally applies true_fn or false_fn.
op_convert_to_numpy()
Convert a tensor to a NumPy array.
op_convert_to_tensor()
Convert an array to a tensor.
op_custom_gradient()
Decorator to define a function with a custom gradient.
op_fori_loop()
For loop implementation.
op_is_tensor()
Check whether the given object is a tensor.
op_scatter()
Returns a tensor of shape shape where indices are set to values.
op_scatter_update()
Update inputs via updates at scattered (sparse) indices.
op_shape()
Gets the shape of the tensor input.
op_slice()
Return a slice of an input tensor.
op_slice_update()
Update an input by slicing in a tensor of updated values.
op_stop_gradient()
Stops gradient computation.
op_unstack()
Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.
op_vectorized_map()
Parallel map of function f on the first axis of tensor(s) elements.
op_while_loop()
While loop implementation.

Math Operations

op_erf()
Computes the error function of x, element-wise.
op_erfinv()
Computes the inverse error function of x, element-wise.
op_extract_sequences()
Expands the dimension of last axis into sequences of sequence_length.
op_fft()
Computes the Fast Fourier Transform along last axis of input.
op_fft2()
Computes the 2D Fast Fourier Transform along the last two axes of input.
op_in_top_k()
Checks if the targets are in the top-k predictions.
op_irfft()
Inverse real-valued Fast Fourier transform along the last axis.
op_istft()
Inverse Short-Time Fourier Transform along the last axis of the input.
op_logsumexp()
Computes the logarithm of sum of exponentials of elements in a tensor.
op_qr()
Computes the QR decomposition of a tensor.
op_rfft()
Real-valued Fast Fourier Transform along the last axis of the input.
op_rsqrt()
Computes reciprocal of square root of x element-wise.
op_segment_max()
Computes the max of segments in a tensor.
op_segment_sum()
Computes the sum of segments in a tensor.
op_solve()
Solves a linear system of equations given by a x = b.
op_stft()
Short-Time Fourier Transform along the last axis of the input.
op_top_k()
Finds the top-k values and their indices in a tensor.

General Tensor Operations

op_abs()
Compute the absolute value element-wise.
op_add()
Add arguments element-wise.
op_all()
Test whether all array elements along a given axis evaluate to TRUE.
op_any()
Test whether any array element along a given axis evaluates to TRUE.
op_append()
Append tensor x2 to the end of tensor x1.
op_arange()
Return evenly spaced values within a given interval.
op_arccos()
Trigonometric inverse cosine, element-wise.
op_arccosh()
Inverse hyperbolic cosine, element-wise.
op_arcsin()
Inverse sine, element-wise.
op_arcsinh()
Inverse hyperbolic sine, element-wise.
op_arctan()
Trigonometric inverse tangent, element-wise.
op_arctan2()
Element-wise arc tangent of x1/x2 choosing the quadrant correctly.
op_arctanh()
Inverse hyperbolic tangent, element-wise.
op_argmax()
Returns the indices of the maximum values along an axis.
op_argmin()
Returns the indices of the minimum values along an axis.
op_argsort()
Returns the indices that would sort a tensor.
op_array()
Create a tensor.
op_average()
Compute the weighted average along the specified axis.
op_bincount()
Count the number of occurrences of each value in a tensor of integers.
op_broadcast_to()
Broadcast a tensor to a new shape.
op_ceil()
Return the ceiling of the input, element-wise.
op_clip()
Clip (limit) the values in a tensor.
op_concatenate()
Join a sequence of tensors along an existing axis.
op_conj()
Returns the complex conjugate, element-wise.
op_copy()
Returns a copy of x.
op_correlate()
Compute the cross-correlation of two 1-dimensional tensors.
op_cos()
Cosine, element-wise.
op_cosh()
Hyperbolic cosine, element-wise.
op_count_nonzero()
Counts the number of non-zero values in x along the given axis.
op_cross()
Returns the cross product of two (arrays of) vectors.
op_ctc_decode()
Decodes the output of a CTC model.
op_cumprod()
Return the cumulative product of elements along a given axis.
op_cumsum()
Returns the cumulative sum of elements along a given axis.
op_diag()
Extract a diagonal or construct a diagonal array.
op_diagonal()
Return specified diagonals.
op_diff()
Calculate the n-th discrete difference along the given axis.
op_digitize()
Returns the indices of the bins to which each value in x belongs.
op_divide()
Divide arguments element-wise.
op_divide_no_nan()
Safe element-wise division which returns 0 where the denominator is 0.
op_dot()
Dot product of two tensors.
op_einsum()
Evaluates the Einstein summation convention on the operands.
op_empty()
Return a tensor of given shape and type filled with uninitialized data.
op_equal()
Returns (x1 == x2) element-wise.
op_exp()
Calculate the exponential of all elements in the input tensor.
op_expand_dims()
Expand the shape of a tensor.
op_expm1()
Calculate exp(x) - 1 for all elements in the tensor.
op_eye()
Return a 2-D tensor with ones on the diagonal and zeros elsewhere.
op_flip()
Reverse the order of elements in the tensor along the given axis.
op_floor()
Return the floor of the input, element-wise.
op_floor_divide()
Returns the largest integer smaller or equal to the division of inputs.
op_full()
Return a new tensor of given shape and type, filled with fill_value.
op_full_like()
Return a full tensor with the same shape and type as the given tensor.
op_get_item()
Return x[key].
op_greater()
Return the truth value of x1 > x2 element-wise.
op_greater_equal()
Return the truth value of x1 >= x2 element-wise.
op_hstack()
Stack tensors in sequence horizontally (column wise).
op_identity()
Return the identity tensor.
op_imag()
Return the imaginary part of the complex argument.
op_isclose()
Return whether two tensors are element-wise almost equal.
op_isfinite()
Return whether a tensor is finite, element-wise.
op_isinf()
Test element-wise for positive or negative infinity.
op_isnan()
Test element-wise for NaN and return result as a boolean tensor.
op_less()
Return the truth value of x1 < x2 element-wise.
op_less_equal()
Return the truth value of x1 <= x2 element-wise.
op_linspace()
Return evenly spaced numbers over a specified interval.
op_log()
Natural logarithm, element-wise.
op_log10()
Return the base 10 logarithm of the input tensor, element-wise.
op_log1p()
Returns the natural logarithm of one plus the x, element-wise.
op_log2()
Base-2 logarithm of x, element-wise.
op_logaddexp()
Logarithm of the sum of exponentiations of the inputs.
op_logical_and()
Computes the element-wise logical AND of the given input tensors.
op_logical_not()
Computes the element-wise NOT of the given input tensor.
op_logical_or()
Computes the element-wise logical OR of the given input tensors.
op_logical_xor()
Compute the truth value of x1 XOR x2, element-wise.
op_logspace()
Returns numbers spaced evenly on a log scale.
op_matmul()
Matrix product of two tensors.
op_max()
Return the maximum of a tensor or maximum along an axis.
op_maximum() op_pmax()
Element-wise maximum of x1 and x2.
op_mean()
Compute the arithmetic mean along the specified axes.
op_median()
Compute the median along the specified axis.
op_meshgrid()
Creates grids of coordinates from coordinate vectors.
op_min()
Return the minimum of a tensor or minimum along an axis.
op_minimum() op_pmin()
Element-wise minimum of x1 and x2.
op_mod()
Returns the element-wise remainder of division.
op_moveaxis()
Move axes of a tensor to new positions.
op_multiply()
Multiply arguments element-wise.
op_nan_to_num()
Replace NaN with zero and infinity with large finite numbers.
op_ndim()
Return the number of dimensions of a tensor.
op_negative()
Numerical negative, element-wise.
op_nonzero()
Return the indices of the elements that are non-zero.
op_not_equal()
Return (x1 != x2) element-wise.
op_ones()
Return a new tensor of given shape and type, filled with ones.
op_ones_like()
Return a tensor of ones with the same shape and type of x.
op_outer()
Compute the outer product of two vectors.
op_pad()
Pad a tensor.
op_power()
First tensor elements raised to powers from second tensor, element-wise.
op_prod()
Return the product of tensor elements over a given axis.
op_quantile()
Compute the q-th quantile(s) of the data along the specified axis.
op_ravel()
Return a contiguous flattened tensor.
op_real()
Return the real part of the complex argument.
op_reciprocal()
Return the reciprocal of the argument, element-wise.
op_repeat()
Repeat each element of a tensor after themselves.
op_reshape()
Gives a new shape to a tensor without changing its data.
op_roll()
Roll tensor elements along a given axis.
op_round()
Evenly round to the given number of decimals.
op_select()
Return elements from choicelist, based on conditions in condlist.
op_sign()
Returns a tensor with the signs of the elements of x.
op_sin()
Trigonometric sine, element-wise.
op_sinh()
Hyperbolic sine, element-wise.
op_size()
Return the number of elements in a tensor.
op_sort()
Sorts the elements of x along a given axis in ascending order.
op_split()
Split a tensor into chunks.
op_sqrt()
Return the non-negative square root of a tensor, element-wise.
op_square()
Return the element-wise square of the input.
op_squeeze()
Remove axes of length one from x.
op_stack()
Join a sequence of tensors along a new axis.
op_std()
Compute the standard deviation along the specified axis.
op_subtract()
Subtract arguments element-wise.
op_sum()
Sum of a tensor over the given axes.
op_swapaxes()
Interchange two axes of a tensor.
op_take()
Take elements from a tensor along an axis.
op_take_along_axis()
Select values from x at the 1-D indices along the given axis.
op_tan()
Compute tangent, element-wise.
op_tanh()
Hyperbolic tangent, element-wise.
op_tensordot()
Compute the tensor dot product along specified axes.
op_tile()
Repeat x the number of times given by repeats.
op_trace()
Return the sum along diagonals of the tensor.
op_transpose()
Returns a tensor with axes transposed.
op_tri()
Return a tensor with ones at and below a diagonal and zeros elsewhere.
op_tril()
Return lower triangle of a tensor.
op_triu()
Return upper triangle of a tensor.
op_var()
Compute the variance along the specified axes.
op_vdot()
Return the dot product of two vectors.
op_vectorize()
Turn a function into a vectorized function.
op_vstack()
Stack tensors in sequence vertically (row wise).
op_where()
Return elements chosen from x1 or x2 depending on condition.
op_zeros()
Return a new tensor of given shape and type, filled with zeros.
op_zeros_like()
Return a tensor of zeros with the same shape and type as x.

Neural Network Operations

op_average_pool()
Average pooling operation.
op_batch_normalization()
Normalizes x by mean and variance.
op_binary_crossentropy()
Computes binary cross-entropy loss between target and output tensor.
op_categorical_crossentropy()
Computes categorical cross-entropy loss between target and output tensor.
op_conv()
General N-D convolution.
op_conv_transpose()
General N-D convolution transpose.
op_ctc_loss()
CTC (Connectionist Temporal Classification) loss.
op_depthwise_conv()
General N-D depthwise convolution.
op_elu()
Exponential Linear Unit activation function.
op_gelu()
Gaussian Error Linear Unit (GELU) activation function.
op_hard_sigmoid()
Hard sigmoid activation function.
op_hard_silu() op_hard_swish()
Hard SiLU activation function, also known as Hard Swish.
op_leaky_relu()
Leaky version of a Rectified Linear Unit activation function.
op_log_sigmoid()
Logarithm of the sigmoid activation function.
op_log_softmax()
Log-softmax activation function.
op_max_pool()
Max pooling operation.
op_moments()
Calculates the mean and variance of x.
op_multi_hot()
Encodes integer labels as multi-hot vectors.
op_normalize()
Normalizes x over the specified axis.
op_one_hot()
Converts integer tensor x into a one-hot tensor.
op_psnr()
Peak Signal-to-Noise Ratio (PSNR) function.
op_relu()
Rectified linear unit activation function.
op_relu6()
Rectified linear unit activation function with upper bound of 6.
op_selu()
Scaled Exponential Linear Unit (SELU) activation function.
op_separable_conv()
General N-D separable convolution.
op_sigmoid()
Sigmoid activation function.
op_silu()
Sigmoid Linear Unit (SiLU) activation function, also known as Swish.
op_softmax()
Softmax activation function.
op_softplus()
Softplus activation function.
op_softsign()
Softsign activation function.
op_sparse_categorical_crossentropy()
Computes sparse categorical cross-entropy loss.

Linear Algebra Operations

op_cholesky()
Computes the Cholesky decomposition of a positive semi-definite matrix.
op_det()
Computes the determinant of a square tensor.
op_eig()
Computes the eigenvalues and eigenvectors of a square matrix.
op_eigh()
Computes the eigenvalues and eigenvectors of a complex Hermitian.
op_inv()
Computes the inverse of a square tensor.
op_lu_factor()
Computes the lower-upper decomposition of a square matrix.
op_norm()
Matrix or vector norm.
op_slogdet()
Compute the sign and natural logarithm of the determinant of a matrix.
op_solve_triangular()
Solves a linear system of equations given by a %*% x = b.
op_svd()
Computes the singular value decomposition of a matrix.

Image Operations

op_image_affine_transform()
Applies the given transform(s) to the image(s).
op_image_crop()
Crop images to a specified height and width.
op_image_extract_patches()
Extracts patches from the image(s).
op_image_map_coordinates()
Map the input array to new coordinates by interpolation..
op_image_pad()
Pad images with zeros to the specified height and width.
op_image_resize()
Resize images to size using the specified interpolation method.
op_image_rgb_to_grayscale()
Convert RGB images to grayscale.

Losses

loss_binary_crossentropy()
Computes the cross-entropy loss between true labels and predicted labels.
loss_binary_focal_crossentropy()
Computes focal cross-entropy loss between true labels and predictions.
loss_categorical_crossentropy()
Computes the crossentropy loss between the labels and predictions.
loss_categorical_focal_crossentropy()
Computes the alpha balanced focal crossentropy loss.
loss_categorical_hinge()
Computes the categorical hinge loss between y_true & y_pred.
loss_cosine_similarity()
Computes the cosine similarity between y_true & y_pred.
loss_ctc()
CTC (Connectionist Temporal Classification) loss.
loss_dice()
Computes the Dice loss value between y_true and y_pred.
loss_hinge()
Computes the hinge loss between y_true & y_pred.
loss_huber()
Computes the Huber loss between y_true & y_pred.
loss_kl_divergence()
Computes Kullback-Leibler divergence loss between y_true & y_pred.
loss_log_cosh()
Computes the logarithm of the hyperbolic cosine of the prediction error.
loss_mean_absolute_error()
Computes the mean of absolute difference between labels and predictions.
loss_mean_absolute_percentage_error()
Computes the mean absolute percentage error between y_true and y_pred.
loss_mean_squared_error()
Computes the mean of squares of errors between labels and predictions.
loss_mean_squared_logarithmic_error()
Computes the mean squared logarithmic error between y_true and y_pred.
loss_poisson()
Computes the Poisson loss between y_true & y_pred.
loss_sparse_categorical_crossentropy()
Computes the crossentropy loss between the labels and predictions.
loss_squared_hinge()
Computes the squared hinge loss between y_true & y_pred.
loss_tversky()
Computes the Tversky loss value between y_true and y_pred.
Loss()
Subclass the base Loss class

Metrics

metric_auc()
Approximates the AUC (Area under the curve) of the ROC or PR curves.
metric_binary_accuracy()
Calculates how often predictions match binary labels.
metric_binary_crossentropy()
Computes the crossentropy metric between the labels and predictions.
metric_binary_focal_crossentropy()
Computes the binary focal crossentropy loss.
metric_binary_iou()
Computes the Intersection-Over-Union metric for class 0 and/or 1.
metric_categorical_accuracy()
Calculates how often predictions match one-hot labels.
metric_categorical_crossentropy()
Computes the crossentropy metric between the labels and predictions.
metric_categorical_focal_crossentropy()
Computes the categorical focal crossentropy loss.
metric_categorical_hinge()
Computes the categorical hinge metric between y_true and y_pred.
metric_cosine_similarity()
Computes the cosine similarity between the labels and predictions.
metric_f1_score()
Computes F-1 Score.
metric_false_negatives()
Calculates the number of false negatives.
metric_false_positives()
Calculates the number of false positives.
metric_fbeta_score()
Computes F-Beta score.
metric_hinge()
Computes the hinge metric between y_true and y_pred.
metric_huber()
Computes Huber loss value.
metric_iou()
Computes the Intersection-Over-Union metric for specific target classes.
metric_kl_divergence()
Computes Kullback-Leibler divergence metric between y_true and
metric_log_cosh()
Logarithm of the hyperbolic cosine of the prediction error.
metric_log_cosh_error()
Computes the logarithm of the hyperbolic cosine of the prediction error.
metric_mean()
Compute the (weighted) mean of the given values.
metric_mean_absolute_error()
Computes the mean absolute error between the labels and predictions.
metric_mean_absolute_percentage_error()
Computes mean absolute percentage error between y_true and y_pred.
metric_mean_iou()
Computes the mean Intersection-Over-Union metric.
metric_mean_squared_error()
Computes the mean squared error between y_true and y_pred.
metric_mean_squared_logarithmic_error()
Computes mean squared logarithmic error between y_true and y_pred.
metric_mean_wrapper()
Wrap a stateless metric function with the Mean metric.
metric_one_hot_iou()
Computes the Intersection-Over-Union metric for one-hot encoded labels.
metric_one_hot_mean_iou()
Computes mean Intersection-Over-Union metric for one-hot encoded labels.
metric_poisson()
Computes the Poisson metric between y_true and y_pred.
metric_precision()
Computes the precision of the predictions with respect to the labels.
metric_precision_at_recall()
Computes best precision where recall is >= specified value.
metric_r2_score()
Computes R2 score.
metric_recall()
Computes the recall of the predictions with respect to the labels.
metric_recall_at_precision()
Computes best recall where precision is >= specified value.
metric_root_mean_squared_error()
Computes root mean squared error metric between y_true and y_pred.
metric_sensitivity_at_specificity()
Computes best sensitivity where specificity is >= specified value.
metric_sparse_categorical_accuracy()
Calculates how often predictions match integer labels.
metric_sparse_categorical_crossentropy()
Computes the crossentropy metric between the labels and predictions.
metric_sparse_top_k_categorical_accuracy()
Computes how often integer targets are in the top K predictions.
metric_specificity_at_sensitivity()
Computes best specificity where sensitivity is >= specified value.
metric_squared_hinge()
Computes the hinge metric between y_true and y_pred.
metric_sum()
Compute the (weighted) sum of the given values.
metric_top_k_categorical_accuracy()
Computes how often targets are in the top K predictions.
metric_true_negatives()
Calculates the number of true negatives.
metric_true_positives()
Calculates the number of true positives.
custom_metric()
Custom metric function
reset_state()
Reset the state for a model, layer or metric.
Metric()
Subclass the base Metric class

Data Loading

Keras data loading utilities help you quickly go from raw data to a TF Dataset object that can be used to efficiently train a model. These loading utilites can be combined with preprocessing layers to futher transform your input dataset before training.

image_dataset_from_directory()
Generates a tf.data.Dataset from image files in a directory.
text_dataset_from_directory()
Generates a tf.data.Dataset from text files in a directory.
audio_dataset_from_directory()
Generates a tf.data.Dataset from audio files in a directory.
timeseries_dataset_from_array()
Creates a dataset of sliding windows over a timeseries provided as array.

Preprocessing

Numerical Features Preprocessing Layers

layer_normalization()
A preprocessing layer that normalizes continuous features.
layer_discretization()
A preprocessing layer which buckets continuous features by ranges.

Categorical Features Preprocessing Layers

layer_category_encoding()
A preprocessing layer which encodes integer features.
layer_hashing()
A preprocessing layer which hashes and bins categorical features.
layer_hashed_crossing()
A preprocessing layer which crosses features using the "hashing trick".
layer_string_lookup()
A preprocessing layer that maps strings to (possibly encoded) indices.
layer_integer_lookup()
A preprocessing layer that maps integers to (possibly encoded) indices.

Text Preprocessing Layers

layer_text_vectorization() get_vocabulary() set_vocabulary()
A preprocessing layer which maps text features to integer sequences.

Sequence Preprocessing

timeseries_dataset_from_array()
Creates a dataset of sliding windows over a timeseries provided as array.
pad_sequences()
Pads sequences to the same length.

Image Preprocessing Layers

layer_resizing()
A preprocessing layer which resizes images.
layer_rescaling()
A preprocessing layer which rescales input values to a new range.
layer_center_crop()
A preprocessing layer which crops images.

Image Preprocessing

image_array_save()
Saves an image stored as an array to a path or file object.
image_dataset_from_directory()
Generates a tf.data.Dataset from image files in a directory.
image_from_array()
Converts a 3D array to a PIL Image instance.
image_load()
Loads an image into PIL format.
image_smart_resize()
Resize images to a target size without aspect ratio distortion.
image_to_array()
Converts a PIL Image instance to a matrix.
op_image_affine_transform()
Applies the given transform(s) to the image(s).
op_image_crop()
Crop images to a specified height and width.
op_image_extract_patches()
Extracts patches from the image(s).
op_image_map_coordinates()
Map the input array to new coordinates by interpolation..
op_image_pad()
Pad images with zeros to the specified height and width.
op_image_resize()
Resize images to size using the specified interpolation method.
op_image_rgb_to_grayscale()
Convert RGB images to grayscale.

Image augmentation Layers

layer_random_crop()
A preprocessing layer which randomly crops images during training.
layer_random_flip()
A preprocessing layer which randomly flips images during training.
layer_random_translation()
A preprocessing layer which randomly translates images during training.
layer_random_rotation()
A preprocessing layer which randomly rotates images during training.
layer_random_zoom()
A preprocessing layer which randomly zooms images during training.
layer_random_contrast()
A preprocessing layer which randomly adjusts contrast during training.
layer_random_brightness()
A preprocessing layer which randomly adjusts brightness during training.

Application Preprocessing

application_preprocess_inputs() application_decode_predictions()
Preprocessing and postprocessing utilities

Optimizers

optimizer_adadelta()
Optimizer that implements the Adadelta algorithm.
optimizer_adafactor()
Optimizer that implements the Adafactor algorithm.
optimizer_adagrad()
Optimizer that implements the Adagrad algorithm.
optimizer_adam()
Optimizer that implements the Adam algorithm.
optimizer_adam_w()
Optimizer that implements the AdamW algorithm.
optimizer_adamax()
Optimizer that implements the Adamax algorithm.
optimizer_ftrl()
Optimizer that implements the FTRL algorithm.
optimizer_lion()
Optimizer that implements the Lion algorithm.
optimizer_loss_scale()
An optimizer that dynamically scales the loss to prevent underflow.
optimizer_nadam()
Optimizer that implements the Nadam algorithm.
optimizer_rmsprop()
Optimizer that implements the RMSprop algorithm.
optimizer_sgd()
Gradient descent (with momentum) optimizer.

Learning Rate Schedules

learning_rate_schedule_cosine_decay()
A LearningRateSchedule that uses a cosine decay with optional warmup.
learning_rate_schedule_cosine_decay_restarts()
A LearningRateSchedule that uses a cosine decay schedule with restarts.
learning_rate_schedule_exponential_decay()
A LearningRateSchedule that uses an exponential decay schedule.
learning_rate_schedule_inverse_time_decay()
A LearningRateSchedule that uses an inverse time decay schedule.
learning_rate_schedule_piecewise_constant_decay()
A LearningRateSchedule that uses a piecewise constant decay schedule.
learning_rate_schedule_polynomial_decay()
A LearningRateSchedule that uses a polynomial decay schedule.
LearningRateSchedule()
Define a custom LearningRateSchedule class

Initializers

initializer_constant()
Initializer that generates tensors with constant values.
initializer_glorot_normal()
The Glorot normal initializer, also called Xavier normal initializer.
initializer_glorot_uniform()
The Glorot uniform initializer, also called Xavier uniform initializer.
initializer_he_normal()
He normal initializer.
initializer_he_uniform()
He uniform variance scaling initializer.
initializer_identity()
Initializer that generates the identity matrix.
initializer_lecun_normal()
Lecun normal initializer.
initializer_lecun_uniform()
Lecun uniform initializer.
initializer_ones()
Initializer that generates tensors initialized to 1.
initializer_orthogonal()
Initializer that generates an orthogonal matrix.
initializer_random_normal()
Random normal initializer.
initializer_random_uniform()
Random uniform initializer.
initializer_truncated_normal()
Initializer that generates a truncated normal distribution.
initializer_variance_scaling()
Initializer that adapts its scale to the shape of its input tensors.
initializer_zeros()
Initializer that generates tensors initialized to 0.

Constraints

Constraint()
Define a custom Constraint class
constraint_maxnorm()
MaxNorm weight constraint.
constraint_minmaxnorm()
MinMaxNorm weight constraint.
constraint_nonneg()
Constrains the weights to be non-negative.
constraint_unitnorm()
Constrains the weights incident to each hidden unit to have unit norm.

Regularizers

regularizer_l1()
A regularizer that applies a L1 regularization penalty.
regularizer_l1_l2()
A regularizer that applies both L1 and L2 regularization penalties.
regularizer_l2()
A regularizer that applies a L2 regularization penalty.
regularizer_orthogonal()
Regularizer that encourages input vectors to be orthogonal to each other.

Activations

activation_elu()
Exponential Linear Unit.
activation_exponential()
Exponential activation function.
activation_gelu()
Gaussian error linear unit (GELU) activation function.
activation_hard_sigmoid()
Hard sigmoid activation function.
activation_hard_silu() activation_hard_swish()
Hard SiLU activation function, also known as Hard Swish.
activation_leaky_relu()
Leaky relu activation function.
activation_linear()
Linear activation function (pass-through).
activation_log_softmax()
Log-Softmax activation function.
activation_mish()
Mish activation function.
activation_relu()
Applies the rectified linear unit activation function.
activation_relu6()
Relu6 activation function.
activation_selu()
Scaled Exponential Linear Unit (SELU).
activation_sigmoid()
Sigmoid activation function.
activation_silu()
Swish (or Silu) activation function.
activation_softmax()
Softmax converts a vector of values to a probability distribution.
activation_softplus()
Softplus activation function.
activation_softsign()
Softsign activation function.
activation_tanh()
Hyperbolic tangent activation function.

Random Tensor Generators

random_uniform()
Draw samples from a uniform distribution.
random_normal()
Draw random samples from a normal (Gaussian) distribution.
random_truncated_normal()
Draw samples from a truncated normal distribution.
random_gamma()
Draw random samples from the Gamma distribution.
random_categorical()
Draws samples from a categorical distribution.
random_integer()
Draw random integers from a uniform distribution.
random_dropout()
Randomly set some values in a tensor to 0.
random_shuffle()
Shuffle the elements of a tensor uniformly at random along an axis.
random_beta()
Draw samples from a Beta distribution.
random_binomial()
Draw samples from a Binomial distribution.
random_seed_generator()
Generates variable seeds upon each call to a RNG-using function.

Builtin small datasets

dataset_boston_housing()
Boston housing price regression dataset
dataset_cifar10()
CIFAR10 small image classification
dataset_cifar100()
CIFAR100 small image classification
dataset_fashion_mnist()
Fashion-MNIST database of fashion articles
dataset_imdb() dataset_imdb_word_index()
IMDB Movie reviews sentiment classification
dataset_mnist()
MNIST database of handwritten digits
dataset_reuters() dataset_reuters_word_index()
Reuters newswire topics classification

Configuration

config_backend()
Publicly accessible method for determining the current backend.
config_disable_interactive_logging()
Turn off interactive logging.
config_disable_traceback_filtering()
Turn off traceback filtering.
config_dtype_policy()
Returns the current default dtype policy object.
config_enable_interactive_logging()
Turn on interactive logging.
config_enable_traceback_filtering()
Turn on traceback filtering.
config_enable_unsafe_deserialization()
Disables safe mode globally, allowing deserialization of lambdas.
config_epsilon()
Return the value of the fuzz factor used in numeric expressions.
config_floatx()
Return the default float type, as a string.
config_image_data_format()
Return the default image data format convention.
config_is_interactive_logging_enabled()
Check if interactive logging is enabled.
config_is_traceback_filtering_enabled()
Check if traceback filtering is enabled.
config_set_backend()
Reload the backend (and the Keras package).
config_set_dtype_policy()
Sets the default dtype policy globally.
config_set_epsilon()
Set the value of the fuzz factor used in numeric expressions.
config_set_floatx()
Set the default float dtype.
config_set_image_data_format()
Set the value of the image data format convention.

Utils

install_keras()
Install Keras
use_backend()
Configure a Keras backend
shape() format(<keras_shape>) print(<keras_shape>) `[`(<keras_shape>) as.integer(<keras_shape>) as.list(<keras_shape>)
Tensor shape utility
set_random_seed()
Sets all random seeds (Python, NumPy, and backend framework, e.g. TF).
clear_session()
Resets all state generated by Keras.
get_source_inputs()
Returns the list of input tensors necessary to compute tensor.
keras
Main Keras module

Numerical Utils

normalize()
Normalizes an array.
to_categorical()
Converts a class vector (integers) to binary class matrix.

Data Utils

zip_lists()
Zip lists
get_file()
Downloads a file from a URL if it not already in the cache.
split_dataset()
Splits a dataset into a left half and a right half (e.g. train / test).

Serialization Utils

register_keras_serializable()
Registers a custom object with the Keras serialization framework.
get_custom_objects() set_custom_objects()
Get/set the currently registered custom objects.
get_registered_name()
Returns the name registered to an object within the Keras framework.
get_registered_object()
Returns the class associated with name if it is registered with Keras.
serialize_keras_object()
Retrieve the full config by serializing the Keras object.
deserialize_keras_object()
Retrieve the object by deserializing the config dict.
with_custom_object_scope()
Provide a scope with mappings of names to custom objects
config_enable_unsafe_deserialization()
Disables safe mode globally, allowing deserialization of lambdas.

Base Keras Classes

Define custom object by subclassing base Keras classes.

Layer()
Define a custom Layer class.
Loss()
Subclass the base Loss class
Metric()
Subclass the base Metric class
Callback()
Define a custom Callback class
Constraint()
Define a custom Constraint class
Model()
Subclass the base Keras Model Class
LearningRateSchedule()
Define a custom LearningRateSchedule class
active_property()
Create an active property class method

Applications

Application utilities

application_preprocess_inputs() application_decode_predictions()
Preprocessing and postprocessing utilities

ConvNeXt Applications

application_convnext_base()
Instantiates the ConvNeXtBase architecture.
application_convnext_large()
Instantiates the ConvNeXtLarge architecture.
application_convnext_small()
Instantiates the ConvNeXtSmall architecture.
application_convnext_tiny()
Instantiates the ConvNeXtTiny architecture.
application_convnext_xlarge()
Instantiates the ConvNeXtXLarge architecture.

Densenet Applications

application_densenet121()
Instantiates the Densenet121 architecture.
application_densenet169()
Instantiates the Densenet169 architecture.
application_densenet201()
Instantiates the Densenet201 architecture.

EfficientNet Applications

application_efficientnet_b0()
Instantiates the EfficientNetB0 architecture.
application_efficientnet_b1()
Instantiates the EfficientNetB1 architecture.
application_efficientnet_b2()
Instantiates the EfficientNetB2 architecture.
application_efficientnet_b3()
Instantiates the EfficientNetB3 architecture.
application_efficientnet_b4()
Instantiates the EfficientNetB4 architecture.
application_efficientnet_b5()
Instantiates the EfficientNetB5 architecture.
application_efficientnet_b6()
Instantiates the EfficientNetB6 architecture.
application_efficientnet_b7()
Instantiates the EfficientNetB7 architecture.
application_efficientnet_v2b0()
Instantiates the EfficientNetV2B0 architecture.
application_efficientnet_v2b1()
Instantiates the EfficientNetV2B1 architecture.
application_efficientnet_v2b2()
Instantiates the EfficientNetV2B2 architecture.
application_efficientnet_v2b3()
Instantiates the EfficientNetV2B3 architecture.
application_efficientnet_v2l()
Instantiates the EfficientNetV2L architecture.
application_efficientnet_v2m()
Instantiates the EfficientNetV2M architecture.
application_efficientnet_v2s()
Instantiates the EfficientNetV2S architecture.

Inception Applications

application_inception_resnet_v2()
Instantiates the Inception-ResNet v2 architecture.
application_inception_v3()
Instantiates the Inception v3 architecture.

MobileNet Applications

application_mobilenet()
Instantiates the MobileNet architecture.
application_mobilenet_v2()
Instantiates the MobileNetV2 architecture.
application_mobilenet_v3_large()
Instantiates the MobileNetV3Large architecture.
application_mobilenet_v3_small()
Instantiates the MobileNetV3Small architecture.

NASNet Applications

application_nasnetlarge()
Instantiates a NASNet model in ImageNet mode.
application_nasnetmobile()
Instantiates a Mobile NASNet model in ImageNet mode.

ResNet Applications

application_resnet101()
Instantiates the ResNet101 architecture.
application_resnet101_v2()
Instantiates the ResNet101V2 architecture.
application_resnet152()
Instantiates the ResNet152 architecture.
application_resnet152_v2()
Instantiates the ResNet152V2 architecture.
application_resnet50()
Instantiates the ResNet50 architecture.
application_resnet50_v2()
Instantiates the ResNet50V2 architecture.

VGG Applications

application_vgg16()
Instantiates the VGG16 model.
application_vgg19()
Instantiates the VGG19 model.

Xception Applications

application_xception()
Instantiates the Xception architecture.