# Source code for geomstats.vectorization

```"""Decorator to handle vectorization.

This abstracts the backend type.
"""

import math

import geomstats.backend as gs

def _get_max_ndim_point(*point):
"""Identify point with higher dimension.

Parameters
----------
point : array-like

Returns
-------
max_ndim_point : array-like
Point with higher dimension.
"""
max_ndim_point = point[0]
for point_ in point[1:]:
if point_.ndim > max_ndim_point.ndim:
max_ndim_point = point_

return max_ndim_point

[docs]
def get_n_points(point_ndim, *point):
"""Compute the number of points.

Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like
Point belonging to the space.

Returns
-------
n_points : int
Number of points.
"""
point_max_ndim = _get_max_ndim_point(*point)
return math.prod(point_max_ndim.shape[:-point_ndim])

[docs]
def check_is_batch(point_ndim, *point):
"""Check if inputs are batch.

Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like
Point belonging to the space.

Returns
-------
is_batch : bool
Returns True if point contains several points.
"""
return any(point_.ndim > point_ndim for point_ in point)

[docs]
def get_batch_shape(point_ndim, *point):
"""Get batch shape.

Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like or None
Point belonging to the space.

Returns
-------
batch_shape : tuple
Returns the shape related with batch. () if only one point.
"""
point = list(filter(_is_not_none, point))
if len(point) == 0:
return ()
point_max_ndim = _get_max_ndim_point(*point)
return point_max_ndim.shape[:-point_ndim]

[docs]
def repeat_point(point, n_reps=2, expand=False):
"""Repeat point.

Parameters
----------
point : array-like
Point of a space.
n_reps : int
Number of times the point should be repeated.
expand : bool
Repeat even if n_reps == 1.

Returns
-------
rep_point : array-like
point repeated n_reps times.
"""
if not expand and n_reps == 1:
return gs.copy(point)

return gs.repeat(gs.expand_dims(point, 0), n_reps, axis=0)

def _is_not_none(value):
"""Check if a value is None."""
return value is not None

[docs]
def repeat_out(point_ndim, out, *point, out_shape=()):
"""Repeat out shape after finding batch shape.

Parameters
----------
point_ndim : int
Point number of array dimensions.
out : array-like
Output to be repeated
point : array-like or None
Point belonging to the space.
out_shape : tuple
Indicates out shape for no batch computations.

Returns
-------
out : array-like
If no batch, then input is returned. Otherwise it is broadcasted.
"""
point = filter(_is_not_none, point)
batch_shape = get_batch_shape(point_ndim, *point)
if out.shape[: -len(out_shape)] != batch_shape:
return out

[docs]
def repeat_out_multiple_ndim(
out, point_ndim_1, points_1, point_ndim_2, points_2, out_ndim=0
):
"""Repeat out after finding batch shape.

Differs from `repeat_out` by accepting two sets of point_ndim arrays.

Parameters
----------
out : array-like
Output to be repeated
point_ndim_1 : int
Point number of array dimensions.
points_1 : tuple[array-like or None]
Arrays of dimension point_ndim_1 or higher.
point_ndim_2 : int
Point number of array dimensions.
points_2 : tuple[array-like or None]
Arrays of dimension point_ndim_2 or higher.
out_ndim : int
Out number of array dimensions.

Returns
-------
out : array-like
If no batch, then input is returned. Otherwise it is broadcasted.
"""
batch_shape = get_batch_shape(point_ndim_1, *points_1)
if not batch_shape:
batch_shape = get_batch_shape(point_ndim_2, *points_2)

out_shape = out.shape[-out_ndim:]
if out.shape[:-out_ndim] != batch_shape:
return out

[docs]

Gives to both arrays batch shape `batch_shape_b + batch_shape_a`.

Does nothing if one of the batch shapes is empty.

Parameters
----------
batch_shape_a : tuple
Batch shape of array_a.
batch_shape_b : tuple
Batch shape of array_b.
array_a : array
array_b : array
"""
multi_b = len(array_b) > 1

if not batch_shape_a or not batch_shape_b:
return (array_a, array_b) if multi_b else (array_a, array_b[0])

array_a_ = gs.broadcast_to(array_a, batch_shape_b + array_a.shape)

n_batch_b = len(batch_shape_b)
indices_in = list(range(len(batch_shape_a)))
indices_out = [index + n_batch_b for index in indices_in]

array_b_ = []
for array in array_b:
array_b_aux = gs.broadcast_to(array, batch_shape_a + array.shape)
array_b_.append(gs.moveaxis(array_b_aux, indices_in, indices_out))

return (array_a_, array_b_) if multi_b else (array_a_, array_b_[0])

```