Source code for geomstats.test.data
import random
[docs]
class TestData:
"""Class for TestData objects."""
fail_for_autodiff_exceptions = True
fail_for_not_implemented_errors = True
trials = 3
skip_vec = False
# TODO: remove after refactoring stratified
skip_all = False
xfails = ()
skips = ()
tolerances = {}
N_VEC_REPS = random.sample(range(2, 5), 1)
N_SHAPE_POINTS = [1] + random.sample(range(2, 5), 1)
N_RANDOM_POINTS = [1] + random.sample(range(2, 5), 1)
N_TIME_POINTS = [1] + random.sample(range(2, 5), 1)
[docs]
def generate_tests(self, test_data, marks=()):
"""Wrap test data with corresponding marks.
Parameters
----------
test_data : list or dict
marks : list
pytest marks,
Returns
-------
data: list or dict
Tests.
"""
tests = []
if not isinstance(marks, (list, tuple)):
marks = [marks]
for test_datum in test_data:
if isinstance(test_datum, dict):
if "marks" not in test_datum:
test_datum["marks"] = marks
else:
test_datum["marks"].extend(marks)
else:
test_datum = list(test_datum)
test_datum.append(marks)
tests.append(test_datum)
return tests
[docs]
def generate_random_data(self, marks=()):
data = [dict(n_points=n_points) for n_points in self.N_RANDOM_POINTS]
return self.generate_tests(data, marks=marks)
[docs]
def generate_vec_data(self, marks=()):
data = [dict(n_reps=n_reps) for n_reps in self.N_VEC_REPS]
return self.generate_tests(data, marks=marks)
[docs]
def generate_shape_data(self, marks=()):
data = [dict(n_points=n_points) for n_points in self.N_SHAPE_POINTS]
return self.generate_tests(data, marks=marks)
[docs]
def generate_vec_data_with_time(self, marks=()):
data = []
for n_reps in self.N_VEC_REPS:
for n_times in self.N_TIME_POINTS:
data.append(dict(n_reps=n_reps, n_times=n_times))
return self.generate_tests(data, marks=marks)
[docs]
def generate_random_data_with_time(self, marks=()):
data = []
for n_points in self.N_RANDOM_POINTS:
for n_times in self.N_TIME_POINTS:
data.append(dict(n_points=n_points, n_times=n_times))
return self.generate_tests(data, marks=marks)