from typing import Optional
import numpy as np
from mygrad.operation_base import Operation
from mygrad.tensor_base import Tensor
from mygrad.typing import ArrayLike
__all__ = ["batchnorm"]
# TODO: Remove affine parameters from Operation
class BatchNorm(Operation):
"""
Attributes
----------
mean : numpy.ndarray
var : numpy.ndarray
Notes
-----
`mean` and `var` are bound as instance-attributes upon
calling the batch-norm instance.
"""
def __call__(self, x, gamma, beta, *, eps):
"""
y(x) = (x - E[x]) / sqrt(Var[x} + eps)
batchnorm(x) = gamma * y(x) + beta
Parameters
----------
x : mygrad.Tensor
gamma : Optional[mygrad.Tensor]
beta : Optional[mygrad.Tensor]
eps : Real
A small non-negative number.
Returns
-------
numpy.ndarray
"""
normed_dims = tuple(i for i in range(x.ndim) if i != 1)
keepdims_shape = tuple(1 if n != 1 else d for n, d in enumerate(x.shape))
self.variables = tuple(i for i in (x, gamma, beta))
if gamma.size == 0:
gamma = None
if beta.size == 0:
beta = None
self.gamma = gamma
self.beta = beta
x = x.data
self.x_norm = None # required for backprop through gamma
self.mean = x.mean(axis=normed_dims)
self.var = x.var(axis=normed_dims)
y = x - self.mean.reshape(keepdims_shape)
self._std = np.sqrt(self.var + eps).reshape(keepdims_shape) # sqrt(var + eps)
y /= self._std
self.x_norm = y
# optional affine transformation
if gamma is not None:
gamma = gamma.data
# must copy `y` to prevent mutation of `self.x_norm`
y = y * gamma.reshape(keepdims_shape)
if beta is not None:
beta = beta.data
y = y + beta.reshape(keepdims_shape)
return y
def backward_var(self, grad, index, **kwargs):
x = self.variables[0].data
if index == 0: # backprop through x
normed_dims = tuple(i for i in range(x.ndim) if i != 1)
keepdims_shape = tuple(1 if n != 1 else d for n, d in enumerate(x.shape))
N = x.size / x.shape[1]
# all sums carried over non-channel dims
# (1/sqrt(var + eps)) * [dL - dL.mean() - (1/N)*x_norm*(x_norm @ dL)]
grad_ = grad - np.mean(grad, axis=normed_dims, keepdims=True)
rterm = self.x_norm * np.reshape(
np.einsum(grad, range(x.ndim), self.x_norm, range(x.ndim), [1]),
keepdims_shape,
)
rterm /= N
grad_ -= rterm
grad_ /= self._std
if (
self.gamma is not None
): # backprop through optional affine transformation
gamma = self.gamma.data
grad_ *= gamma.reshape(keepdims_shape)
return grad_
elif index == 1 and self.gamma is not None: # backprop through gamma
return np.einsum(grad, range(x.ndim), self.x_norm, range(x.ndim), [1])
elif (index == 1 and self.gamma is None) or index == 2:
normed_dims = tuple(i for i in range(x.ndim) if i != 1)
return grad.sum(axis=normed_dims)
else: # pragma: no cover
raise IndexError
[docs]def batchnorm(
x: ArrayLike,
*,
gamma: Optional[ArrayLike] = None,
beta: Optional[ArrayLike] = None,
eps: float,
constant: Optional[bool] = None,
) -> Tensor:
"""
Performs batch normalization on ``x``::
y(x) = (x - E[x]) / sqrt(Var[x] + eps)
batchnorm(x) = gamma * y(x) + beta
Where :math:`E[x]` and :math:`Var[x]` represent the mean and variance, respectively,
over axis-1 of ``x``. The subsequent affine transformation on ``y``
is optional.
Parameters
----------
x : array_like, shape=(N, C, ...)
The batch to be normalized within each entry of C
gamma : Optional[array_like], shape=(C,)
Optional per-channel scaling factors to be applied after the
normalization step.
beta : Optional[array_like], shape=(C,)
Optional per-channel scaling bias factors to be applied after the
normalization step.
eps : Real
A small non-negative number.
constant : bool, optional (default=False)
If True, the resulting Tensor is a constant.
Returns
-------
mygrad.Tensor
The batch-normalized data.
Examples
--------
>>> import mygrad as mg
>>> from mygrad.nnet import batchnorm
>>> x = mg.Tensor([1., 4., 1.]).reshape(3, 1)
>>> batchnorm(x, eps=0)
Tensor([[-0.70710678],
[ 1.41421356],
[-0.70710678]])
"""
# pass gamma and beta as empty arrays if they are not supplied
if gamma is None:
gamma = np.array([])
if beta is None:
beta = np.array([])
return Tensor._op(
BatchNorm, x, gamma, beta, op_kwargs=dict(eps=eps), constant=constant
)