# Source code for geomstats.geometry.stratified.trees

```
r"""Helper classes for tree spaces.
Class ``Split``.
Essentially, a ``Split`` is a two-set partition of a subset of :math:`\{0,\dots,n-1\}`.
This class is designed such that one part of both parts of the partition can be empty.
Splits are corresponding uniquely to edges in a phylogenetic forest, where, if one cuts
the edge in the forest, the resulting two-set partition of the labels of the respective
component of the forest is the corresponding split.
Lead author: Jonas Lueg
"""
import functools
import itertools
import random
import geomstats.backend as gs
from geomstats.exceptions import NotPartialOrder
def _pop_random_elem(ls):
"""Pops a random element from a list.
Parameters
----------
ls : list
"""
random_index = random.randint(0, len(ls) - 1)
return ls.pop(random_index)
[docs]
def generate_splits(labels):
"""Generate random maximal set of compatible splits of set ``labels``.
This method works inductively on the number of elements in labels.
Start with a split of two randomly chosen labels. Then, successively choose
a label from the labels and add this as a leaf with a split to the existing
tree by attaching it to a random split, thereby dividing this split into two
splits and one has to update all the other splits accordingly.
Parameters
----------
labels : list[int]
A list of integers, the set of labels that we generate splits for.
Returns
-------
splits : list[Split]
A list of splits of the set of labels, maximal number of splits.
"""
if len(labels) <= 1:
return []
unused_labels = labels.copy()
u = _pop_random_elem(unused_labels)
v = _pop_random_elem(unused_labels)
used_labels = [u, v]
splits = [Split(part1={u}, part2={v})]
while unused_labels:
u = _pop_random_elem(unused_labels)
divided_split = _pop_random_elem(splits)
updated_splits = [
Split(part1={u}, part2=used_labels),
Split(part1=divided_split.part1 | {u}, part2=divided_split.part2),
Split(part1=divided_split.part1, part2=divided_split.part2 | {u}),
]
for split in splits:
updated_splits.append(
Split(
part1=split.get_part_away_from(divided_split),
part2=split.get_part_towards(divided_split) | {u},
)
)
used_labels.append(u)
splits = updated_splits
return splits
[docs]
def check_if_separated(labels, splits):
"""Check for each pair of labels if exists split that separates them.
Parameters
----------
labels : list[int]
A list of integers, the set of labels that we generate splits for.
splits : list[Split]
A list of splits of the set of labels.
Returns
-------
are_separated : bool
True if the labels are pair-wise separated by a split else False.
"""
return gs.all(
[
gs.any([sp.separates(u, v) for sp in splits])
for u, v in itertools.combinations(labels, 2)
]
)
[docs]
def delete_splits(splits, labels, p_keep, check=True):
"""Delete splits randomly from a set of splits.
We require the splits to satisfy the check for if all pair-wise labels are
separated. In this way, before deleting a split, this condition is checked
to make sure it is not violated.
Parameters
----------
splits : list[Split]
A list of splits of the set of labels.
labels : list[int]
A list of integers, the set of labels that we generate splits for.
p_keep : float
A float between 0 and 1 determining the probability with which a split
is kept and not deleted.
check : bool
If True, checks if splits still separate all labels. In this case, the split
will not be deleted. If False, any split can be randomly deleted.
Returns
-------
left_over_splits : list[Split]
The list of splits that are not deleted.
"""
if p_keep == 1:
return splits
for i in reversed(range(len(splits))):
if gs.random.rand(1) > p_keep:
splits_cp = splits.copy()
splits_cp.pop(i)
if not check:
splits = splits_cp
elif check_if_separated(splits=splits_cp, labels=labels):
splits = splits_cp
return splits
[docs]
@functools.total_ordering
class Split:
r"""Two-set partitions of sets.
Two-set partitions of a smaller subset of :math:`\{0,...,n-1\}` are also allowed,
where :math:`n` is a positive integer, which is not passed as an argument as it is
nowhere needed.
The parameters ``part1`` and ``part2`` are assigned to the attributes ``self.part1``
and ``self.part2`` in a unique sorted way: the one that contains the smallest
minimal value is assigned to ``self.part1`` for consistency.
Parameters
----------
part1 : iterable
The first part of the split, an iterable that is a subset of
:math:`\{0,\dots,n-1\}`. It may be empty, but must have empty intersection with
``part2``.
part2 : iterable
The second part of the split, an iterable that is a subset of
:math:`\{0,\dots,n-1\}`. It may be empty, but must have empty intersection with
``part1``.
"""
def __init__(self, part1, part2):
part1, part2 = set(part1), set(part2)
if part1 & part2:
raise ValueError(
f"A split consists of disjoint sets, those are not: {part1}, {part2}."
)
if part1 and part2:
self.part1, self.part2 = (
(part1, part2) if min(part1) < min(part2) else (part2, part1)
)
else:
self.part1 = part1 or part2
self.part2 = set()
def __bool__(self):
"""Return True if and only if both parts are non-empty.
We use the boolean representation to indicate whether both parts of a split are
non-empty.
Returns
-------
boolean_of_split : bool
Returns the boolean representation of a split.
"""
return bool(self.part1) and bool(self.part2)
def __eq__(self, other):
"""Check for equal hashes of the two splits.
Parameters
----------
other : Split
The other split.
Returns
-------
is_equal : bool
Return ``True`` if the splits are equal, else ``False``.
"""
return hash(self) == hash(other)
def __hash__(self):
"""Compute the hash of a split.
Note that this hash simply uses the hash function for tuples.
Returns
-------
hash_of_split : int
Return the hash of the split.
"""
return hash((tuple(self.part1), tuple(self.part2)))
def __lt__(self, other):
"""Check if the hash of this split is less than the hash of the other split.
Note that this partial ordering does not have a mathematical background, this is
introduced in order to have a unique ordering for each set of splits at hand.
Parameters
----------
other : Split
The other split.
Returns
-------
is_strictly_less_than : bool
Return ``True`` if hash is less than hash of other, else ``False``.
"""
return hash(self) < hash(other)
def __repr__(self):
"""Return the string representation of the split.
This string representation requires that one can retrieve all necessary
information from the string.
Returns
-------
string_of_split : str
Return the string representation of the split.
"""
return str((self.part1, self.part2))
def __str__(self):
"""Return the fancy printable string representation of the split.
This string representation does NOT require that one can retrieve all necessary
information from the string, but this string representation is required to be
readable by a human.
Returns
-------
string_of_split : str
Return the fancy readable string representation of the split.
"""
return f"{self.part1}|{self.part2}"
[docs]
def is_compatible(self, other):
"""Check whether this split is compatible with another split.
Two splits are compatible, if at least one intersection of the respective parts
of the splits is empty.
Parameters
----------
other : Split
The other split.
Returns
-------
is_compatible_with : bool
Return ``True`` if the splits are compatible, else ``False``.
"""
p1, p2 = self.part1, self.part2
o1, o2 = other.part1, other.part2
return sum([bool(s) for s in [p1 & o1, p1 & o2, p2 & o1, p2 & o2]]) < 4
[docs]
def get_part_away_from(self, other):
"""Return the part of this split that is directed away from other split.
Parameters
----------
other : Split
The other split.
Returns
-------
part_that_does_not_point : iterable
Return the part of the split ``self`` that does not point toward
``other``. See ``self.get_part_towards`` for further explanation.
"""
if other.part_contains(self.part1):
return self.part1
return self.part2
[docs]
def get_part_towards(self, other):
"""Return the part of this split that is directed toward the other split.
Each split contains part1 and part2, the parts that the corresponding edge in
the graph splits the set of labels into. Thus, one can think of the split as an
edge, where part1 points in the direction of the part of the tree where the
labels of part1 are contained, and part2 points in the other direction.
So, part1 points in the direction of ``other``, if it corresponds to an
edge that is contained in the tree that part1 points to, else part2 points in
the direction of ``split``.
Parameters
----------
other : Split
The other split.
Returns
-------
part_towards : iterable
Return the part of the split ``self`` that points toward ``other_split``.
"""
if other.part_contains(self.part1):
return self.part2
return self.part1
[docs]
def part_contains(self, subset):
"""Determine if a subset is contained in either part of a split.
Parameters
----------
subset : set
The subset containing labels.
Returns
-------
is_contained : bool
A boolean that is true if the subset is contained in ``self.part1`` or
``self.part2``.
"""
return subset.issubset(self.part1) or subset.issubset(self.part2)
[docs]
def restrict_to(self, subset):
r"""Return the restriction of a split to a subset.
Parameters
----------
subset : set
The subset that the split is restricted to.
Returns
-------
restr_split : Split
The restricted split, if the split is :math:`A\vert B`, then the split
restricted to the subset :math:`C` is :math:`A\cap C\vert B\cap C`.
"""
return Split(
part1=self.part1 & subset,
part2=self.part2 & subset,
)
[docs]
def separates(self, u, v):
"""Determine whether the labels (or label sets) are separated by the split.
Parameters
----------
u : list of int, int
Either an integer or a set of labels.
v : list of int, int
Either an integer or a set of labels.
Returns
-------
is_separated : bool
A boolean determining whether u and v are separated by the split (i.e. if
they are not in the same part).
"""
u = {u} if isinstance(u, int) else set(u)
v = {v} if isinstance(v, int) else set(v)
b1 = u.issubset(self.part1) and v.issubset(self.part2)
b2 = v.issubset(self.part1) and u.issubset(self.part2)
return b1 or b2
[docs]
class ForestTopology:
r"""The topology of a forest, using a split-based graph-structure representation.
A forest topology is a partition into non-empty sets of the set
:math:`\{0,\dots,n-1\}`, together with a set of splits for each element of the
partition, where every split is a two-set partition of the respective element.
A structure basically describes a phylogenetic forest, where each set of splits
gives the structure of the tree with the labels of the corresponding element of
the partition.
Parameters
----------
partition : tuple
A tuple of tuples that is a partition of the set :math:`\{0,\dots,n-1\}`,
representing the label sets of each connected component of the forest topology.
split_sets : tuple
A tuple of tuples containing splits, where each set of splits contains only
splits of the respective label set in the partition, so their order
is related. The splits are the edges of the connected components of the forest,
respectively, and thus the union of all sets of splits yields all edges of the
forest topology.
Attributes
----------
n_labels : int
Number of labels, the set of labels is then :math:`\{0,\dots,n-1\}`.
n_splits : int
Number of splits.
where : dict
Give the index of a split in the flattened list of all splits.
sep : list of int
An increasing list of numbers between 0 and m, where m is the total number
of splits in ``self.split_sets``, starting with 0, where each number
indicates that a new connected component starts at that index.
Useful for example for unraveling the tuple of all splits into
``self.split_sets``.
paths : list of dict
A list of dictionaries, each dictionary is for the respective connected
component of the forest, and the items of each dictionary are for each pair
of labels u, v, u < v in the respective component, a list of the splits on the
unique path between the labels u and v.
support : list of array-like
For each split, give an :math:`n\times n` dimensional matrix, where the
uv-th entry is ``True`` if the split separates the labels u and v, else
``False``.
"""
def __init__(self, partition, split_sets):
self._check_init(partition, split_sets)
self.n_labels = len(set.union(*[set(part) for part in partition]))
partition = [tuple(sorted(x)) for x in partition]
seq = [part[0] for part in partition]
sort_key = sorted(range(len(seq)), key=seq.__getitem__)
self.partition = tuple([partition[key] for key in sort_key])
self.split_sets = tuple([tuple(sorted(split_sets[key])) for key in sort_key])
self.where = {s: i for i, s in enumerate(self._flatten(self.split_sets))}
lengths = [len(splits) for splits in self.split_sets]
self.sep = [0] + [sum(lengths[0:j]) for j in range(1, len(lengths) + 1)]
self.paths = [
{
(u, v): [s for s in splits if s.separates(u, v)]
for u, v in itertools.combinations(part, r=2)
}
for part, splits in zip(self.partition, self.split_sets)
]
_support = [
gs.zeros((self.n_labels, self.n_labels), dtype=int)
for _ in self._flatten(self.split_sets)
]
for path_dict in self.paths:
for (u, v), path in path_dict.items():
for split in path:
_support[self.where[split]][u][v] = True
_support[self.where[split]][v][u] = True
self.support = gs.reshape(
gs.array([m for m in self._flatten(_support)]),
(-1, self.n_labels, self.n_labels),
)
self._chart_gradient = None
self.n_splits = gs.sum(
gs.array([len(splits) for splits in self.split_sets]), dtype=int
)
def _check_init(self, partition, split_sets):
if len(split_sets) != len(partition):
raise ValueError(
"Number of split sets is not equal to number of " "components."
)
for _part, _splits in zip(partition, split_sets):
for _sp in _splits:
if (_sp.part1 | _sp.part2) != set(_part):
raise ValueError(
f"The split {_sp} is not a split of component {_part}."
)
def __eq__(self, other):
"""Check if ``self`` is equal to ``other``.
Parameters
----------
other : ForestTopology
The other topology.
Returns
-------
is_equal : bool
Return ``True`` if the topologies are equal, else ``False``.
"""
equal_n = self.n_labels == other.n_labels
equal_partition = self.partition == other.partition
equal_split_sets = self.split_sets == other.split_sets
return equal_n and equal_partition and equal_split_sets
def __ge__(self, other):
"""Check if ``self`` is greater than or equal to ``other``.
Parameters
----------
other : ForestTopology
The other topology.
Returns
-------
is_greater_than_or_equal : bool
Return ``True`` if ``self`` is greater or equal to ``other``, else
``False``.
"""
return other <= self
def __gt__(self, other):
"""Check if ``self`` is strictly greater than ``other``.
Parameters
----------
other : ForestTopology
The other topology.
Returns
-------
is_greater_than : bool
Return ``True`` if this topology is greater than the other, else ``False``.
"""
return other < self
def __hash__(self):
"""Compute the hash of a topology.
Note that this hash simply uses the hash function for tuples.
Returns
-------
hash_of_topology : int
Return the hash of the topology.
"""
return hash((self.n_labels, self.partition, self.split_sets))
def __le__(self, other):
"""Check if ``self`` is less than or equal to ``other``.
This partial ordering is the one defined in [1] and to show if self <= other is
True, three things must be satisfied.
(i) ``self.partition`` must be a refinement of ``other.partition`` in the
sense of partitions.
(ii) The splits of each component in ``self`` must be contained in the
set of splits of ``other`` restricted to the component of ``self``.
(iii) Whenever two components of ``self`` are contained in a component of
``other``, there needs to exist a split in ``other`` separating those
two components.
If one of those three conditions are not fulfilled, this method returns False.
Parameters
----------
other : ForestTopology
The structure to which self is compared to.
Returns
-------
is_less_than_or_equal : bool
Return ``True`` if (i), (ii) and (iii) are satisfied, else ``False``.
"""
x_parts = [set(x) for x in self.partition]
y_parts = [set(y) for y in other.partition]
# (i)
try:
cover = {
i: [j for j, y in enumerate(y_parts) if x.issubset(y)][0]
for i, x in enumerate(x_parts)
}
except IndexError:
return False
# (ii)
try:
for (i, j), x in zip(cover.items(), x_parts):
y_splits_restricted = {
split_y.restrict_to(subset=x) for split_y in other.split_sets[j]
}
if not set(self.split_sets[i]).issubset(y_splits_restricted):
raise NotPartialOrder()
except NotPartialOrder:
return False
# (iii)
try:
for j in range(len(y_parts)):
xs_in_y = [x for i, x in enumerate(x_parts) if cover[i] == j]
for x1, x2 in itertools.combinations(xs_in_y, r=2):
sep_sp = [sp for sp in other.split_sets[j] if sp.separates(x1, x2)]
if not sep_sp:
raise NotPartialOrder()
except NotPartialOrder:
return False
return True
def __lt__(self, other):
"""Check if ``self`` is less than ``other``.
Parameters
----------
other : ForestTopology
The other topology.
Returns
-------
is_less_than : bool
Return ``True`` if ``self`` less than ``other``, else ``False``.
"""
return self <= other and self != other
def __repr__(self):
"""Return the string representation of the topology.
This string representation requires that one can retrieve all necessary
information from the string.
Returns
-------
string_of_topology : str
Return the string representation of the topology.
"""
return str((self.n_labels, self.partition, self.split_sets))
def __str__(self):
"""Return the fancy printable string representation of the topology.
This string representation does NOT require that one can retrieve all necessary
information from the string, but this string representation is required to be
readable by a human.
Returns
-------
string_of_topology : str
Return the fancy readable string representation of the topology.
"""
comps = [", ".join(str(sp) for sp in splits) for splits in self.split_sets]
return "(" + "; ".join(comps) + ")"
[docs]
def corr(self, weights):
"""Compute the correlation matrix of the topology with edge weights ``weights``.
Parameters
----------
weights : array-like, [n_splits]
Edge weights.
Returns
-------
corr : array-like, shape=[n, n]
Returns the corresponding correlation matrix.
"""
corr = gs.zeros((self.n_labels, self.n_labels))
for path_dict in self.paths:
for (u, v), path in path_dict.items():
corr[u][v] = gs.prod(
gs.array([1 - weights[self.where[split]] for split in path])
)
corr[v][u] = corr[u][v]
corr = gs.array(corr)
return corr + gs.eye(corr.shape[0])
[docs]
def corr_gradient(self, weights):
"""Compute the gradient of the correlation matrix, differentiated by weights.
Parameters
----------
weights : array-like, [n_splits]
The vector weights at which the gradient is computed.
Returns
-------
gradient : array-like, shape=[n_splits, n, n]
The gradient of the correlation matrix, differentiated by weights.
"""
x_list = [
[y if i != k else 0 for i, y in enumerate(weights)]
for k in range(len(weights))
]
gradient = gs.array(
[-supp * self.corr(x) for supp, x in zip(self.support, x_list)]
)
return gradient
def _unflatten(self, ls):
"""Transform list into list of lists according to separators, ``self.sep``.
The separators are a list of integers, increasing. Then, all elements between to
indices in separators will be put into a list, and together, all lists give a
nested list.
Parameters
----------
ls : iterable
The flat list that will be nested.
Returns
-------
ls_nested : list[list]
The nested list of lists.
"""
return [ls[i:j] for i, j in zip(self.sep[:-1], self.sep[1:])]
@staticmethod
def _flatten(ls):
"""Flatten a list of lists into a single list by concatenation.
Parameters
----------
ls : nested list
The nested list to flatten.
Returns
-------
ls_flat : list, tuple
The flatted list.
"""
return [y for z in ls for y in z]
```