Source code for mygrad.nnet.activations.glu
from typing import Optional
from numpy import ndarray
from mygrad.math.arithmetic.funcs import multiply
from mygrad.tensor_base import Tensor
from mygrad.typing import ArrayLike
from .sigmoid import sigmoid
[docs]def glu(x: ArrayLike, axis: int = -1, *, constant: Optional[bool] = None) -> Tensor:
"""Returns the Gated Linear Unit A * σ(B), where A and B are split from `x`.
Parameters
----------
x : ArrayLike
The input.
axis : int, optional (default=-1)
The axis along which to split the input in half and apply the GLU.
constant : Optional[bool]
If ``True``, the returned tensor is a constant (it
does not back-propagate a gradient).
Returns
-------
mygrad.Tensor
The result of applying the Gated Linear Unit elementwise to the input.
Notes
-----
The Gated Linear Unit was proposed in the paper
"Language Modeling with Gated Convolutional Networks"
Yann Dauphin, Angela Fan, Michael Auli, David Grangier
available at https://arxiv.org/abs/1612.08083
The GLU operation splits the input `x` in half along `axis`, storing the first half in A and the
second in B. The return value is then A ⊙ σ(B), where ⊙ is elementwise multiplication and σ is
the sigmoid function.
Examples
--------
>>> import mygrad as mg
>>> from mygrad.nnet.activations import glu
>>> x = mg.arange(-5., 5.)
>>> x
Tensor([-5., -4., -3., -2., -1., 0., 1., 2., 3., 4.])
>>> y = glu(x); y
Tensor([-2.5 , -2.92423431, -2.64239123, -1.90514825, -0.98201379])
>>> y.backward()
>>> x.grad
array([ 0, 0, 0, 0, 0, -1, 0, 0, 0, 0])
"""
if isinstance(axis, (ndarray, Tensor)):
axis = axis.item()
if not isinstance(axis, int):
raise TypeError(
f"`axis` must be an integer-valued scalar, got {axis} (type {type(axis)})"
)
first_idx = list(slice(None) for _ in x.shape)
second_idx = list(slice(None) for _ in x.shape)
first_idx[axis] = slice(0, x.shape[axis] // 2)
second_idx[axis] = slice(x.shape[axis] // 2, None)
first_half = x[tuple(first_idx)]
second_half = x[tuple(second_idx)]
if first_half.shape != second_half.shape:
raise ValueError(
f"The shapes after splitting must be the same but got {first_half.shape} "
"and {second_half.shape}"
)
return multiply(first_half, sigmoid(second_half), constant=constant)