from numbers import Integral
from typing import Optional, Tuple, Union
import numpy as np
from mygrad.nnet.layers.utils import sliding_window_view
from mygrad.operation_base import Operation
from mygrad.tensor_base import Tensor
from mygrad.typing import ArrayLike
__all__ = ["conv_nd"]
class ConvND(Operation):
def __call__(self, x, w, *, stride, padding=0, dilation=1):
self.variables = (x, w)
# x ... data: (N, C, X0, X1, ...)
# w ... filters: (F, C, W0, W1, ...)
x = x.data
w = w.data
assert x.ndim > 2
assert x.ndim == w.ndim
assert (
w.shape[1] == x.shape[1]
), "The channel-depth of the batch and filters must agree"
num_conv_channels = w.ndim - 2
x_shape = np.array(
x.shape[2:]
) # (X0, ...): shape of the channels being convolved over
w_shape = np.array(w.shape[2:]) # (W0, ...): shape of each conv filter
dilation = (
np.array((dilation,) * num_conv_channels)
if isinstance(dilation, Integral)
else np.array(dilation, dtype=int)
)
assert len(dilation) == num_conv_channels and all(
d >= 1 and isinstance(d, Integral) for d in dilation
)
padding = (
np.array((padding,) * num_conv_channels)
if isinstance(padding, Integral)
else np.array(padding, dtype=int)
)
assert len(padding) == num_conv_channels and all(
p >= 0 and isinstance(p, Integral) for p in padding
)
stride = (
np.array((stride,) * num_conv_channels)
if isinstance(stride, Integral)
else np.asarray(stride, dtype=int)
)
assert len(stride) == num_conv_channels and all(
s >= 1 and isinstance(s, Integral) for s in stride
)
out_shape = (
x_shape + 2 * padding - ((w_shape - 1) * dilation + 1)
) / stride + 1
if not all(i.is_integer() and i > 0 for i in out_shape):
msg = "Stride and kernel dimensions are incompatible: \n"
msg += f"Input dimensions: {tuple(x_shape)}\n"
msg += f"Stride dimensions: {tuple(stride)}\n"
msg += f"Kernel dimensions: {tuple(w_shape)}\n"
msg += f"Padding dimensions: {tuple(padding)}\n"
msg += f"Dilation dimensions: {tuple(dilation)}\n"
raise ValueError(msg)
self.padding = padding
self.stride = stride
self.dilation = dilation
# symmetric 0-padding for X0, X1, ... dimensions
axis_pad = tuple((i, i) for i in (0, 0, *padding))
x = np.pad(x, axis_pad, mode="constant") if sum(padding) else x
# (G0, ...) is the tuple of grid-positions for placing each window (not including stride)
# (N, C, X0, ...) -> (G0, ..., N, C, W0, ...)
windowed_data = sliding_window_view(
x, window_shape=w_shape, step=self.stride, dilation=self.dilation
)
w_conv_channels = list(range(1, num_conv_channels + 2)) # C, W0, ...
window_conv_channels = [
i + 1 + num_conv_channels # C, W0, ...
for i in range(num_conv_channels + 1)
]
# (F, C, W0, ...) ⋆ (G0, ..., N, C, W0, ...) -> (F, G0, ..., N)
conv_out = np.tensordot(
w, windowed_data, axes=[w_conv_channels, window_conv_channels]
)
# (F, G0, ..., N) -> (N, F, G0, ...)
out = np.moveaxis(conv_out, source=-1, destination=0)
return out if out.flags["C_CONTIGUOUS"] else np.ascontiguousarray(out)
def backward_var(self, grad, index, **kwargs):
"""Computes dX, where X is the data batch
Parameters
----------
grad : numpy.ndarray, shape=(N, F, G0, ...)"""
x, w = (i.data for i in self.variables)
num_conv_channels = grad.ndim - 2
if index == 0: # backprop through x
x_shape = x.shape[:2] + tuple(
i + 2 * p for i, p in zip(x.shape[-num_conv_channels:], self.padding)
)
dx = np.zeros(x_shape, dtype=x.dtype) # (N, C, X0, ...)
# `gp` stores all of the various broadcast multiplications of each grad
# element against the conv filter.
# (N, F, G0, ...) -tdot- (F, C, W0, ...) --> (N, G0, ..., C, W0, ...)
gp = np.tensordot(grad, w, axes=[[1], [0]])
for ind in np.ndindex(grad.shape[-num_conv_channels:]):
# ind: (g0, ...) - grid-position of filter placement
slices = tuple(
slice(i * s, i * s + w * d, d)
for i, w, s, d in zip(ind, w.shape[2:], self.stride, self.dilation)
)
# Add (grad-element * filter) to each appropriate window position in `dx`
# dx[N, C, g0*s0 : g0*s0 + w0*d0 : d0, (...)] += gp[N, g0, (...), C, W0, (...)]
dx[(..., *slices)] += gp[(slice(None), *ind, ...)]
# remove padding from dx
if sum(self.padding):
no_pads = tuple(slice(p, -p if p else None) for p in self.padding)
dx = dx[(..., *no_pads)]
return dx
else: # backprop through w
# backprop into f
# symmetric 0-padding for H, W dimensions
axis_pad = tuple((i, i) for i in (0, 0, *self.padding))
x = np.pad(x, axis_pad, mode="constant") if sum(self.padding) else x
# (G0, ...) is the tuple of grid-indices for placing each window (not including stride)
# (N, C, X0, ...) -> (G0, ..., N, C, W0, ...)
windowed_data = sliding_window_view(
x, window_shape=w.shape[2:], step=self.stride, dilation=self.dilation
)
# (N, F, G0, ...) -tdot- (G0, ..., N, C, W0, ...) --> (F, C, W0, ...)
grad_axes = list(range(2, num_conv_channels + 2)) + [0] # (G0, ..., N)
window_axes = list(range(num_conv_channels + 1)) # (G0, ..., N)
return np.tensordot(grad, windowed_data, axes=[grad_axes, window_axes])
[docs]def conv_nd(
x: ArrayLike,
filter_bank: ArrayLike,
*,
stride: Union[int, Tuple[int, ...]],
padding: Union[int, Tuple[int, ...]] = 0,
dilation: Union[int, Tuple[int, ...]] = 1,
constant: Optional[bool] = None,
) -> Tensor:
"""Use ``filter_bank`` (``w``) to perform strided N-dimensional neural network-style
convolutions (see Notes) over ``x``.::
f(x, w) -> x ⋆ w
shapes:
(N, C, X0, ...) ⋆ (F, C, W0, ...) -> (N, F, G0, ...)
``x`` represents a batch of data over which the filters
are convolved. Specifically, it must be a tensor of shape
:math:`(N, C, X_0, ...)`, where :math:`N` is the number of samples in the batch,
C is the channel-depth of each datum, and :math:`(X_0, ...)` are the
dimensions over which the filters are convolved. Accordingly,
each filter must have a channel depth of :math:`C`.
Thus convolving :math:`F` filters, each with a shape :math:`(C, W_0, ...)`,
over the data batch will produce a tensor of shape
:math:`(N, F, G_0, ...)`, where :math:`(G_0, ...)` is the shape of the grid
commensurate with the filter placements
Parameters
----------
x : ArrayLike, shape=(N, C, Xo, ...)
The data batch to be convolved over.
filter_bank : Union[Tensor, array_like], shape=(F, C, Wo, ...)
The filters used to perform the convolutions.
stride : Union[int, Tuple[int, ...]]
(keyword-only argument) The step-size with which each
filter is placed along the H and W axes during the
convolution. The tuple indicates (stride-0, ...). If a
single integer is provided, this stride is used for all
convolved dimensions
padding : Union[int, Tuple[int, ...]]
(keyword-only argument) The number of zeros to be padded
to both ends of each convolved dimension, respectively.
If a single integer is provided, this padding is used for
all of the convolved axes
dilation : Union[int, Tuple[int, ...]], optional (default=1)
(keyword-only argument) The spacing used when placing kernel
elements along the data. E.g. for a 1D convolution the ith
placement of the kernel multiplied against the dilated-window:
``x[:, :, i*s:(i*s + w*d):d]``, where ``s`` is
the stride, ``w`` is the kernel-size, and ``d`` is the dilation factor.
If a single integer is provided, that dilation value is used for all
of the convolved axes
constant : Optional[None]
If True, the resulting Tensor is a constant.
Returns
-------
Tensor, shape=(N, F, G0, ...)
The result of each filter being convolved over each datum in
the batch.
Notes
-----
- The filters are *not* flipped by this operation, meaning that
an auto-correlation is being performed rather than a true convolution.
- Only 'valid' filter placements – where the filters overlap
completely with the (padded) data – are permitted.
Examples
--------
Here we perform a 1D convolution of a constant-valued kernel, ``k``, with a
'square-wave' signal, ``x``, using stride-1. Note that because we are constrained
to doing deep learning-style convolutions, that we prepend the dimensions
:math:`(N=1, C=1)` to ``x``, and :math:`(F=1, C=1)` and to ``k``. That is,
we are performing a convolution on one, single-channeled signal using
one kernel.
See that this convolution produces the expected triangle-shaped
response. The shape of the resulting tensor is :math:`(N=1, F=1, G_0=12)`.
That is, the length-5 kernel can be placed in 12 valid positions, using a
stride of 1.
>>> import mygrad as mg
>>> from mygrad.nnet import conv_nd
>>> x = mg.zeros((1, 1, 16)) # a square-wave signal
>>> x[..., 5:11] = 1
>>> k = mg.ones((1, 1, 5)) # a constant-valued kernel
>>> conv_nd(x, k, stride=1) # performing a stride-1, 1D convolution
Tensor([[[0., 1., 2., 3., 4., 5., 5., 4., 3., 2., 1., 0.]]], dtype=float32)
Back-propagating through the (summed) convolution:
>>> conv_nd(x, k, stride=1).sum().backward() # sum to a scalar to perform back-prop
>>> x.grad # d(summed_conv)/dx
array([[[1., 2., 3., 4., 5., 5., 5., 5., 5., 5., 5., 5., 4., 3., 2., 1.]]],
dtype=float32)
>>> k.grad # d(summed_conv)/dk
array([[[6., 6., 6., 6., 6.]]])
.. plot::
>>> import mygrad as mg
>>> from mygrad.nnet import conv_nd
>>> import matplotlib.pyplot as plt
>>> kernel = mg.ones(5) # a square-wave signal
>>> x = mg.zeros((1, 1, 16)) # a square-wave signal
>>> x[..., 5:11] = 1
>>> k = mg.ones((1, 1, 5)) # a constant-valued kernel
>>> y = conv_nd(x, k, stride=1) # performing a stride-1, 1D convolution
>>> plt.title("conv(f, g); stride: 1")
>>> y.backward()
>>> plt.plot(x.data[0,0], label="f", ls="--", lw=3, drawstyle='steps-pre')
>>> plt.plot(kernel, label="g", ls="--", lw=3, drawstyle='steps-pre')
>>> plt.plot(y.data[0,0], label="f * g")
>>> plt.plot(mg.arange(16.), x.grad[0, 0], label="d[sum(f * g)]/df")
>>> kernel = mg.ones(5) # a square-wave signal
>>> plt.legend()
>>> plt.grid()
>>> plt.show()
Let's apply a edge-detection kernel to each color channel of an RGB image.
>>> import matplotlib.pyplot as plt
>>> import matplotlib.image as mpimg
>>> from mygrad.nnet.layers import conv_nd
>>> # A shape-(H, W, 3) RGB image
>>> img = mpimg.imread('../_static/meerkat.png')
>>> # We'll treat this like a batch of three greyscale images
>>> # where each "image" is actually a color channel
>>> # shape-(H, W, 3) -> shape-(3, 1, H, W)
>>> x = img.transpose(2, 0, 1)[:, None, :, :]
>>> # edge detection kernel
>>> kernel = np.array([[-1, -1, -1],
... [-1, 8, -1],
... [-1, -1, -1]])
>>> # (Hf, Wf) --> (1, 1, Hf, Wf)
>>> kernel = kernel.reshape(1, 1, *kernel.shape)
>>> # conv: (3, 1, H, W) w/ (1, 1, Hf, Wf) --> (3, 1, H', W')
>>> # squeeze + transpose: (3, 1, H', W') --> (H', W', 3)
>>> processed = conv_nd(x, kernel, stride=(1, 1))
>>> processed = processed.data.squeeze().transpose(1, 2, 0)
>>> fig, ax = plt.subplots()
>>> ax.imshow(img)
>>> fig, ax = plt.subplots()
>>> ax.imshow(processed)
.. plot::
>>> import matplotlib.pyplot as plt
>>> import matplotlib.image as mpimg
>>> from mygrad.nnet.layers import conv_nd
>>> img = mpimg.imread('../_static/meerkat.png')
>>> # edge detection
>>> kernel = np.array([[-1, -1, -1],
... [-1, 8, -1],
... [-1, -1, -1]])
>>> x = img.transpose(2,0,1)[:, None, :, :]
>>> # (Hf, Wf) --> (1, 1, Hf, Wf)
>>> kernel = kernel.reshape(1, 1, *kernel.shape)
>>> # conv: (C, 1, H, W) w/ (1, 1, Hf, Wf) --> (C, 1, H', W')
>>> # squeeze + transpose: (C, 1, H', W') --> (H', W', C)
>>> processed = conv_nd(x, kernel, stride=(1, 1)).data.squeeze().transpose(1, 2, 0)
>>> fig, ax = plt.subplots()
>>> ax.imshow(img)
>>> fig, ax = plt.subplots()
>>> ax.imshow(processed)
Now, let's demonstrate a more typical usage for ``conv_nd`` in the context of
neural networks. ``x`` will represent 10, 32x32 RGB images, and we will use
5 distinct 2x2 kernels to convolve over each of these images . Note that
each kernel must possess 3-channel - one for each RGB channel.
That is, we will be performing NxF channel-wise 2D convolutions. Supposing
that we don't want the kernel placements to overlap, we can use a stride of 2. In
total, this will produce a shape-:math:`(N=10, F=5, G_0=16, G_1=16)` tensor as a
result.
>>> import mygrad as mg
>>> x = mg.random.rand(10, 3, 32, 32)) # creating 10 random 32x32 RGB images
>>> k = mg.random.rand(5, 3, 2, 2)) # creating 5 random 3-channel 2x2 kernels
Given the shapes of ``x`` and ``k``, ``conv_nd`` automatically executes a 2D convolution:
>>> conv_nd(x, k, stride=2).shape
(10, 5, 16, 16)
Extrapolating further, ``conv_nd`` is capable of performing ND convolutions!
Performing a convolution over a batch of single-channel, "spatial-3D" tensor data:
>>> # shape-(N=1, C=1, X=10, Y=12, Z=10)
>>> x = mg.random.rand(1, 1, 10, 12, 10)
>>> # shape-(F=2, C=1, Wx=3, Wy=1, Wz=2)
>>> k = mg.random.rand(2, 1, 3, 1, 32)
>>> conv_nd(x, k, stride=1).shape
(1, 2, 8, 12, 9)
"""
if x.ndim < 3:
raise ValueError(
f"`x` must possess at least three " f"dimensions, got {x.ndim} dimensions"
)
if x.ndim != filter_bank.ndim:
raise ValueError(
f"`x` ({x.ndim}-dimensions) must have the same dimensionality as "
f"`filter_bank` ({filter_bank.ndim}-dimensions)"
)
if filter_bank.shape[1] != x.shape[1]:
raise ValueError(
f"`x.shape[1]` ({x.shape[1]}) must match `filter_bank.shape[1]` ({filter_bank.shape[1]})"
)
return Tensor._op(
ConvND,
x,
filter_bank,
op_kwargs={"stride": stride, "padding": padding, "dilation": dilation},
constant=constant,
)