Source code for geomstats.vectorization
"""Decorator to handle vectorization.
This abstracts the backend type.
"""
import math
import geomstats.backend as gs
def _get_max_ndim_point(*point):
"""Identify point with higher dimension.
Parameters
----------
point : array-like
Returns
-------
max_ndim_point : array-like
Point with higher dimension.
"""
max_ndim_point = point[0]
for point_ in point[1:]:
if point_.ndim > max_ndim_point.ndim:
max_ndim_point = point_
return max_ndim_point
[docs]
def get_n_points(point_ndim, *point):
"""Compute the number of points.
Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like
Point belonging to the space.
Returns
-------
n_points : int
Number of points.
"""
point_max_ndim = _get_max_ndim_point(*point)
return math.prod(point_max_ndim.shape[:-point_ndim])
[docs]
def check_is_batch(point_ndim, *point):
"""Check if inputs are batch.
Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like
Point belonging to the space.
Returns
-------
is_batch : bool
Returns True if point contains several points.
"""
return any(point_.ndim > point_ndim for point_ in point)
[docs]
def get_batch_shape(point_ndim, *point):
"""Get batch shape.
Parameters
----------
point_ndim : int
Point number of array dimensions.
point : array-like or None
Point belonging to the space.
Returns
-------
batch_shape : tuple
Returns the shape related with batch. () if only one point.
"""
point = list(filter(_is_not_none, point))
if len(point) == 0:
return ()
point_max_ndim = _get_max_ndim_point(*point)
return point_max_ndim.shape[:-point_ndim]
[docs]
def repeat_point(point, n_reps=2, expand=False):
"""Repeat point.
Parameters
----------
point : array-like
Point of a space.
n_reps : int
Number of times the point should be repeated.
expand : bool
Repeat even if n_reps == 1.
Returns
-------
rep_point : array-like
point repeated n_reps times.
"""
if not expand and n_reps == 1:
return gs.copy(point)
return gs.repeat(gs.expand_dims(point, 0), n_reps, axis=0)
def _is_not_none(value):
"""Check if a value is None."""
return value is not None
[docs]
def repeat_out(point_ndim, out, *point, out_shape=()):
"""Repeat out shape after finding batch shape.
Parameters
----------
point_ndim : int
Point number of array dimensions.
out : array-like
Output to be repeated
point : array-like or None
Point belonging to the space.
out_shape : tuple
Indicates out shape for no batch computations.
Returns
-------
out : array-like
If no batch, then input is returned. Otherwise it is broadcasted.
"""
point = filter(_is_not_none, point)
batch_shape = get_batch_shape(point_ndim, *point)
if out.shape[: -len(out_shape)] != batch_shape:
return gs.broadcast_to(out, batch_shape + out_shape)
return out
[docs]
def repeat_out_multiple_ndim(
out, point_ndim_1, points_1, point_ndim_2, points_2, out_ndim=0
):
"""Repeat out after finding batch shape.
Differs from `repeat_out` by accepting two sets of point_ndim arrays.
Parameters
----------
out : array-like
Output to be repeated
point_ndim_1 : int
Point number of array dimensions.
points_1 : tuple[array-like or None]
Arrays of dimension point_ndim_1 or higher.
point_ndim_2 : int
Point number of array dimensions.
points_2 : tuple[array-like or None]
Arrays of dimension point_ndim_2 or higher.
out_ndim : int
Out number of array dimensions.
Returns
-------
out : array-like
If no batch, then input is returned. Otherwise it is broadcasted.
"""
batch_shape = get_batch_shape(point_ndim_1, *points_1)
if not batch_shape:
batch_shape = get_batch_shape(point_ndim_2, *points_2)
out_shape = out.shape[-out_ndim:]
if out.shape[:-out_ndim] != batch_shape:
return gs.broadcast_to(out, batch_shape + out_shape)
return out
[docs]
def broadcast_to_multibatch(batch_shape_a, batch_shape_b, array_a, *array_b):
"""Broadcast to multibatch.
Gives to both arrays batch shape `batch_shape_b + batch_shape_a`.
Does nothing if one of the batch shapes is empty.
Parameters
----------
batch_shape_a : tuple
Batch shape of array_a.
batch_shape_b : tuple
Batch shape of array_b.
array_a : array
array_b : array
"""
multi_b = len(array_b) > 1
if not batch_shape_a or not batch_shape_b:
return (array_a, array_b) if multi_b else (array_a, array_b[0])
array_a_ = gs.broadcast_to(array_a, batch_shape_b + array_a.shape)
n_batch_b = len(batch_shape_b)
indices_in = list(range(len(batch_shape_a)))
indices_out = [index + n_batch_b for index in indices_in]
array_b_ = []
for array in array_b:
array_b_aux = gs.broadcast_to(array, batch_shape_a + array.shape)
array_b_.append(gs.moveaxis(array_b_aux, indices_in, indices_out))
return (array_a_, array_b_) if multi_b else (array_a_, array_b_[0])