Source code for geomstats.geometry.stratified.bhv_space

"""Class for the BHV Tree Space.

Lead author: Jonas Lueg

References
----------
.. [BHV01] Billera, L. J., S. P. Holmes, K. Vogtmann.
    "Geometry of the Space of Phylogenetic Trees."
    Advances in Applied Mathematics,
    volume 27, issue 4, pages 733-767, 2001.
    https://doi.org/10.1006%2Faama.2001.0759
.. [OP11] Owen, M., J. S. Provan.
    "A Fast Algorithm for Computing Geodesic Distances in Tree Space."
    IEEE/ACM Transactions on Computational Biology and Bioinformatics,
    volume 8, issue 1, pages 2-13, 2011.
    https://doi.org/10.1109/TCBB.2010.3
"""

import itertools as it

import networkx as nx
import numpy as np

import geomstats.backend as gs
from geomstats.geometry.stratified.point_set import (
    Point,
    PointBatch,
    PointSet,
    PointSetMetric,
)
from geomstats.geometry.stratified.trees import (
    ForestTopology,
    Split,
    delete_splits,
    generate_splits,
)
from geomstats.geometry.stratified.vectorization import broadcast_lists, vectorize_point


[docs] def generate_random_tree(n_labels, p_keep=0.9, btol=1e-8): """Generate a random instance of ``Tree``. Parameters ---------- p_keep : float between 0 and 1 The probability that a sampled edge is kept and not deleted randomly. To be precise, it is not exactly the probability, as some edges cannot be deleted since the requirement that two labels are separated by a split might be violated otherwise. Defaults to 0.9 btol: float Tolerance for the boundary of the edge lengths. Defaults to 1e-08. """ labels = list(range(n_labels)) initial_splits = generate_splits(labels) splits = delete_splits(initial_splits, labels, p_keep, check=False) x = gs.random.uniform(size=(len(splits),), low=0, high=1) x = gs.minimum(gs.maximum(btol, x), 1 - btol) lengths = gs.maximum(btol, gs.abs(gs.log(1 - x))) return Tree(splits, lengths)
[docs] class TreeTopology(ForestTopology): r"""The topology of a tree, using a split-based representation. Parameters ---------- splits : list[Split] The structure of the tree in form of a set of splits of the set of labels. Attributes ---------- n_labels : int Number of labels, the set of labels is then :math:`\{0,\dots,n-1\}`. where : dict Give the index of a split in the flattened list of all splits. sep : list of int An increasing list of numbers between 0 and m, where m is the total number of splits in ``self.split_sets``, starting with 0, where each number indicates that a new connected component starts at that index. Useful for example for unraveling the tuple of all splits into ``self.split_sets``. paths : list of dict A list of dictionaries, each dictionary is for the respective connected component of the forest, and the items of each dictionary are for each pair of labels u, v, u < v in the respective component, a list of the splits on the unique path between the labels u and v. support : list of array-like For each split, give an :math:`n\times n` dimensional matrix, where the uv-th entry is ``True`` if the split separates the labels u and v, else ``False``. """ def __init__(self, splits): super().__init__( partition=(tuple(splits[0].part1.union(splits[0].part2)),), split_sets=(splits,), ) @property def splits(self): """Splits. Returns ------- splits : list[Split] The structure of the tree in form of a set of splits of the set of labels. """ return self.split_sets[0] @property def labels(self): """Node labels. Returns ------- labels : tuple[int] Node labels. """ return self.partition[0]
[docs] class Tree(Point): r"""A class for trees, that are phylogenetic trees, elements of the BHV space. A tree is essentially a phylogenetic tree with edges having length greater than zero. The representation of the tree is via splits, the edge lengths are stored in a vector. Parameters ---------- splits : list[Split] The structure of the tree in form of a set of splits of the set of labels. lengths : array-like, shape=[n_splits] The edge lengths of the splits, a vector containing positive numbers. Attributes ---------- topology : TreeTopology The topology of the tree. lengths : array-like, shape=[n_splits] The edge lengths of the splits, a vector containing positive numbers. """ def __init__(self, splits, lengths): self.topology = TreeTopology(splits=splits) self.lengths = gs.array( [ length for _, length in sorted( zip(splits, lengths), key=lambda x: self.topology.where.get(x[0]) ) ] ) def __repr__(self): """Return the string representation of the tree. This string representation requires that one can retrieve all necessary information from the string. Returns ------- string_of_tree : str Return the string representation of the tree. """ return repr((self.topology.splits, tuple(self.lengths))) def __str__(self): """Return the fancy printable string representation of the tree. This string representation does NOT require that one can retrieve all necessary information from the string, but this string representation is required to be readable by a human. Returns ------- string_of_tree : str Return the fancy readable string representation of the tree. """ return f"({self.topology};{str(self.lengths)})" def _equal_single(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Tree Point to compare against point. atol : float Returns ------- is_equal : bool """ if self.topology != point.topology: return False return gs.all(gs.abs(self.lengths - point.lengths) < atol)
[docs] @vectorize_point((1, "point")) def equal(self, point, atol=gs.atol): """Check equality against another point. Parameters ---------- point : Tree or TreeBatch Point to compare against point. atol : float Returns ------- is_equal : array-like, shape=[...] """ return gs.array([self._equal_single(point_, atol) for point_ in point])
[docs] class TreeBatch(PointBatch): """Tree batch.""" @property def topology(self): """Tree topology. Returns ------- topology : list[TreeTopology] """ return [point.topology for point in self] @property def lengths(self): """Edge lengths. Returns ------- lengths : array-like, shape=[n_points, n_splits] """ return gs.array([point.lengths for point in self])
[docs] class TreeSpace(PointSet): """Class for the Tree space, a point set containing phylogenetic trees. A topological space. Points in Tree space are instances of the class :class:`Tree`: phylogenetic trees with edge lengths between 0 and infinity. For the space of trees see also [BHV01]_. Parameters ---------- n_labels : int The number of labels in the trees. """ def __init__(self, n_labels, equip=True): self.n_labels = n_labels super().__init__(equip)
[docs] @staticmethod def default_metric(): """Metric to equip the space with if equip is True.""" return BHVMetric
def _belongs_single(self, point, atol=gs.atol): """Check if a point belongs to Tree space. Parameters ---------- point : Tree The point to be checked. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : bool Boolean denoting if point belongs to Tree space. """ if point.topology.n_labels != self.n_labels: return False return gs.all(point.lengths > -atol)
[docs] @vectorize_point((1, "point")) def belongs(self, point, atol=gs.atol): """Check if a point belongs to Tree space. Parameters ---------- point : Tree or TreeBatch The point to be checked. atol : float Absolute tolerance. Optional, default: backend atol. Returns ------- belongs : array-like, shape=[...] Boolean denoting if point belongs to Tree space. """ return gs.array([self._belongs_single(point_, atol) for point_ in point])
[docs] def random_point(self, n_samples=1, p_keep=0.9, btol=1e-8): """Sample a random point in Tree space. Parameters ---------- n_samples : int Number of samples. Defaults to 1. p_keep : float between 0 and 1 The probability that a sampled edge is kept and not deleted randomly. To be precise, it is not exactly the probability, as some edges cannot be deleted since the requirement that two labels are separated by a split might be violated otherwise. Defaults to 0.9 btol: float Tolerance for the boundary of the edge lengths. Defaults to 1e-08. Returns ------- samples : Tree or TreeBatch Points sampled in Tree space. """ trees = [ generate_random_tree(self.n_labels, p_keep, btol) for _ in range(n_samples) ] if n_samples == 1: return trees[0] return TreeBatch(trees)
[docs] class BHVMetric(PointSetMetric): """BHV metric for Tree Space for phylogenetic trees. The BHV Tree Space as it is introduced in [BHV01]_, a metric space that is CAT(0), and there exist unique geodesics between each pair of points in the BHV Space. The polynomial time algorithm for computing the distance and geodesic between two points is implemented, following the definitions and results of [OP11]_. There, computing the geodesic between two trees is called the 'Geodesic Tree Path' problem, and that is why some methods below (not visible to the user though) start with the letters 'gtp'. Parameters ---------- total_space : TreeSpace Set with quotient structure. """ def __init__(self, space): super().__init__(space=space) self.geodesic_solver = GTPSolver(n_labels=space.n_labels)
[docs] def squared_dist(self, point_a, point_b): """Compute the squared distance between two points. Parameters ---------- point_a : Tree or TreeBatch A point in BHV Space. point_b : Tree or TreeBatch A point in BHV Space. Returns ------- squared_dist : array-like, shape=[...] The squared distance between the two points. """ return self.geodesic_solver.squared_dist(point_a, point_b)
[docs] def dist(self, point_a, point_b): """Compute the distance between two points. Parameters ---------- point_a : Tree or TreeBatch A point in BHV Space. point_b : Tree or TreeBatch A point in BHV Space. Returns ------- dist : array-like, shape=[...] The distance between the two points. """ return self.geodesic_solver.dist(point_a, point_b)
[docs] def geodesic(self, initial_point, end_point): """Compute the geodesic between two points. Parameters ---------- initial_point : Tree or TreeBatch A point in BHV Space. end_point : Tree or TreeBatch A point in BHV Space. Returns ------- geodesic : callable The geodesic between the two points. Takes parameter t, that is the time between 0 and 1 at which the corresponding point on the path is returned. """ return self.geodesic_solver.geodesic( initial_point=initial_point, end_point=end_point )
[docs] class GTPSolver: """'Geodesic Tree Path' problem solver [OP11]_. Essentially uses Theorem 2.4 from [OP11]_. Parameters ---------- tol : float Tolerance for the algorithm, in particular for the decision problem in the GTP algorithm in [OP11] to avoid unambiguity. """ def __init__(self, n_labels, tol=1e-8): self.n_labels = n_labels self.tol = tol def _squared_dist_single(self, point_a, point_b): """Compute the squared distance between two points. Parameters ---------- point_a : Tree A point in BHV Space. point_b : Tree A point in BHV Space. Returns ------- squared_dist : array-like, shape=[...] The squared distance between the two points. """ sp_a = { split: length for split, length in zip(point_a.topology.splits, point_a.lengths) } sp_b = { split: length for split, length in zip(point_b.topology.splits, point_b.lengths) } common_a, common_b, supports = self._trees_with_common_support( sp_a, sp_b, ) sq_dist_common = sum((common_a[s] - common_b[s]) ** 2 for s in common_a.keys()) sq_dist_parts = sum( ( gs.sqrt(sum(sp_a[s] ** 2 for s in a)) + gs.sqrt(sum(sp_b[s] ** 2 for s in b)) ) ** 2 for supp_a, supp_b in supports.values() for a, b in zip(supp_a, supp_b) ) return sq_dist_common + sq_dist_parts
[docs] @vectorize_point((1, "point_a"), (2, "point_b")) def squared_dist(self, point_a, point_b): """Compute the squared distance between two points. Parameters ---------- point_a : Tree or TreeBatch A point in BHV Space. point_b : Tree or TreeBatch A point in BHV Space. Returns ------- squared_dist : array-like, shape=[...] The squared distance between the two points. """ point_a, point_b = broadcast_lists(point_a, point_b) sq_dists = gs.array( [ self._squared_dist_single(point_a_, point_b_) for point_a_, point_b_ in zip(point_a, point_b) ] ) if len(sq_dists) == 1: return sq_dists[0] return sq_dists
[docs] def dist(self, point_a, point_b): """Compute the distance between two points. Parameters ---------- point_a : Tree or TreeBatch A point in BHV Space. point_b : Tree or TreeBatch A point in BHV Space. Returns ------- dist : array-like, shape=[...] The distance between the two points. """ return gs.sqrt(self.squared_dist(point_a, point_b))
def _geodesic_single(self, initial_point, end_point): """Compute the geodesic between two points. Parameters ---------- initial_point : Tree A point in BHV Space. end_point : Tree A point in BHV Space. Returns ------- geodesic : callable The geodesic between the two points. Takes parameter t, that is the time between 0 and 1 at which the corresponding point on the path is returned. """ sp_a = dict(zip(initial_point.topology.splits, initial_point.lengths)) sp_b = dict(zip(end_point.topology.splits, end_point.lengths)) common_a, common_b, supports = self._trees_with_common_support( sp_a, sp_b, ) ratios = { part: [ gs.sqrt(sum(sp_a[s] ** 2 for s in a) / sum(sp_b[s] ** 2 for s in b)) for a, b in zip(supp_a, supp_b) ] for part, (supp_a, supp_b) in supports.items() } def geodesic_t(t): if t == 0.0: return initial_point elif t == 1.0: return end_point t_ratio = t / (1 - t) splits_t = {s: (1 - t) * common_a[s] + t * common_b[s] for s in common_a} for part, (supp_a, supp_b) in supports.items(): index = gs.argmax([t_ratio <= _r for _r in ratios[part] + [np.inf]]) splits_t_a = { s: sp_a[s] * (1 - t - t / _r) for a_k, _r in zip(supp_a[index:], ratios[part][index:]) for s in a_k } splits_t_b = { s: sp_b[s] * (t - (1 - t) * _r) for b_k, _r in zip(supp_b[:index], ratios[part][:index]) for s in b_k } splits_t = {**splits_t, **splits_t_a, **splits_t_b} splits_lengths = [ (split, length) for split, length in splits_t.items() if length > self.tol ] tree_t = Tree( splits=[sl[0] for sl in splits_lengths], lengths=[sl[1] for sl in splits_lengths], ) return tree_t def geodesic_(t): if isinstance(t, (float, int)): t = gs.array([t]) return TreeBatch([geodesic_t(t_) for t_ in t]) return geodesic_
[docs] @vectorize_point((1, "initial_point"), (2, "end_point")) def geodesic(self, initial_point, end_point): """Compute the geodesic between two points. Parameters ---------- initial_point : Tree or TreeBatch A point in BHV Space. end_point : Tree or TreeBatch A point in BHV Space. Returns ------- geodesic : callable The geodesic between the two points. Takes parameter t, that is the time between 0 and 1 at which the corresponding point on the path is returned. """ initial_point, end_point = broadcast_lists(initial_point, end_point) def _vec(t, fncs): if len(fncs) == 1: return fncs[0](t) return [fnc(t) for fnc in fncs] fncs = [ self._geodesic_single(initial_point_, end_point_) for initial_point_, end_point_ in zip(initial_point, end_point) ] return lambda t: _vec(t, fncs=fncs)
def _trees_with_common_support(self, splits_a, splits_b): """Compute the support that corresponds to a geodesic for common split sets. We refer to the splits of the tree corresponding to splits_a as A, and B analogously. This method divides the split sets into smaller split sets that have distinct support and then use the method ``gtp_trees_with_distinct_support``. For each of these smaller subsets, return the support in a dictionary. The common splits are returned separately as well as the respective edge lengths. A split in A that is not in B but compatible with all splits of B is added to the common splits of B with length zero, and vice versa for splits in B. Parameters ---------- splits_a : dict of Split, float The splits in A and their respective lengths. splits_b : dict of Split, float The splits in B and their respective lengths. Returns ------- common_a : dict Containing the splits of A that are also in B, as well as the splits of B that are compatible with all splits in A, given edge length zero. common_b : dict Containing the splits of B that are also in A, as well as the splits of A that are compatible with all splits in B, given edge length zero. supports: dict Containing for each subtree the respective support. """ pendants = { Split(part1=[i], part2=[j for j in range(self.n_labels) if j != i]) for i in range(self.n_labels) } sp_a, sp_b = set(splits_a.keys()), set(splits_b.keys()) common = sp_a & sp_b only_a = sp_a - common only_b = sp_b - common easy_a = {s for s in only_a if gs.all(list(map(s.is_compatible, only_b)))} easy_b = {s for s in only_b if gs.all(list(map(s.is_compatible, only_a)))} total_a = (sp_a | easy_b) - pendants total_b = (sp_b | easy_a) - pendants cut_splits = (common | easy_a | easy_b) - pendants trees_a = self._cut_tree_at_splits(total_a, cut_splits) trees_b = self._cut_tree_at_splits(total_b, cut_splits) supports = { part: self._trees_with_distinct_support( {s: splits_a[s] for s in trees_a[part]}, {s: splits_b[s] for s in trees_b[part]}, ) for part in trees_a.keys() if trees_a[part] and trees_b[part] } common = common | easy_a | easy_b common_a = {s: splits_a[s] if s in sp_a else 0 for s in common} common_b = {s: splits_b[s] if s in sp_b else 0 for s in common} return common_a, common_b, supports def _cut_tree_at_splits(self, splits, cut_splits): """Cut a tree, given by splits, at all edges in cut_splits. Starting with the partition that consists of all labels and is assigned all splits, the tree is successively cut into parts by the splits in cut_splits. Accordingly, the set of labels is cut successively into parts and the set of all splits is also cut successively into the respective parts. Parameters ---------- splits : iterable of split The tree given via its splits. Each split corresponds to an edge. cut_splits : iterable of Split A subset of splits, the edges at which the tree is cut. Returns ------- partition : dict of tuple, tuple A dictionary, where the keys form a partition of the set of labels (0,...,n_labels-1), and each key is assigned the tuple of splits that are part of the subtree that the respective set of labels is spanning. """ partition = {tuple(range(self.n_labels)): splits} for cut in cut_splits: try: labels, subtree = [ (_, subtree) for _, subtree in partition.items() if cut in subtree ][0] except IndexError: continue splits = set(subtree) - {cut} part1 = tuple(set(labels) & set(cut.part1)) part2 = tuple(set(labels) & set(cut.part2)) subtree1 = {s for s in splits if part1 == cut.get_part_towards(s)} subtree2 = splits - subtree1 partition.pop(labels) partition = { **partition, tuple(part1): tuple(subtree1), tuple(part2): tuple(subtree2), } return partition def _trees_with_distinct_support(self, splits_a, splits_b): """Compute the support that corresponds to a geodesic for disjoint split sets. This is essentially the GTP algorithm from [1], starting with a cone path and iteratively updating the support, solving in each iteration an extension problem for each support pair. The Extension Problem gives a minimum cut of a graph and two-set partitions C1 and C2 of A, and D1 and D2 of B, respectively. If the value of the minimum cut is greater or equal to one minus some tolerance, then the support pair (A,B) is split into (C1,D1) and (C2,D2). Parameters ---------- splits_a : dict of Split, float The splits in A and their respective lengths. splits_b : dict of Split, float The splits in B and their respective lengths. Returns ------- support_a : tuple of tuple The support partition of A corresponding to a geodesic. support_b : tuple of tuple The support partition of B corresponding to a geodesic. """ old_support_a = (tuple(splits_a.keys()),) old_support_b = (tuple(splits_b.keys()),) weights_a = {split: splits_a[split] ** 2 for split in splits_a} weights_b = {split: splits_b[split] ** 2 for split in splits_b} while 1: new_support_a, new_support_b = tuple(), tuple() for pair_a, pair_b in zip(old_support_a, old_support_b): pair_a_w = {s: weights_a[s] for s in pair_a} pair_b_w = {s: weights_b[s] for s in pair_b} value, c1, c2, d1, d2 = self._solve_extension_problem( pair_a_w, pair_b_w ) if value >= 1 - self.tol: new_support_a += (pair_a,) new_support_b += (pair_b,) else: new_support_a += (c1, c2) new_support_b += (d1, d2) if len(new_support_a) == len(old_support_a): return new_support_a, new_support_b else: old_support_a, old_support_b = new_support_a, new_support_b @staticmethod def _solve_extension_problem(sq_splits_a, sq_splits_b): """Solve the extension problem in [1] for sets of splits with squared weights. Solving the min weight vertex cover with respect to the incompatibility graph in the Extension Problem in [1] is equivalent to solving the minimum cut problem for the following directed graph with edges that have 'capacities'. The set of vertices are the splits in A, the splits in B, a sink and a source node. The source is connected to all splits in A, each edge has the normalized squared weight of the split it is attached to. Analogously, each split in B is connected to the sink and the corresponding edge has normalized squared weight of the split in B. Finally, each split in A is attached to a split in B whenever the splits are not compatible. The edge is given infinite capacity. The minimum cut returns the two-set partition (V, V_bar) of the set of vertices and its value, that is the sum of all capacities of edges from V to V_bar, such that the source is in V and the sink is in V_bar. If the value is larger or equal than one (possibly with respect to some tolerance), then a geodesic is found and there is no need to update anything. Else, the sets A and B are separated into sets C_1 = A intersection V_bar, C_2 = A intersection V, D_1 = B intersection V_bar, D_2 = B intersection V. Then, the new support is (i.e. A and B are replaced with) (C_1, C_2) and (D_1, D_2) (here, the notation from [1], GTP algorithm is used). Parameters ---------- sq_splits_a : dict of Split, float Dictionary of splits in A with squared length associated to each split. sq_splits_b : dict of Split, float Dictionary of splits in B with squared length associated to each split. Returns ------- value : float The value of the minimum cut. c1 : set of Split First part of A that it is split into. c2 : set of Split Second part of A that it is split into. d1 : set of Split First part of B that it is split into. d2 : set of Split Second part of B that it is split into. """ total_a, total_b = sum(sq_splits_a.values()), sum(sq_splits_b.values()) graph = nx.DiGraph() for split, weight in sq_splits_a.items(): graph.add_edge("source", split, capacity=weight / total_a) for split, weight in sq_splits_b.items(): graph.add_edge(split, "sink", capacity=weight / total_b) for split_a, split_b in it.product(sq_splits_a.keys(), sq_splits_b.keys()): if not split_a.is_compatible(split_b): graph.add_edge(split_a, split_b) min_value, (v, v_bar) = nx.minimum_cut(graph, "source", "sink") a = set(sq_splits_a.keys()) b = set(sq_splits_b.keys()) v = set(v) v_bar = set(v_bar) return min_value, tuple(a & v_bar), tuple(a & v), tuple(b & v_bar), tuple(b & v)