Source code for geomstats.geometry.lower_triangular_matrices
"""The vector space of lower triangular matrices.
Lead author: Saiteja Utpala.
"""
import geomstats.backend as gs
from geomstats.geometry.base import MatrixVectorSpace
from geomstats.geometry.matrices import Matrices, MatricesMetric
[docs]
class LowerTriangularMatrices(MatrixVectorSpace):
"""Class for the vector space of lower triangular matrices of size n.
Parameters
----------
n : int
Integer representing the shapes of the matrices: n x n.
"""
def __init__(self, n, equip=True):
super().__init__(dim=int(n * (n + 1) / 2), shape=(n, n), equip=equip)
self.n = n
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return MatricesMetric
def _create_basis(self):
"""Compute the basis of the vector space of lower triangular.
Returns
-------
basis : array-like, shape=[dim, n, n]
Basis matrices of the space.
"""
tril_idxs = gs.ravel_tril_indices(self.n)
vector_bases = gs.cast(
gs.one_hot(tril_idxs, self.n * self.n),
dtype=gs.get_default_dtype(),
)
return gs.reshape(vector_bases, [-1, self.n, self.n])
[docs]
def belongs(self, point, atol=gs.atol):
"""Evaluate if a matrix is lower triangular.
Parameters
----------
point : array-like, shape=[.., n, n]
Point to test.
atol : float
Tolerance to evaluate equality with the transpose.
Returns
-------
belongs : array-like, shape=[...,]
Boolean evaluating if point belongs to the space.
"""
belongs = super().belongs(point)
if gs.any(belongs):
is_lower_triangular = Matrices.is_lower_triangular(point, atol)
return gs.logical_and(belongs, is_lower_triangular)
return belongs
[docs]
@staticmethod
def basis_representation(matrix_representation):
"""Convert a lower triangular matrix into a vector.
Parameters
----------
matrix_representation : array-like, shape=[..., n, n]
Matrix.
Returns
-------
vec : array-like, shape=[..., n(n+1)/2]
Vector.
"""
return gs.tril_to_vec(matrix_representation)
[docs]
def projection(self, point):
"""Make a square matrix lower triangular by zeroing out other elements.
Parameters
----------
point : array-like, shape=[..., n, n]
Matrix.
Returns
-------
sym : array-like, shape=[..., n, n]
Symmetric matrix.
"""
return gs.tril(point)