Source code for geomstats.test.vectorization

import copy
import inspect
import itertools

from geomstats.vectorization import repeat_point

KNOWN_POINT_NAMES = (
    "point",
    "base_point",
    "point_a",
    "point_b",
    "fiber_point",
    "mat",
)
KNOWN_VECTOR_NAMES = (
    "vec",
    "vector",
    "tangent_vec",
    "tangent_vec_a",
    "tangent_vec_b",
    "tangent_vec_c",
    "tangent_vec_d",
    "cotangent_vec_a",
    "cotangent_vec_b",
    "direction",
)
KNOWN_ARGS = KNOWN_POINT_NAMES + KNOWN_VECTOR_NAMES


def _filter_combs(n_args, combs, vectorization_type):
    if vectorization_type == "sym" or n_args == 1:
        return combs

    repeat_indices = [int(val) for val in vectorization_type.split("-")[1:]]
    if len(repeat_indices) == n_args:
        return combs

    remove_indices = [val for val in range(n_args) if val not in repeat_indices]

    for comb in combs.copy():
        vals_remove = [comb[remove_index] for remove_index in remove_indices]
        if max(vals_remove) == 0 or min(vals_remove) == 1 and min(comb) == 1:
            continue

        combs.remove(comb)

    return combs


def _check_vectorization_type(vectorization_type, n_args):
    if vectorization_type in ("sym", "basic") or n_args == 1:
        return

    try:
        val = None
        for val in vectorization_type.split("-")[1:]:
            int(val)

        if val is None or int(val) > n_args:
            raise ValueError(
                "Unable to repeat unexisting args for vectorization type "
                f"`{vectorization_type}` and `n_args = {n_args}"
            )

    except ValueError:
        raise ValueError(
            f"Unable to understand vectorization type `{vectorization_type}`. "
            "Can handle `sym` and `repeat-(int)` format."
        )


def _generate_datum_vectorization_data(
    datum, comb_indices, arg_names, expected_name, n_reps=2
):
    if expected_name is not None:
        has_expected = True
        if isinstance(expected_name, str):
            expected_name = [expected_name]

        expected_combs = []
        for expected_name_ in expected_name:
            expected = datum.get(expected_name_)
            expected_combs.append([expected, repeat_point(expected, n_reps=n_reps)])
    else:
        has_expected = False

    args_combs = []
    for arg_name in arg_names:
        arg = datum.get(arg_name)
        args_combs.append([arg, repeat_point(arg, n_reps=n_reps, expand=True)])

    new_data = []
    for indices in comb_indices:
        new_datum = copy.copy(datum)

        if has_expected:
            rep = int(1 in indices)
            for expected_i, expected_name_ in enumerate(expected_name):
                new_datum[expected_name_] = expected_combs[expected_i][rep]

        for arg_i, (index, arg_name) in enumerate(zip(indices, arg_names)):
            new_datum[arg_name] = args_combs[arg_i][index]

        new_data.append(new_datum)

    return new_data


[docs] def generate_vectorization_data( data, arg_names, expected_name=None, n_reps=2, vectorization_type="sym", ): """Create new data with vectorized version of inputs. Parameters ---------- data : list of dict Data. Each to vectorize. arg_names: list Name of inputs to vectorize. expected_name: str or list of str Output name in case it needs to be repeated. n_reps: int Number of times the input points should be repeated. vectorization_type: str Possible values are 'sym', 'basic', or the format 'repeat-(int)' (e.g. "repeat-0-2"). 'sym': tests all repetition combinations. 'basic': tests only no repetition and repetition of all. 'repeat-(int)': tests repetition of provided indices. """ n_args = len(arg_names) _check_vectorization_type(vectorization_type, n_args) if vectorization_type == "basic": comb_indices = [tuple(i for _ in range(n_args)) for i in range(2)] else: comb_indices = list(itertools.product(*[range(2)] * n_args)) comb_indices = _filter_combs(n_args, comb_indices, vectorization_type) comb_indices.pop(0) new_data = [] for datum in data: new_data.extend( _generate_datum_vectorization_data( datum, comb_indices, arg_names, expected_name=expected_name, n_reps=n_reps, ) ) return new_data
def _generate_random_data(data_generator, arg_names): data = {} base_point = None for arg_name in arg_names: if arg_name in KNOWN_POINT_NAMES: base_point = data[arg_name] = data_generator.random_point() for arg_name in arg_names: if arg_name in KNOWN_VECTOR_NAMES: data[arg_name] = data_generator.random_tangent_vec(base_point) return data def _get_vectorization_type(test_case, arg_names): if test_case.tangent_to_multiple: return "sym" tangent_vec_type = "" for k, arg_name in enumerate(arg_names): if arg_name in KNOWN_VECTOR_NAMES: tangent_vec_type += f"-{k}" return "repeat" + tangent_vec_type if tangent_vec_type else "sym"
[docs] def test_vectorization(self, test_func, n_reps, atol): # TODO: move this to decorator? # TODO: accept kwargs? arg_names = list(inspect.signature(test_func).parameters.keys()) arg_names = list(filter(lambda x: x in KNOWN_ARGS, arg_names)) data = _generate_random_data(self.data_generator, arg_names) geometry = ( self.space.metric if hasattr(self, "is_metric") and self.is_metric else self.space ) data["expected"] = getattr(geometry, test_func.__name__[5:])(**data) data["atol"] = atol vec_data = generate_vectorization_data( data=[data], arg_names=arg_names, expected_name="expected", n_reps=n_reps, vectorization_type=_get_vectorization_type(self, arg_names), ) self._test_vectorization(vec_data, test_fnc_name=test_func.__name__)