Source code for geomstats.geometry.stratified.vectorization

"""Decorator to handle vectorization."""

import functools

import geomstats.backend as gs
from geomstats.geometry.stratified.point_set import PointBatch


[docs] def broadcast_lists(*lists): """Broadcast lists. Similar behavior as ``gs.broadcast_arrays``, but for lists. Parameters ---------- *lists : list Returns ------- *broadcasted_lists : list Lists with broadcasted length. """ lens = [len(list_) for list_ in lists] n_max = max(lens) if n_max == 1: return lists out = [] for len_, list_ in zip(lens, lists): if len_ == 1: out.append([list_[0]] * n_max) elif len_ == n_max: out.append(list_) else: raise Exception(f"Cannot broadcast lens: {lens}") return out
def _manipulate_input(arg, *args): """Transform input into list if not one. Parameters ---------- arg : any Argument to manipulate Returns ------- arg : any or list[any] transformed : bool If point was transformed. """ if not isinstance(arg, (list, tuple)): return [arg], True return arg, False def _manipulate_output_iterable(out): return PointBatch(out) def _manipulate_output( out, to_list, manipulate_output_iterable=_manipulate_output_iterable ): is_array = gs.is_array(out) is_iterable = isinstance(out, (list, tuple)) if not (gs.is_array(out) or is_iterable): return out if to_list: if is_array: return gs.array(out[0]) if is_iterable: return out[0] if is_iterable: return manipulate_output_iterable(out) return out
[docs] def vectorize_point( *args_positions, manipulate_input=_manipulate_input, manipulate_output=_manipulate_output, ): """Check point type and transform in iterable if not the case. Parameters ---------- args_positions : tuple Position and corresponding argument name. A tuple for each position. Notes ----- Explicitly defining args_positions and args names ensures it works for all combinations of input calling. """ def _dec(func): @functools.wraps(func) def _wrapped(*args, **kwargs): to_list = True args = list(args) for pos, name in args_positions: if name in kwargs: kwargs[name], to_list_ = manipulate_input(kwargs[name], name) else: args[pos], to_list_ = manipulate_input(args[pos], name) to_list = to_list and to_list_ out = func(*args, **kwargs) return manipulate_output(out, to_list) return _wrapped return _dec