Source code for geomstats.vectorization

"""Decorator to handle vectorization.

This abstracts the backend type.
"""

import math

import geomstats.backend as gs

POINT_TYPES = ["scalar", "vector", "matrix"]
FLEXIBLE_TYPE = "point"
OTHER_TYPES = ["point_type", "else"]

POINT_TYPES_TO_NDIMS = {"scalar": 2, "vector": 2, "matrix": 3}

ERROR_MSG = "Invalid type: %s."


def _get_max_ndim_point(*point):
    """Identify point with higher dimension.

    Parameters
    ----------
    point : array-like

    Returns
    -------
    max_ndim_point : array-like
        Point with higher dimension.
    """
    max_ndim_point = point[0]
    for point_ in point[1:]:
        if point_.ndim > max_ndim_point.ndim:
            max_ndim_point = point_

    return max_ndim_point


[docs] def get_n_points(point_ndim, *point): """Compute the number of points. Parameters ---------- point_ndim : int Point number of array dimensions. point : array-like Point belonging to the space. Returns ------- n_points : int Number of points. """ point_max_ndim = _get_max_ndim_point(*point) return math.prod(point_max_ndim.shape[:-point_ndim])
[docs] def check_is_batch(point_ndim, *point): """Check if inputs are batch. Parameters ---------- point_ndim : int Point number of array dimensions. point : array-like Point belonging to the space. Returns ------- is_batch : bool Returns True if point contains several points. """ return any(point_.ndim > point_ndim for point_ in point)
[docs] def get_batch_shape(point_ndim, *point): """Get batch shape. Parameters ---------- point_ndim : int Point number of array dimensions. point : array-like or None Point belonging to the space. Returns ------- batch_shape : tuple Returns the shape related with batch. () if only one point. """ point = list(filter(_is_not_none, point)) if len(point) == 0: return () point_max_ndim = _get_max_ndim_point(*point) return point_max_ndim.shape[:-point_ndim]
[docs] def repeat_point(point, n_reps=2, expand=False): """Repeat point. Parameters ---------- point : array-like Point of a space. n_reps : int Number of times the point should be repeated. expand : bool Repeat even if n_reps == 1. Returns ------- rep_point : array-like point repeated n_reps times. """ if not expand and n_reps == 1: return gs.copy(point) return gs.repeat(gs.expand_dims(point, 0), n_reps, axis=0)
def _is_not_none(value): """Check if a value is None.""" return value is not None
[docs] def repeat_out(point_ndim, out, *point, out_shape=()): """Repeat out shape after finding batch shape. Parameters ---------- point_ndim : int Point number of array dimensions. out : array-like Output to be repeated point : array-like or None Point belonging to the space. out_shape : tuple Indicates out shape for no batch computations. Returns ------- out : array-like If no batch, then input is returned. Otherwise it is broadcasted. """ point = filter(_is_not_none, point) batch_shape = get_batch_shape(point_ndim, *point) if out.shape[: -len(out_shape)] != batch_shape: return gs.broadcast_to(out, batch_shape + out_shape) return out
[docs] def get_types(input_types, args, kwargs): """Extract the types of args, kwargs, optional kwargs and output. Parameters ---------- input_types : list List of inputs' input_types, including for optional inputs. args : tuple Args of a function. kwargs : dict Kwargs of a function. Returns ------- args_types : list Types of args. kwargs_types : list Types of kwargs. opt_kwargs_types : list Types of optional kwargs. is_scal : bool Boolean determining if the output is a scalar. """ len_args = len(args) len_kwargs = len(kwargs) len_total = len_args + len_kwargs args_types = input_types[:len_args] kwargs_types = input_types[len_args:len_total] opt_kwargs_types = [] is_scal = True if len(input_types) > len_total: opt_kwargs_types = input_types[len_total:] last_input_type = input_types[-1] if "output_" in last_input_type and last_input_type != "output_scalar": is_scal = False opt_kwargs_types = input_types[len_total:-1] return (args_types, kwargs_types, opt_kwargs_types, is_scal)
[docs] def adapt_types(args_types, kwargs_types, opt_kwargs_types, args, kwargs): """Adapt the list of input input_types. Some functions are implemented with array-like arguments that can be either 'vector' or 'matrix' depending on the value of the 'point_type' argument. This function reads the 'point_type' argument, and adapt the actual type of the input array-like arguments. Parameters ---------- args_types : list Types of args. kwargs_types : list Types of kwargs. opt_kwargs_types : list Types of optional kwargs. args : tuple Args of a function. kwargs : dict Kwargs of a function. Returns ------- args_types : list Adapted types of args. kwargs_types : list Adapted types of kwargs. """ in_args = "point_type" in args_types in_kwargs = "point_type" in kwargs_types in_optional = "point_type" in opt_kwargs_types if in_args or in_kwargs or in_optional: if in_args: i_input_type = args_types.index("point_type") input_type = args[i_input_type] elif in_kwargs: input_type = kwargs["point_type"] elif in_optional: obj = args[0] input_type = obj.default_point_type kwargs["point_type"] = input_type kwargs_types.append("point_type") args_types = [input_type if pt == FLEXIBLE_TYPE else pt for pt in args_types] kwargs_types = [ input_type if pt == FLEXIBLE_TYPE else pt for pt in kwargs_types ] return args_types, kwargs_types, kwargs
[docs] def get_initial_shapes(input_types, args): """Extract shapes and ndims of input args or kwargs values. Store the shapes of the input args, or kwargs values, that are array-like, store None otherwise. Parameters ---------- input_types : list Point types corresponding to the args, or kwargs values. args : tuple or dict_values Args, or kwargs values, of a function. Returns ------- in_shapes : list Shapes of array-like input args, or kwargs values. """ in_shapes = [] for arg, input_type in zip(args, input_types): if input_type == "scalar": arg = gs.array(arg) if input_type in POINT_TYPES and arg is not None: in_shapes.append(gs.shape(arg)) elif input_type in OTHER_TYPES or arg is None: in_shapes.append(None) else: raise ValueError(ERROR_MSG % input_type) return in_shapes
[docs] def vectorize_args(input_types, args): """Vectorize input args. Transform input array-like args into their fully-vectorized form, where "fully-vectorized" means that: - one scalar has shape [1, 1], - n scalars have shape [n, 1], - one d-D vector has shape [1, d], - n d-D vectors have shape [n, d], etc. Parameters ---------- input_types : list Point types corresponding to the args. args : tuple Args of a function. Returns ------- vect_args : tuple Args of the function in their fully-vectorized form. """ vect_args = [] for arg, input_type in zip(args, input_types): if input_type == "scalar": vect_arg = gs.to_ndarray(arg, to_ndim=1) vect_arg = gs.to_ndarray(vect_arg, to_ndim=2, axis=1) elif input_type in POINT_TYPES and arg is not None: vect_arg = gs.to_ndarray(arg, to_ndim=POINT_TYPES_TO_NDIMS[input_type]) elif input_type in OTHER_TYPES or arg is None: vect_arg = arg else: raise ValueError(ERROR_MSG % input_type) vect_args.append(vect_arg) return tuple(vect_args)
[docs] def vectorize_kwargs(input_types, kwargs): """Vectorize input kwargs. Transform input array-like kwargs into their fully-vectorized form, where "fully-vectorized" means that: - one scalar has shape [1, 1], - n scalars have shape [n, 1], - one d-D vector has shape [1, d], - n d-D vectors have shape [n, d], etc. Parameters ---------- input_types :list Point types corresponding to the args. kwargs : dict Kwargs of a function. Returns ------- vect_kwargs : dict Kwargs of the function in their fully-vectorized form. """ vect_kwargs = {} for key_arg, input_type in zip(kwargs.keys(), input_types): arg = kwargs[key_arg] if input_type == "scalar": vect_arg = gs.to_ndarray(arg, to_ndim=1) vect_arg = gs.to_ndarray(vect_arg, to_ndim=2, axis=1) elif input_type in POINT_TYPES and arg is not None: vect_arg = gs.to_ndarray(arg, to_ndim=POINT_TYPES_TO_NDIMS[input_type]) elif input_type in OTHER_TYPES or arg is None: vect_arg = arg else: raise ValueError(ERROR_MSG % input_type) vect_kwargs[key_arg] = vect_arg return vect_kwargs
[docs] def adapt_result(result, initial_shapes, args_kwargs_types, is_scal): """Adapt shape of output. This function squeezes the dim 0 or 1 of the output, depending on: - the type of the output: scalar vs else, - the initial shapes or args and kwargs provided by the user. Parameters ---------- result : unspecified Output of the function. initial_shapes : list Shapes of args and kwargs provided by the user. args_kwargs_types : list Types of args and kwargs. is_scal : bool Boolean determining if the output 'result' is a scalar. Returns ------- result : unspecified Output of the function, with adapted shape. """ if squeeze_output_dim_1(result, initial_shapes, args_kwargs_types, is_scal): if result.shape[1] == 1: result = gs.squeeze(result, axis=1) if ( squeeze_output_dim_0(result, initial_shapes, args_kwargs_types) and result.shape[0] == 1 ): result = gs.squeeze(result, axis=0) return result
[docs] def squeeze_output_dim_0(result, in_shapes, input_types): """Determine if the output needs to be squeezed on dim 0. The dimension 0 is squeezed iff all input parameters: - contain one sample, - have the corresponding dimension 0 squeezed, i.e. if all input parameters have ndim strictly less than the ndim corresponding to their vectorized shape. Parameters ---------- in_ndims : list Initial ndims of input parameters, as entered by the user. input_types : list Associated list of input_type of input parameters. Returns ------- squeeze : bool Boolean deciding whether to squeeze dim 0 of the output. """ if isinstance(result, tuple): return False if isinstance(result, list): return False for in_shape, input_type in zip(in_shapes, input_types): if input_type not in POINT_TYPES: continue in_ndim = None if in_shape is not None: in_ndim = len(in_shape) if in_ndim is not None: vect_ndim = POINT_TYPES_TO_NDIMS[input_type] if in_ndim > vect_ndim: raise ValueError( "Fully-vectorizing an input can only increase its ndim." ) if in_ndim == vect_ndim: return False return True
[docs] def squeeze_output_dim_1(result, in_shapes, input_types, is_scal=True): """Determine if the output needs to be squeezed on dim 1. This happens if the user represents scalars as array of shapes: [n_samples,] instead of [n_samples, 1] Dimension 1 is squeezed by default if input_type is 'scalar'. Dimension 1 is not squeezed if the user inputs at least one scalar with a singleton in dimension 1. Parameters ---------- result: array-like Result output by the function, before reshaping. in_shapes : list Initial shapes of input parameters, as entered by the user. input_types : list Associated list of input_type of input parameters. Returns ------- squeeze : bool Boolean deciding whether to squeeze dim 1 of the output. """ if not is_scal: return False if not is_scalar(result): return False for shape, input_type in zip(in_shapes, input_types): if input_type == "scalar": ndim = len(shape) if ndim > 2: raise ValueError("The ndim of a scalar cannot be > 2.") if ndim == 2: return False return True
[docs] def is_scalar(vect_array): """Test if a "fully-vectorized" array represents a scalar. Parameters ---------- vect_array : array-like Array to be tested. Returns ------- is_scalar : bool Boolean determining if vect_array is a fully-vectorized scalar. """ if isinstance(vect_array, tuple): return False if isinstance(vect_array, list): return False has_ndim_2 = vect_array.ndim == 2 if not has_ndim_2: return False has_singleton_dim_1 = vect_array.shape[1] == 1 return has_singleton_dim_1
[docs] def broadcast_to_multibatch(batch_shape_a, batch_shape_b, array_a, *array_b): """Broadcast to multibatch. Gives to both arrays batch shape `batch_shape_b + batch_shape_a`. Does nothing if one of the batch shapes is empty. Parameters ---------- batch_shape_a : tuple Batch shape of array_a. batch_shape_b : tuple Batch shape of array_b. array_a : array array_b : array """ multi_b = len(array_b) > 1 if not batch_shape_a or not batch_shape_b: return (array_a, array_b) if multi_b else (array_a, array_b[0]) array_a_ = gs.broadcast_to(array_a, batch_shape_b + array_a.shape) n_batch_b = len(batch_shape_b) indices_in = list(range(len(batch_shape_a))) indices_out = [index + n_batch_b for index in indices_in] array_b_ = [] for array in array_b: array_b_aux = gs.broadcast_to(array, batch_shape_a + array.shape) array_b_.append(gs.moveaxis(array_b_aux, indices_in, indices_out)) return (array_a_, array_b_) if multi_b else (array_a_, array_b_[0])