Source code for geomstats.visualization.spd_matrices
"""Visualization for Geometric Statistics."""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # NOQA
import geomstats.backend as gs
[docs]
class Ellipses:
"""Class used to plot points on the manifold SPD(2).
Elements S of the manifold of 2D Symmetric Positive Definite matrices
can be conveniently represented by ellipses.
We write :math:`S = O D O^T` with :math:`O` an orthogonal matrix (rotation)
and :math:`D` a diagonal matrix. The positive eigenvalues, i.e. the elements of
:math:`D`, are the inverse of the length of the major and minor axes of
the ellipse. The rotation matrix :math:`O` determines the orientation of the
2D ellipse in the 2D plane.
Parameters
----------
k_sampling_points : int
Number of points to sample on the discretized ellipses.
"""
def __init__(self, k_sampling_points=100):
self.k_sampling_points = k_sampling_points
[docs]
@staticmethod
def set_ax(ax=None):
"""Set the axis for the Figure.
Parameters
----------
ax : Axis
Axis of the figure.
Returns
-------
ax : Axis
Axis of the figure.
"""
if ax is None:
ax = plt.subplot()
plt.setp(ax, xlabel="X", ylabel="Y")
return ax
[docs]
def draw_points(self, points=None, ax=None, **plot_kwargs):
"""Draw the ellipses.
Parameters
----------
ax : Axis
Axis of the figure.
points : array-like, shape=[..., 2, 2]
Points on the SPD manifold of 2D symmetric
positive definite matrices.
Optional, default: None.
plot_kwargs : dict
Dictionary of arguments related to plotting.
"""
if ax is None:
ax = self.set_ax()
if points.ndim == 2:
points = [points]
for point in points:
x_coords, y_coords = self.compute_coordinates(point)
ax.plot(x_coords, y_coords, **plot_kwargs)
[docs]
def compute_coordinates(self, point):
"""Compute the ellipse coordinates of a 2D SPD matrix.
Parameters
----------
point : array-like, shape=[2, 2]
SPD matrix.
Returns
-------
x_coords : array-like, shape=[k_sampling_points,]
x_coords coordinates of the sampling points on the discretized ellipse.
Y: array-like, shape = [k_sampling_points,]
y coordinates of the sampling points on the discretized ellipse.
"""
eigvalues, eigvectors = gs.linalg.eigh(point)
eigvalues = gs.where(eigvalues < gs.atol, gs.atol, eigvalues)
[eigvalue1, eigvalue2] = eigvalues
rot_sin = eigvectors[1, 0]
rot_cos = eigvectors[0, 0]
thetas = gs.linspace(0.0, 2 * gs.pi, self.k_sampling_points + 1)
x_coords = eigvalue1 * gs.cos(thetas) * rot_cos
x_coords -= rot_sin * eigvalue2 * gs.sin(thetas)
y_coords = eigvalue1 * gs.cos(thetas) * rot_sin
y_coords += rot_cos * eigvalue2 * gs.sin(thetas)
return x_coords, y_coords