Source code for geomstats.geometry.stratified.wald_space

r"""Classes for the Wald Space and elements therein of class Wald and helper classes.

Lead author: Jonas Lueg

References
----------
.. [Garba21] Garba, M. K., T. M. W. Nye, J. Lueg and S. F. Huckemann.
    "Information geometry for phylogenetic trees"
    Journal of Mathematical Biology, 82(3):19, February 2021a.
    https://doi.org/10.1007/s00285-021-01553-x.
.. [Lueg21] Lueg, J., M. K. Garba, T. M. W. Nye, S. F. Huckemann.
    "Wald Space for Phylogenetic Trees."
    Geometric Science of Information, Lecture Notes in Computer Science,
    pages 710–717, Cham, 2021.
    https://doi.org/10.1007/978-3-030-80209-7_76.
"""

import random
from abc import ABC

import geomstats.backend as gs
from geomstats.geometry.hermitian_matrices import powermh
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.spd_matrices import SPDMatrices
from geomstats.geometry.stratified.point_set import (
    Point,
    PointBatch,
    PointSet,
    PointSetMetric,
)
from geomstats.geometry.stratified.trees import (
    ForestTopology,
    Split,
    delete_splits,
    generate_splits,
)
from geomstats.geometry.stratified.vectorization import (
    _manipulate_output,
    broadcast_lists,
    vectorize_point,
)
from geomstats.numerics.interpolation import LinearInterpolator1D
from geomstats.numerics.optimization import ScipyMinimize


def _manipulate_input_with_array(arg, name):
    if gs.is_array(arg):
        ndim = 2 if name == "ambient_point" else 1
        if arg.ndim > ndim:
            return arg, False
        return gs.expand_dims(arg, axis=0), True

    if not isinstance(arg, (list, tuple)):
        return [arg], True

    return arg, False


def _manipulate_output_wald(out, to_list):
    return _manipulate_output(out, to_list, manipulate_output_iterable=WaldBatch)


[docs] def make_splits(n_labels): """Generate all possible splits of a collection. Parameters ---------- n_labels : int Number of labels. Returns ------- split_iterator : Iterator[Split] """ if n_labels <= 1: raise ValueError("`n_labels` must be greater than 1.") if n_labels == 2: yield Split(part1=[0], part2=[1]) else: for split in make_splits(n_labels=n_labels - 1): yield Split(part1=split.part1, part2=split.part2.union((n_labels - 1,))) yield Split(part1=split.part1.union((n_labels - 1,)), part2=split.part2) yield Split(part1=list(range(n_labels - 1)), part2=[n_labels - 1])
[docs] def make_topologies(n_labels): """Generate all possible sets of compatible splits of a collection. This only works well for `len(n_labels) < 8`. Parameters ---------- n_labels : int Number of labels. Returns ------- topology_iterator : Iterator[ForestTopology] """ if n_labels <= 1: raise ValueError("The collection must have 2 elements or more.") if n_labels in (2, 3): yield ForestTopology( partition=(tuple(range(n_labels)),), split_sets=(list(make_splits(n_labels)),), ) else: pendant_split = Split(part1=[n_labels - 1], part2=list(range(n_labels - 1))) for st in make_topologies(n_labels - 1): for s in st.split_sets[0]: new_split_set = [pendant_split] a, b = set(s.part1), set(s.part2) for t in st.split_sets[0]: _, d = set(t.part1), set(t.part2) if t != s: if a.issubset(d) or b.issubset(d): new_split_set.append( Split( part1=t.part1, part2=t.part2.union((n_labels - 1,)) ) ) else: new_split_set.append( Split( part1=t.part2, part2=t.part1.union((n_labels - 1,)) ) ) else: new_split_set.append( Split(part1=s.part1, part2=s.part2.union((n_labels - 1,))) ) new_split_set.append( Split(part1=s.part2, part2=s.part1.union((n_labels - 1,))) ) yield ForestTopology( partition=(tuple(range(n_labels)),), split_sets=(new_split_set,), )
def _generate_partition(n_labels, p_new): r"""Generate a random partition of :math:`\{0,\dots,n-1\}`. This algorithm works as follows: Start with a single set containing zero, then successively add the labels from 1 to n-1 to the partition in the following manner: for each label u, with probability `probability`, add the label u to a random existing set of the partition, else add a new singleton set {u} to the partition (i.e. with probability 1 - `probability`). Parameters ---------- p_new : float A float between 0 and 1, the probability that no new component is added, and 1 - probability that a new component is added. Returns ------- partition : list[list[int]] A partition of the set :math:`\{0,\dots,n-1\}` into non-empty sets. """ _partition = [[0]] for u in range(1, n_labels): if gs.random.rand(1) < p_new: index = random.randint(0, len(_partition) - 1) _partition[index].append(u) else: _partition.append([u]) return _partition
[docs] def generate_random_wald(n_labels, p_keep, p_new, btol=1e-8, check=True): """Generate a random instance of class ``Wald``. Parameters ---------- n_labels : int The number of labels the wald is generated with respect to. p_keep : float The probability will be inserted into the generation of a partition as well as for the generation of a split set for the topology of the wald. p_new : float A float between 0 and 1, the probability that no new component is added, and probability of 1 - p_new_ that a new component is added. btol: float Tolerance for the boundary of the coordinates in each grove. Defaults to 1e-08. check : bool If True, checks if splits still separate all labels. In this case, the split will not be deleted. If False, any split can be randomly deleted. Returns ------- random_wald : Wald The randomly generated wald. """ partition = _generate_partition(n_labels=n_labels, p_new=p_new) split_sets = [generate_splits(labels=_part) for _part in partition] split_sets = [ delete_splits(splits=splits, labels=part, p_keep=p_keep, check=check) for part, splits in zip(partition, split_sets) ] top = ForestTopology(partition=partition, split_sets=split_sets) x = gs.random.uniform(size=(top.n_splits,), low=0, high=1) x = gs.minimum(gs.maximum(btol, x), 1 - btol) return Wald(topology=top, weights=x)
[docs] class Wald(Point): r"""A class for wälder, that are phylogenetic forests, elements of the Wald Space. A wald is essentially a phylogenetic forest with weights between zero and one on the edges. The forest structure is stored as a ``ForestTopology`` and the edge weights are an array of length that is equal to the total number of splits in the structure. These elements are the points in Wald space and other phylogenetic forest spaces, like BHV space, although the partition is just the whole set of labels in this case. Parameters ---------- topology : ForestTopology The structure of the forest. weights : array-like, shape=[n_splits] The edge weights, array of floats between 0 and 1, with m entries, where m is the total number of splits/edges in the structure ``top``. corr : array-like, shape=[n_labels, n_labels] Correlation matrix of the topology with edge weights. """ def __init__(self, topology, weights): super().__init__() self.topology = topology self.weights = weights self.corr = self.topology.corr(weights) def __repr__(self): """Return the string representation of the wald. This string representation requires that one can retrieve all necessary information from the string. Returns ------- string_of_wald : str Return the string representation of the wald. """ return repr((self.topology, tuple(self.weights))) def __str__(self): """Return the fancy printable string representation of the wald. This string representation does NOT require that one can retrieve all necessary information from the string, but this string representation is required to be readable by a human. Returns ------- string_of_wald : str Return the fancy readable string representation of the wald. """ return f"({str(self.topology)};{str(self.weights)})" def _equal_single(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Wald Point to compare against point. atol : float Returns ------- is_equal : bool """ if self.topology != point.topology: return False return gs.all(gs.abs(self.weights - point.weights) < atol)
[docs] @vectorize_point((1, "point")) def equal(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Wald or WaldBatch Point to compare against point. atol : float Returns ------- is_equal : array-like, shape=[...] """ return gs.array([self._equal_single(point_, atol) for point_ in point])
[docs] class WaldBatch(PointBatch): """Wald batch.""" @property def topology(self): """Forest topology. Returns ------- topology : list[ForestTopology] """ return [point.topology for point in self] @property def weights(self): """Edge weights. Returns ------- weights : array-like, shape=[n_points, n_splits] """ return gs.stack([point.weights for point in self]) @property def corr(self): """Correlation matrix of the topology with edge weights. Returns ------- corr : array-like, shape=[n_points, n_nodes, n_nodes] """ return gs.stack([point.corr for point in self])
[docs] class WaldSpace(PointSet): r"""Class for the Wald space, a metric space for phylogenetic forests. A topological space. Points in Wald space are instances of the class :class:`Wald`: phylogenetic forests with edge weights between 0 and 1. In particular, Wald space is a stratified space, each stratum is called grove. The highest dimensional groves correspond to fully resolved or binary trees. The topology is obtained from embedding wälder into the ambient space of strictly positive :math:`n\times n` symmetric matrices, implemented in the class :class:`spd.SPDMatrices`. Parameters ---------- n_labels : int Integer determining the number of labels in the forests, and thus the shape of the correlation matrices: n_labels x n_labels. Attributes ---------- ambient_space : Manifold The ambient space, the positive definite n_labels x n_labels matrices that the WaldSpace is embedded into. """ def __init__(self, n_labels, ambient_space=None, equip=True): super().__init__(equip) self.n_labels = n_labels if ambient_space is None: ambient_space = SPDMatrices(n=self.n_labels) self.ambient_space = ambient_space
[docs] @staticmethod def default_metric(): """Metric to equip the space with if equip is True.""" return WaldSpaceMetric
def _belongs_single(self, point, atol=gs.atol): """Check if a point belongs to Tree space. Parameters ---------- point : Wald The point to be checked. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : bool Boolean denoting if point belongs to wald space. """ if not self.ambient_space.belongs(self.lift(point)): return False if gs.all(point.weights > 0) and gs.all(point.weights < 1): return True return False
[docs] @vectorize_point((1, "point")) def belongs(self, point, atol=gs.atol): """Check if a point `wald` belongs to Wald space. From FUTURE PUBLICATION we know that the corresponding matrix of a wald is strictly positive definite if and only if the labels are separated by at least one edge, which is exactly when the wald is an element of the Wald space. Parameters ---------- point : Wald or WaldBatch The point to be checked. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : array-like, shape=[...] Boolean denoting if `point` belongs to Wald space. """ return gs.array([self._belongs_single(point_, atol) for point_ in point])
[docs] def random_point(self, n_samples=1, p_tree=0.9, p_keep=0.9, btol=1e-8): """Sample a random point in Wald space. Parameters ---------- n_samples : int Number of samples. Defaults to 1. p_tree : float between 0 and 1 The probability that the sampled point is a tree, and not a forest. If the probability is equal to 1, then the sampled point will be a tree. Defaults to 0.9. p_keep : float between 0 and 1 The probability that a sampled edge is kept and not deleted randomly. To be precise, it is not exactly the probability, as some edges cannot be deleted since the requirement that two labels are separated by a split might be violated otherwise. Defaults to 0.9 btol: float Tolerance for the boundary of the coordinates in each grove. Defaults to 1e-08. Returns ------- samples : Wald or WaldBatch Points sampled in Wald space. """ p_new = p_tree ** (1 / (self.n_labels - 1)) forests = [ generate_random_wald(self.n_labels, p_keep, p_new, btol, check=True) for _ in range(n_samples) ] if n_samples == 1: return forests[0] return WaldBatch(forests)
[docs] def random_grove_point(self, topology, n_samples=1): """Sample a random point in a given grove of wald spcae. Parameters ---------- topology : ForestTopology n_samples : int Number of samples. Defaults to 1. Returns ------- samples : Wald or WaldBatch Points sampled in Wald space. """ n_splits = topology.n_splits weights = gs.random.uniform(size=(n_samples, n_splits)) forests = [Wald(topology, weights_) for weights_ in weights] if n_samples == 1: return forests[0] return WaldBatch(forests)
[docs] def lift(self, point): """Lift a point to the ambient space. Parameters ---------- point : Wald or WaldBatch The point to be lifted. Returns ------- lifted_point : array-like, shape=[..., n_labels, n_labels] Point in the ambient space. """ return point.corr
[docs] class WaldSpaceMetric(PointSetMetric): """Wald space metric. Parameters ---------- space : WaldSpace Set to equip with metric. projection_solver : ProjectionSolver Numerical solver to solve projection problem. geodesic_solver : GeodesicSolver Numerical solver to solve geodesic problem. """ def __init__(self, space, projection_solver=None, geodesic_solver=None): super().__init__(space) if projection_solver is None: projection_solver = LocalProjectionSolver(space) if geodesic_solver is None: geodesic_solver = SuccessiveProjectionGeodesicSolver(space) self.projection_solver = projection_solver self.geodesic_solver = geodesic_solver
[docs] def dist(self, point_a, point_b): """Distance between two points in the WaldSpace. Parameters ---------- point_a: Wald or WaldBatch Point in the WaldSpace. point_b: Wald or WaldBatch Point in the WaldSpace. Returns ------- distance : array-like, shape=[...] Distance. """ discrete_path = self.discrete_geodesic(point_a, point_b) is_list = not isinstance(discrete_path, WaldBatch) if not is_list: discrete_path = [discrete_path] dists = [ gs.sum( self._space.ambient_space.metric.dist( discrete_path_.corr[1:], discrete_path_.corr[:-1] ), axis=-1, ) for discrete_path_ in discrete_path ] if is_list: return gs.stack(dists) return dists[0]
[docs] def geodesic(self, initial_point, end_point): """Compute the geodesic in the WaldSpace. Parameters ---------- initial_point: Wald or WaldBatch Point in the WaldSpace. end_point: Wald or WaldBatch Point in the WaldSpace. Returns ------- path : callable Time parameterized geodesic curve. """ return self.geodesic_solver.geodesic(initial_point, end_point)
[docs] def discrete_geodesic(self, initial_point, end_point): """Compute a discrete geodesic in the WaldSpace. Parameters ---------- initial_point: Wald or WaldBatch Point in the WaldSpace. end_point: Wald or WaldBatch Point in the WaldSpace. Returns ------- geod_points : WaldBatch or list[WaldBatch] Time parameterized geodesic curve. """ return self.geodesic_solver.discrete_geodesic(initial_point, end_point)
[docs] def projection(self, ambient_point, topology, initial_weights=None): """Projects a point into Wald space. Parameters ---------- ambient_point : array-like, shape=[..., n_nodes, n_nodes] Ambient point to project. topology : ForestTopology or list[ForestTopology] Stratum topology. initial_weights : array-like, shape=[..., n_splits] Initial guess for weights. """ return self.projection_solver.projection( ambient_point, topology, initial_weights=initial_weights )
def _squared_dist_and_grad_affine(space, topology, ambient_point): """Squared distance and gradient wrt weights. See section 5.1 of [Garba2021]_. Parameters ---------- space : WaldSpace topology : ForestTopology ambient_point : array-like, shape=[n_nodes, n_nodes] Point wrt measure distance. Returns ------- value_and_grad : callable A callable that takes weights and outputs value and grad. """ sqrt_ambient_point, inv_sqrt_ambient_point = powermh( ambient_point, [1.0 / 2, -1.0 / 2] ) def _value_and_grad(weights): corr = topology.corr(weights) inv_corr = gs.linalg.inv(corr) grad = topology.corr_gradient(weights) target = space.ambient_space.metric.squared_dist(corr, ambient_point) target_grad = 2 * gs.trace( Matrices.mul( gs.linalg.logm( Matrices.mul(inv_sqrt_ambient_point, corr, inv_sqrt_ambient_point) ), sqrt_ambient_point, inv_corr, grad, inv_sqrt_ambient_point, ) ) return target, target_grad return _value_and_grad def _squared_dist_and_grad_euclidean(space, topology, ambient_point): """Squared distance and gradient wrt weights. Parameters ---------- space : WaldSpace topology : ForestTopology ambient_point : array-like, shape=[n_nodes, n_nodes] Point wrt measure distance. Returns ------- value_and_grad : callable A callable that takes weights and outputs value and grad. """ def _value_and_grad(weights): corr = topology.corr(weights) grad = topology.corr_gradient(weights) target = space.ambient_space.metric.squared_dist(corr, ambient_point) target_grad = 2 * gs.sum((corr - ambient_point) * grad, axis=(-2, -1)) return target, target_grad return _value_and_grad def _squared_dist_and_grad_autodiff(space, topology, ambient_point): """Squared distance and gradient wrt weights. Parameters ---------- space : WaldSpace topology : ForestTopology ambient_point : array-like, shape=[n_nodes, n_nodes] Point wrt measure distance. Returns ------- value_and_grad : callable A callable that takes weights and outputs value and grad. """ def _value(weights): corr = topology.corr(weights) return space.ambient_space.metric.squared_dist(corr, ambient_point) return gs.autodiff.value_and_grad(_value) _AMBIENT_METRIC_TO_SQUARED_DIST_GRAD = { "SPDAffineMetric": _squared_dist_and_grad_affine, "SPDEuclideanMetric": _squared_dist_and_grad_euclidean, }
[docs] class LocalProjectionSolver: """Local projection solver.""" def __init__(self, space, btol=1e-10): self._space = space self.btol = btol self.optimizer = ScipyMinimize( method="L-BFGS-B", tol=gs.atol, ) def _get_bounds(self, n_splits): return [(self.btol, 1 - self.btol)] * n_splits def _projection_single(self, ambient_point, topology, initial_weights=None): """Project ambient point into wald space. Parameters ---------- ambient_point : array-like, shape=[n_nodes, n_nodes] Ambient point to project. topology : ForestTopology Stratum topology. initial_weights : array-like, shape=[n_splits] Initial guess for weights. """ if len(topology.partition) == topology.n_labels: return Wald(topology=topology, weights=gs.ones(self.n_labels)) value_and_grad = _AMBIENT_METRIC_TO_SQUARED_DIST_GRAD.get( self._space.ambient_space.metric.__class__.__name__, _squared_dist_and_grad_autodiff, )( space=self._space, ambient_point=ambient_point, topology=topology, ) n_splits = topology.n_splits initial_weights = ( gs.ones(n_splits) * 0.5 if initial_weights is None else initial_weights ) self.optimizer.bounds = self._get_bounds(n_splits) res = self.optimizer.minimize(value_and_grad, initial_weights, fun_jac=True) if res.status != 0: raise ValueError("Projection failed!") weights = gs.array( [ _x if self.btol < _x < 1 - self.btol else 0 if _x <= self.btol else 1 for _x in res.x ] ) return Wald(topology=topology, weights=weights)
[docs] @vectorize_point( (1, "ambient_point"), (2, "topology"), (3, "initial_weights"), manipulate_input=_manipulate_input_with_array, manipulate_output=_manipulate_output_wald, ) def projection(self, ambient_point, topology, initial_weights=None): """Project ambient point into wald space. Parameters ---------- ambient_point : array-like, shape=[..., n_nodes, n_nodes] Ambient point to project. topology : ForestTopology or list[ForestTopology] Stratum topology. initial_weights : array-like, shape=[..., n_splits] Initial guess for weights. """ ambient_point, topology, initial_weights = broadcast_lists( ambient_point, topology, initial_weights ) return [ self._projection_single(ambient_point_, topology_, initial_weights_) for ambient_point_, topology_, initial_weights_ in zip( ambient_point, topology, initial_weights ) ]
[docs] class BasicWaldGeodesicSolver(ABC): """Abstract class for wald geodesic solver. Parameters ---------- space : WaldSpace """ def __init__(self, space): self._space = space
[docs] @vectorize_point( (1, "initial_point"), (2, "end_point"), manipulate_output=lambda out, to_list: _manipulate_output( out, to_list, manipulate_output_iterable=lambda x: x ), ) def discrete_geodesic(self, initial_point, end_point): """Compute a discrete geodesic in the WaldSpace. Parameters ---------- initial_point: Wald or WaldBatch Point in the WaldSpace. initial_point: Wald or WaldBatch Point in the WaldSpace. Returns ------- geod_points : WaldBatch or list[WaldBatch] Time parameterized geodesic curve. """ initial_point, end_point = broadcast_lists(initial_point, end_point) return [ self._discrete_geodesic_single( initial_point_, end_point_, ) for initial_point_, end_point_ in zip(initial_point, end_point) ]
[docs] def geodesic(self, initial_point, end_point): """Compute the geodesic in the WaldSpace. Parameters ---------- initial_point: Wald or WaldBatch Point in the WaldSpace. initial_point: Wald or WaldBatch Point in the WaldSpace. Returns ------- path : callable Time parameterized geodesic curve. """ def _vec(t, fncs): if len(fncs) == 1: return fncs[0](t) return [fnc(t) for fnc in fncs] discrete_path = self.discrete_geodesic(initial_point, end_point) if isinstance(discrete_path, WaldBatch): return DiscreteWaldPath(self._space, discrete_path) return lambda t: _vec( t, fncs=[ DiscreteWaldPath(self._space, discrete_path_) for discrete_path_ in discrete_path ], )
[docs] class NaiveProjectionGeodesicSolver(BasicWaldGeodesicSolver): """Naive geodesic projection solver. Implementation of algorithm 1 from [Lueg21]_. """ def __init__(self, space, n_grid=10): super().__init__(space) self.n_grid = n_grid def _discrete_geodesic_single(self, initial_point, end_point): """Compute a discrete geodesic in the WaldSpace. Parameters ---------- initial_point: Wald Point in the WaldSpace. end_point: Wald Point in the WaldSpace. Returns ------- geod_points : WaldBatch Time parameterized geodesic curve. """ if initial_point.topology != end_point.topology: raise ValueError("Can only handle points in the same grove.") topology = initial_point.topology initial_corr = self._space.lift(initial_point) end_corr = self._space.lift(end_point) ambient_geod_func = self._space.ambient_space.metric.geodesic( initial_corr, end_point=end_corr ) time = gs.linspace(0, 1, self.n_grid)[1:-1] mid_ambient_point = ambient_geod_func(time) mid_points = self._space.metric.projection(mid_ambient_point, topology) return WaldBatch([initial_point] + mid_points + [end_point])
[docs] class SuccessiveProjectionGeodesicSolver(BasicWaldGeodesicSolver): """Successive projection geodesic projection solver. Implementation of algorithm 2 from [Lueg21]_. """ def __init__(self, space, n_grid=10): super().__init__(space) self.n_grid = n_grid def _discrete_geodesic_single(self, initial_point, end_point): """Compute a discrete geodesic in the WaldSpace. Parameters ---------- initial_point: Wald Point in the WaldSpace. end_point: Wald Point in the WaldSpace. Returns ------- geod_points : WaldBatch Time parameterized geodesic curve. """ if initial_point.topology != end_point.topology: raise ValueError("Can only handle points in the same grove.") topology = initial_point.topology left_points = [initial_point] right_points = [end_point] for i in range(2, self.n_grid // 2 + 1): time = 1 / (self.n_grid - i + 1) left_corr = self._space.lift(left_points[-1]) right_corr = self._space.lift(right_points[-1]) ambient_geod_func = self._space.ambient_space.metric.geodesic( left_corr, right_corr ) left_points.append( self._space.metric.projection( ambient_geod_func(time)[0], topology, initial_weights=left_points[-1].weights, ) ) right_points.append( self._space.metric.projection( ambient_geod_func(1 - time)[0], topology, initial_weights=right_points[-1].weights, ) ) right_points.reverse() if gs.mod(self.n_grid, 2) == 0: return WaldBatch(left_points + right_points) left_corr = self._space.lift(left_points[-1]) right_corr = self._space.lift(right_points[0]) ambient_geod_func = self._space.ambient_space.metric.geodesic( left_corr, end_point=right_corr ) mid_point = self._space.metric.projection(ambient_geod_func(0.5), topology) return WaldBatch(left_points + mid_point + right_points)
[docs] class DiscreteWaldPath: """A uniformly-sampled discrete path. Parameters ---------- space : WaldSpace path : WaldBatch Wald collection with common topology. interpolator : Interpolator1D """ def __init__(self, space, path, interpolator=None, **interpolator_kwargs): # NB: assumes common topology self._topology = path[0].topology times = self._get_times(space, path) if interpolator is None: interpolator = LinearInterpolator1D( times, path.weights, point_ndim=1, **interpolator_kwargs ) self.interpolator = interpolator def _get_times(self, space, path): """Compute times for linear interpolation.""" dists = space.ambient_space.metric.dist(path[1:].corr, path[:-1].corr) cum_dists = gs.cumsum(dists) return gs.concatenate([gs.array([0.0]), cum_dists / cum_dists[-1]]) def __call__(self, t): """Interpolate path. Parameters ---------- t : array-like, shape=[n_time] Interpolation time. Returns ------- path : WaldBatch """ weights = self.interpolator(t) return WaldBatch([Wald(self._topology, weights_) for weights_ in weights])