Source code for mygrad.tensor_manip.tiling.funcs

from typing import Optional, Sequence, Union

from mygrad.tensor_base import Tensor, implements_numpy_override
from mygrad.typing import ArrayLike

from .ops import Repeat

__all__ = ["repeat"]


[docs]@implements_numpy_override() def repeat( a: ArrayLike, repeats: Union[int, Sequence[int]], axis: Optional[int] = None, *, constant: Optional[bool] = None, ) -> Tensor: """ Repeat elements of a tensor. This docstring was adapted from ``numpy.repeat`` Parameters ---------- a : ArrayLike Input tensor. repeats : Union[int, Sequence[int]] The number of repetitions for each element. ``repeats`` is broadcasted to fit the shape of the given axis. axis : Optional[int] The axis along which to repeat values. By default, use the flattened input array, and return a flat output tensor. constant : Optional[bool] If ``True``, this tensor is treated as a constant, and thus does not facilitate back propagation (i.e. ``constant.grad`` will always return ``None``). Defaults to ``False`` for float-type data. Defaults to ``True`` for integer-type data. Integer-type tensors must be constant. Returns ------- repeated_tensor : Tensor Output tensor which has the same shape as `a`, except along the given axis. Examples -------- >>> import mygrad as mg >>> mg.repeat(3, 4) Tensor([3, 3, 3, 3]) >>> x = mg.Tensor([[1, 2], [3, 4]]) >>> mg.repeat(x, 2) Tensor([1, 1, 2, 2, 3, 3, 4, 4]) >>> mg.repeat(x, 3, axis=1) Tensor([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> mg.repeat(x, [1, 2], axis=0) Tensor([[1, 2], [3, 4], [3, 4]]) """ return Tensor._op(Repeat, a, op_args=(repeats, axis), constant=constant)