"""Stiefel manifold St(n,p).
A set of all orthonormal p-frames in n-dimensional space, where p <= n.
Lead author: Oleg Kachan.
"""
import warnings
import geomstats.backend as gs
import geomstats.errors
from geomstats import algebra_utils
from geomstats.geometry.base import LevelSet
from geomstats.geometry.hermitian_matrices import powermh
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.riemannian_metric import RiemannianMetric
from geomstats.numerics.geodesic import LogSolver
from geomstats.vectorization import repeat_out
[docs]
class Stiefel(LevelSet):
"""Class for Stiefel manifolds St(n,p).
A set of all orthonormal p-frames in n-dimensional space,
where p <= n.
Parameters
----------
n : int
Dimension of the ambient vector space.
p : int
Number of basis vectors in the orthonormal frame.
"""
def __init__(self, n, p, equip=True):
geomstats.errors.check_integer(n, "n")
geomstats.errors.check_integer(p, "p")
if p > n:
raise ValueError("p needs to be smaller than n.")
self.n = n
self.p = p
self._value = gs.eye(p)
dim = int(p * n - (p * (p + 1) / 2))
super().__init__(dim=dim, equip=equip)
[docs]
@staticmethod
def default_metric():
"""Metric to equip the space with if equip is True."""
return StiefelCanonicalMetric
def _define_embedding_space(self):
return Matrices(self.n, self.p)
[docs]
def submersion(self, point):
"""Submersion that defines the manifold.
Parameters
----------
point : array-like, shape=[..., n, n]
Returns
-------
submersed_point : array-like, shape=[..., n, n]
"""
return Matrices.mul(Matrices.transpose(point), point) - self._value
[docs]
def tangent_submersion(self, vector, point):
"""Tangent submersion.
Parameters
----------
vector : array-like, shape=[..., n, n]
point : array-like, shape=[..., n, n]
Returns
-------
submersed_vector : array-like, shape=[..., n, n]
"""
return 2 * Matrices.to_symmetric(
Matrices.mul(Matrices.transpose(point), vector)
)
[docs]
@staticmethod
def to_grassmannian(point):
r"""Project a point of St(n, p) to Gr(n, p).
If :math:`U \in St(n, p)` is an orthonormal frame,
return the orthogonal projector :math:`P = U U^T`
onto the subspace of :math:`\mathbb{R}^n` spanned by
:math:`U`.
Parameters
----------
point : array-like, shape=[..., n, p]
Point.
Returns
-------
projected : array-like, shape=[..., n, n]
Projected point.
"""
return Matrices.mul(point, Matrices.transpose(point))
[docs]
def random_point(self, n_samples=1, bound=1.0):
r"""Sample on St(n,p) from the uniform distribution.
If :math:`Z(p,n) \sim N(0,1)`, then :math:`St(n,p) \sim U`,
according to Haar measure:
:math:`St(n,p) := Z(Z^TZ)^{-1/2}`.
Parameters
----------
n_samples : int
Number of samples.
Optional, default: 1.
bound : float
Unused here.
Returns
-------
samples : array-like, shape=[..., n, p]
Samples on the Stiefel manifold.
"""
return self.random_uniform(n_samples)
[docs]
def to_tangent(self, vector, base_point):
"""Project a vector to a tangent space of the manifold.
Inspired by the method of Pymanopt.
Parameters
----------
vector : array-like, shape=[..., n, p]
Vector.
base_point : array-like, shape=[..., n, p]
Point on the manifold.
Returns
-------
tangent_vec : array-like, shape=[..., n, p]
Tangent vector at base point.
"""
aux = Matrices.mul(Matrices.transpose(base_point), vector)
sym_aux = Matrices.to_symmetric(aux)
return vector - Matrices.mul(base_point, sym_aux)
[docs]
def projection(self, point):
"""Project a close enough matrix to the Stiefel manifold.
A singular value decomposition is used, and all singular values are
set to 1 [Absil]_
Parameters
----------
point : array-like, shape=[..., n, p]
Point in embedding manifold.
Returns
-------
projected : array-like, shape=[..., n, p]
References
----------
.. [Absil] Absil, Pierre-Antoine, and Jérôme Malick.
“Projection-like Retractions on Matrix Manifolds.”
SIAM Journal on Optimization 22, no. 1 (January 2012):
135–58. https://doi.org/10.1137/100802529.
"""
mat_u, _, mat_v = gs.linalg.svd(point)
return Matrices.mul(mat_u[..., :, : self.p], mat_v)
[docs]
class StiefelCanonicalMetric(RiemannianMetric):
"""Class that defines the canonical metric for Stiefel manifolds."""
def __init__(self, space):
super().__init__(space=space, signature=(space.dim, 0, 0))
self.log_solver = _StiefelLogSolver(space)
[docs]
def inner_product(self, tangent_vec_a, tangent_vec_b, base_point):
r"""Compute the inner-product of two tangent vectors at a base point.
Canonical inner-product on the tangent space at `base_point`,
which is different from the inner-product induced by the embedding
(see [RLSMRZ2017]_).
.. math::
\langle\Delta, \tilde{\Delta}\rangle_{U}=\operatorname{tr}
\left(\Delta^{T}\left(I-\frac{1}{2} U U^{T}\right)
\tilde{\Delta}\right)
Parameters
----------
tangent_vec_a : array-like, shape=[..., n, p]
First tangent vector at base point.
tangent_vec_b : array-like, shape=[..., n, p]
Second tangent vector at base point.
base_point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
Returns
-------
inner_prod : array-like, shape=[..., 1]
Inner-product of the two tangent vectors.
References
----------
.. [RLSMRZ2017] R Zimmermann. A matrix-algebraic algorithm for the
Riemannian logarithm on the Stiefel manifold under the canonical
metric. SIAM Journal on Matrix Analysis and Applications 38 (2),
322-342, 2017. https://epubs.siam.org/doi/pdf/10.1137/16M1074485
"""
base_point_transpose = Matrices.transpose(base_point)
aux = gs.matmul(
Matrices.transpose(tangent_vec_a),
gs.eye(self._space.n) - 0.5 * gs.matmul(base_point, base_point_transpose),
)
return Matrices.trace_product(aux, tangent_vec_b)
[docs]
def exp(self, tangent_vec, base_point):
"""Compute the Riemannian exponential of a tangent vector.
Parameters
----------
tangent_vec : array-like, shape=[..., n, p]
Tangent vector at a base point.
base_point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
Returns
-------
exp : array-like, shape=[..., n, p]
Point in the Stiefel manifold equal to the Riemannian exponential
of tangent_vec at the base point.
"""
p = self._space.p
matrix_a = Matrices.mul(Matrices.transpose(base_point), tangent_vec)
matrix_k = tangent_vec - Matrices.mul(base_point, matrix_a)
matrix_q, matrix_r = gs.linalg.qr(matrix_k)
matrix_ar = gs.concatenate([matrix_a, -Matrices.transpose(matrix_r)], axis=-1)
zeros = gs.zeros_like(tangent_vec)[..., :p, :p]
matrix_rz = gs.concatenate([matrix_r, zeros], axis=-1)
block = gs.concatenate([matrix_ar, matrix_rz], axis=-2)
matrix_mn_e = gs.linalg.expm(block)
exp = Matrices.mul(base_point, matrix_mn_e[..., :p, :p]) + Matrices.mul(
matrix_q, matrix_mn_e[..., p:, :p]
)
return exp
[docs]
@staticmethod
def retraction(tangent_vec, base_point):
"""Compute the retraction of a tangent vector.
This computation is based on the QR-decomposition.
e.g. :math:`P_x(V) = qf(X + V)`.
Parameters
----------
tangent_vec : array-like, shape=[..., n, p]
Tangent vector at a base point.
base_point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
Returns
-------
exp : array-like, shape=[..., n, p]
Point in the Stiefel manifold equal to the retraction
of tangent_vec at the base point.
"""
matrix_q, matrix_r = gs.linalg.qr(base_point + tangent_vec)
diagonal = gs.diagonal(matrix_r, axis1=-2, axis2=-1)
sign = gs.sign(gs.sign(diagonal) + 0.5)
diag = algebra_utils.from_vector_to_diagonal_matrix(sign)
result = Matrices.mul(matrix_q, diag)
return result
@staticmethod
def _matrix_r_single(matrix_m):
def _make_minor(i, matrix):
return matrix[: i + 1, : i + 1]
def _make_column_r(i, matrix, columns_list):
if i == 0:
return gs.array([1.0 / matrix[0, 0]])
matrix_m_i = _make_minor(i, matrix_m)
inv_matrix_m_i = gs.linalg.inv(matrix_m_i)
b_i = _make_b(i, matrix_m, columns_list)
column_r_i = gs.matvec(inv_matrix_m_i, b_i)
if column_r_i[i] <= 0:
raise ValueError("(r_i)_i <= 0")
return column_r_i
def _make_b(i, matrix, columns_list):
return gs.array(
[-gs.dot(matrix[i, : j + 1], columns_list[j]) for j in range(i)] + [1.0]
)
n = matrix_m.shape[-1]
columns_list = []
matrix_r = gs.zeros((n, n))
for j in range(n):
column_r_j = _make_column_r(j, matrix_m, columns_list)
columns_list.append(column_r_j)
matrix_r[: len(column_r_j), j] = column_r_j
return matrix_r
[docs]
def lifting(self, point, base_point):
"""Compute the lifting of a point.
This computation is based on the QR-decomposion.
e.g. :math:`P_x^{-1}(Q) = QR - X`.
Parameters
----------
point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
base_point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
Returns
-------
log : array-like, shape=[..., dim + 1]
Tangent vector at the base point equal to the lifting
of point at the base point.
"""
point, base_point = gs.broadcast_arrays(point, base_point)
matrix_m = gs.matmul(Matrices.transpose(base_point), point)
if gs.any(matrix_m[..., 0, 0] < 0.0):
raise ValueError("Algorithm does no work if m11 <= 0.")
if point.ndim == 2:
matrix_r = self._matrix_r_single(matrix_m)
else:
matrix_r = gs.stack([self._matrix_r_single(matrix) for matrix in matrix_m])
return gs.matmul(point, matrix_r) - base_point
[docs]
def injectivity_radius(self, base_point=None):
"""Compute the radius of the injectivity domain.
This is is the supremum of radii r for which the exponential map is a
diffeomorphism from the open ball of radius r centered at the base
point onto its image.
In this case the exact injectivity radius is not known, and we use here
a lower bound given by [Rentmeesters2015]_.
Parameters
----------
base_point : array-like, shape=[..., n, p]
Point on the manifold.
Returns
-------
radius : array-like, shape=[...,]
Injectivity radius.
References
----------
.. [Rentmeesters2015] Rentmeesters, Quentin. “Algorithms for Data
Fitting on Some Common Homogeneous Spaces.” UCL - Université
Catholique de Louvain, 2013.
https://dial.uclouvain.be/pr/boreal/object/boreal:132587.
"""
radius = gs.array(0.89 * gs.pi)
return repeat_out(self._space.point_ndim, radius, base_point)
class _StiefelLogSolver(LogSolver):
"""Stiefel log solver.
Parameters
----------
space : Stiefel.
Stiefel manifold.
max_iter : int
Maximum iterations.
tol : float
Tolerance.
imag_tol : float
Tolerance for image sum.
"""
def __init__(self, space, max_iter=500, tol=1e-8, imag_tol=1e-6):
super().__init__()
self._space = space
self.max_iter = max_iter
self.tol = tol
self.imag_tol = imag_tol
@staticmethod
def _normal_component_qr(point, base_point, matrix_m):
"""Compute the QR decomposition of the normal component of a point.
Parameters
----------
point : array-like, shape=[..., n, p]
base_point : array-like, shape=[..., n, p]
matrix_m : array-like
Returns
-------
matrix_q : array-like
matrix_n : array-like
"""
matrix_k = point - gs.matmul(base_point, matrix_m)
matrix_q, matrix_n = gs.linalg.qr(matrix_k)
return matrix_q, matrix_n
@staticmethod
def _orthogonal_completion(matrix_m, matrix_n):
"""Orthogonal matrix completion.
Parameters
----------
matrix_m : array-like
matrix_n : array-like
Returns
-------
matrix_v : array-like
"""
matrix_w = gs.concatenate([matrix_m, matrix_n], axis=-2)
matrix_v, _ = gs.linalg.qr(matrix_w, mode="complete")
return matrix_v
@staticmethod
def _procrustes_preprocessing(p, matrix_v, matrix_m, matrix_n):
"""Procrustes preprocessing.
Parameters
----------
matrix_v : array-like
matrix_m : array-like
matrix_n : array-like
Returns
-------
matrix_v : array-like
"""
[matrix_d, _, matrix_r] = gs.linalg.svd(matrix_v[..., p:, p:])
matrix_v_final = gs.copy(matrix_v)
for i in range(1, p + 1):
matrix_rd = Matrices.mul(matrix_r, Matrices.transpose(matrix_d))
sub_matrix_v = gs.matmul(matrix_v[..., :, p:], matrix_rd)
matrix_v_final = gs.concatenate(
[gs.concatenate([matrix_m, matrix_n], axis=-2), sub_matrix_v], axis=-1
)
det = gs.linalg.det(matrix_v_final)
if gs.all(det > 0):
break
ones = gs.ones(p)
reflection_vec = gs.concatenate([ones[:-i], gs.array([-1.0] * i)], axis=0)
mask = gs.cast(det < 0, matrix_v.dtype)
sign = mask[..., None] * reflection_vec + (1.0 - mask)[..., None] * ones
matrix_d = gs.einsum(
"...ij,...i->...ij", Matrices.transpose(matrix_d), sign
)
return matrix_v_final
def log(self, point, base_point):
"""Compute the Riemannian logarithm of a point.
When p=n, the space St(n,n)~O(n) has two non connected sheets: the
log is only defined for data from the same sheet.
For p<n, the space St(n,p)~O(n)/O(n-p)~SO(n)/SO(n-p) is connected.
Based on [ZR2017]_.
Parameters
----------
point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
base_point : array-like, shape=[..., n, p]
Point in the Stiefel manifold.
max_iter: int
Maximum number of iterations to perform during the algorithm.
Optional, default: 30.
tol: float
Tolerance to reach convergence. The matrix 2-norm is used as
criterion.
Optional, default: 1e-6.
Returns
-------
log : array-like, shape=[..., n, p]
Tangent vector at the base point equal to the Riemannian logarithm
of point at the base point.
References
----------
.. [ZR2017] Zimmermann, Ralf. "A Matrix-Algebraic Algorithm for the
Riemannian Logarithm on the Stiefel Manifold under the Canonical
Metric" SIAM J. Matrix Anal. & Appl., 38(2), 322–342, 2017.
https://arxiv.org/pdf/1604.05054.pdf
"""
n, p = self._space.n, self._space.p
if p == n:
det_point = gs.linalg.det(point)
det_base_point = gs.linalg.det(base_point)
if not gs.all(det_point * det_base_point > 0.0):
raise ValueError("Points from different sheets in log")
transpose_base_point = Matrices.transpose(base_point)
matrix_m = gs.matmul(transpose_base_point, point)
matrix_q, matrix_n = self._normal_component_qr(point, base_point, matrix_m)
matrix_v = self._orthogonal_completion(matrix_m, matrix_n)
matrix_v = self._procrustes_preprocessing(p, matrix_v, matrix_m, matrix_n)
matrix_lv = (
self._iter_log(p, matrix_v)
if gs.ndim(matrix_v) == 2
else gs.stack([self._iter_log(p, x) for x in matrix_v])
)
matrix_xv = gs.matmul(base_point, matrix_lv[..., :p, :p])
matrix_qv = gs.matmul(matrix_q, matrix_lv[..., p:, :p])
return matrix_xv + matrix_qv
def _iter_log(self, p, matrix_v):
matrix_lv = gs.zeros_like(matrix_v)
for _ in range(self.max_iter):
matrix_lv = gs.linalg.logm(matrix_v)
matrix_c = matrix_lv[..., p:, p:]
norm_matrix_c = gs.linalg.norm(matrix_c)
if norm_matrix_c <= self.tol:
break
matrix_phi = gs.linalg.expm(-Matrices.to_skew_symmetric(matrix_c))
aux_matrix = gs.matmul(matrix_v[..., :, p:], matrix_phi)
matrix_v = gs.concatenate([matrix_v[..., :, :p], aux_matrix], axis=-1)
else:
warnings.warn("`log` hasn't converged.")
if gs.is_complex(matrix_lv):
imag_sum = gs.amax(gs.abs(gs.imag(matrix_lv)))
if imag_sum < self.imag_tol:
matrix_lv = gs.real(matrix_lv)
else:
raise ValueError(f"Non-neglible imaginary part. max is {imag_sum}")
return matrix_lv