"""Visualization for Geometric Statistics."""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # NOQA
import geomstats.backend as gs
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.pre_shape import PreShapeSpace
M32 = Matrices(m=3, n=2)
S32 = PreShapeSpace(k_landmarks=3, ambient_dim=2)
S32.equip_with_group_action("rotations")
S32.equip_with_quotient()
M33 = Matrices(m=3, n=3)
S33 = PreShapeSpace(k_landmarks=3, ambient_dim=3)
S33.equip_with_group_action("rotations")
S33.equip_with_quotient()
[docs]
class KendallSphere:
"""Class used to plot points in Kendall shape space of 2D triangles.
David G. Kendall showed that the shape space of 2D triangles is isometric
to the 2-sphere of radius 1/2 [K1984]_. This class encodes this isometric
representation, offering a 3D visualization of Kendall shape space of order
(3,2), and its related objects.
Attributes
----------
points : list
List of points to plot on the Kendall sphere.
coords_type : str
Type of the points. Can be either 'pre-shape' (for points in Kendall
pre-shape space) or 'extrinsic' (for points given as 3x2 matrices).
Optional, default: 'pre-shape'.
pole : array-like, shape=[3,2]
Equilateral triangle (north pole).
ua : array-like, shape=[3,2]
Tangent vector toward isosceles triangle at vertex A.
ub : array-like, shape=[3,2]
Tangent vector toward isosceles triangle at vertex B.
na : array-like, shape=[3,2]
Tangent vector such that (ua,na) is a positively oriented
orthonormal basis of the horizontal space at north pole.
References
----------
.. [K1984] David G. Kendall. "Shape Manifolds, Procrustean Metrics, and
Complex Projective Spaces." Bulletin of the London Mathematical
Society, Volume 16, Issue 2, March 1984, Pages 81–121.
https://doi.org/10.1112/blms/16.2.81
"""
def __init__(self, points=None, coords_type="pre-shape"):
self.points = []
self.coords_type = coords_type
self.ax = None
self.elev, self.azim = None, None
self.pole = gs.array(
[[1.0, 0.0], [-0.5, gs.sqrt(3.0) / 2.0], [-0.5, -gs.sqrt(3.0) / 2.0]]
) / gs.sqrt(3.0)
self.ua = gs.array(
[[-1.0, 0.0], [0.5, gs.sqrt(3.0) / 2.0], [0.5, -gs.sqrt(3.0) / 2.0]]
) / gs.sqrt(3.0)
self.ub = gs.array(
[[0.5, gs.sqrt(3.0) / 2.0], [0.5, -gs.sqrt(3.0) / 2], [-1.0, 0.0]]
) / gs.sqrt(3.0)
self.na = self.ub - S32.metric.inner_product(self.ub, self.ua) * self.ua
self.na = self.na / S32.metric.norm(self.na)
if points is not None:
self.add_points(points)
[docs]
def set_ax(self, ax=None):
"""Set axis."""
if ax is None:
ax = plt.subplot(111, projection="3d")
ax_s = 0.5
plt.setp(
ax,
xlim=(-ax_s, ax_s),
ylim=(-ax_s, ax_s),
zlim=(-ax_s, ax_s),
xlabel="X",
ylabel="Y",
zlabel="Z",
)
self.ax = ax
[docs]
def set_view(self, elev=60.0, azim=0.0):
"""Set azimuth and elevation angle."""
if self.ax is None:
self.set_ax()
self.elev, self.azim = gs.pi * elev / 180, gs.pi * azim / 180
self.ax.view_init(elev, azim)
[docs]
def convert_to_polar_coordinates(self, points):
"""Assign polar coordinates to given pre-shapes."""
aligned_points = S32.fiber_bundle.align(points, self.pole)
speeds = S32.metric.log(aligned_points, self.pole)
coords_theta = gs.arctan2(
S32.metric.inner_product(speeds, self.na),
S32.metric.inner_product(speeds, self.ua),
)
coords_phi = 2.0 * S32.metric.dist(self.pole, aligned_points)
return coords_theta, coords_phi
[docs]
def convert_to_spherical_coordinates(self, points):
"""Convert polar coordinates to spherical one."""
coords_theta, coords_phi = self.convert_to_polar_coordinates(points)
coords_x = 0.5 * gs.cos(coords_theta) * gs.sin(coords_phi)
coords_y = 0.5 * gs.sin(coords_theta) * gs.sin(coords_phi)
coords_z = 0.5 * gs.cos(coords_phi)
spherical_coords = gs.transpose(gs.stack((coords_x, coords_y, coords_z)))
return spherical_coords
[docs]
def add_points(self, points):
"""Add points to draw on the Kendall sphere."""
if self.coords_type == "extrinsic":
if not gs.all(M32.belongs(points)):
raise ValueError("Points do not belong to Matrices(3, 2).")
points = S32.projection(points)
elif self.coords_type == "pre-shape" and not gs.all(S32.belongs(points)):
raise ValueError("Points do not belong to the pre-shape space.")
points = self.convert_to_spherical_coordinates(points)
if not isinstance(points, list):
if points.shape == (3,):
points = [gs.array(points)]
else:
points = list(points)
self.points.extend(points)
[docs]
def clear_points(self):
"""Clear the points to draw."""
self.points = []
[docs]
def draw(self, n_theta=25, n_phi=13, scale=0.05, elev=60.0, azim=0.0):
"""Draw the sphere regularly sampled with corresponding triangles."""
self.set_ax()
self.set_view(elev=elev, azim=azim)
self.ax.set_axis_off()
plt.tight_layout()
coords_theta = gs.linspace(0.0, 2.0 * gs.pi, n_theta)
coords_phi = gs.linspace(0.0, gs.pi, n_phi)
coords_x = gs.to_numpy(0.5 * gs.outer(gs.sin(coords_phi), gs.cos(coords_theta)))
coords_y = gs.to_numpy(0.5 * gs.outer(gs.sin(coords_phi), gs.sin(coords_theta)))
coords_z = gs.to_numpy(
0.5 * gs.outer(gs.cos(coords_phi), gs.ones_like(coords_theta))
)
self.ax.plot_surface(
coords_x,
coords_y,
coords_z,
rstride=1,
cstride=1,
color="grey",
linewidth=0,
alpha=0.1,
zorder=-1,
)
self.ax.plot_wireframe(
coords_x,
coords_y,
coords_z,
linewidths=0.6,
color="grey",
alpha=0.6,
zorder=-1,
)
def lim(theta):
return (
gs.pi
- self.elev
+ (2.0 * self.elev - gs.pi) / gs.pi * abs(self.azim - theta)
)
for theta in gs.linspace(0.0, 2.0 * gs.pi, n_theta // 2 + 1):
for phi in gs.linspace(0.0, gs.pi, n_phi):
if theta <= self.azim + gs.pi and phi <= lim(theta):
self.draw_triangle(theta, phi, scale)
if theta > self.azim + gs.pi and phi < lim(
2.0 * self.azim + 2.0 * gs.pi - theta
):
self.draw_triangle(theta, phi, scale)
[docs]
def draw_triangle(self, theta, phi, scale):
"""Draw the corresponding triangle on the sphere at theta, phi."""
u_theta = gs.cos(theta) * self.ua + gs.sin(theta) * self.na
triangle = gs.cos(phi / 2.0) * self.pole + gs.sin(phi / 2.0) * u_theta
triangle = scale * triangle
triangle3d = gs.transpose(
gs.stack((triangle[:, 0], triangle[:, 1], 0.5 * gs.ones(3)))
)
triangle3d = self.rotation(theta, phi) @ gs.transpose(triangle3d)
x = list(triangle3d[0]) + [triangle3d[0, 0]]
y = list(triangle3d[1]) + [triangle3d[1, 0]]
z = list(triangle3d[2]) + [triangle3d[2, 0]]
self.ax.plot3D(x, y, z, "grey", zorder=1)
c = ["red", "green", "blue"]
for i in range(3):
self.ax.scatter(x[i], y[i], z[i], color=c[i], s=10, alpha=1, zorder=1)
[docs]
@staticmethod
def rotation(theta, phi):
"""Rotation sending a triangle at pole to location theta, phi."""
rot_th = gs.array(
[
[gs.cos(theta), -gs.sin(theta), 0.0],
[gs.sin(theta), gs.cos(theta), 0.0],
[0.0, 0.0, 1.0],
]
)
rot_phi = gs.array(
[
[gs.cos(phi), 0.0, gs.sin(phi)],
[0.0, 1.0, 0.0],
[-gs.sin(phi), 0, gs.cos(phi)],
]
)
return rot_th @ rot_phi @ gs.transpose(rot_th)
[docs]
def draw_points(self, alpha=1, zorder=0, **kwargs):
"""Draw points on the Kendall sphere."""
points_x = [gs.to_numpy(point)[0] for point in self.points]
points_y = [gs.to_numpy(point)[1] for point in self.points]
points_z = [gs.to_numpy(point)[2] for point in self.points]
self.ax.scatter(
points_x, points_y, points_z, alpha=alpha, zorder=zorder, **kwargs
)
[docs]
def draw_curve(self, alpha=1, zorder=0, **kwargs):
"""Draw a curve on the Kendall sphere."""
points_x = [gs.to_numpy(point)[0] for point in self.points]
points_y = [gs.to_numpy(point)[1] for point in self.points]
points_z = [gs.to_numpy(point)[2] for point in self.points]
self.ax.plot3D(
points_x, points_y, points_z, alpha=alpha, zorder=zorder, **kwargs
)
[docs]
def draw_vector(self, tangent_vec, base_point, **kwargs):
"""Draw one vector in the tangent space to sphere at a base point."""
norm = S32.quotient.metric.norm(tangent_vec, base_point)
exp = S32.quotient.metric.exp(tangent_vec, base_point)
bp = self.convert_to_spherical_coordinates(base_point)
exp = self.convert_to_spherical_coordinates(exp)
tv = exp - gs.dot(exp, 2.0 * bp) * 2.0 * bp
tv = tv / gs.linalg.norm(tv) * norm
self.ax.quiver(bp[0], bp[1], bp[2], tv[0], tv[1], tv[2], **kwargs)
[docs]
class KendallDisk:
"""Class used to plot points in Kendall shape space of 3D triangles.
The shape space of 2D triangles is isometric to the 2-sphere of radius 1/2
[K1984]_. This isometry induced another isometry between the shape space of
3D triangle and the 1-ball of radius pi/4 [LK1993]_. Following the first
visualization class "KendallSphere" for 2D triangles, this class encodes
the 2D isometric representation of Kendall shape space of order (3,3).
Attributes
----------
points : list
List of points to plot on the Kendall sphere.
coords_type : str
Type of the points. Can be either 'pre-shape' (for points in Kendall
pre-shape space) or 'extrinsic' (for points given as 3x2 matrices).
Optional, default: 'pre-shape'.
pole : array-like, shape=[3,2]
Equilateral triangle in 2D (north pole).
centre : array-like, shape=[3,3]
Equilateral triangle in 3D (centre).
ua : array-like, shape=[3,2]
Tangent vector at north pole toward isosceles triangle at vertex A.
ub : array-like, shape=[3,2]
Tangent vector at north pole toward isosceles triangle at vertex B.
na : array-like, shape=[3,2]
Tangent vector such that (ua,na) is a positively oriented
orthonormal basis of the horizontal space at north pole.
References
----------
.. [K1984] David G. Kendall. "Shape Manifolds, Procrustean Metrics, and
Complex Projective Spaces." Bulletin of the London Mathematical
Society, Volume 16, Issue 2, March 1984, Pages 81–121.
https://doi.org/10.1112/blms/16.2.81
.. [LK1993] Huiling Le and David G. Kendall. "The Riemannian structure of
Euclidean shape spaces: a novel environment for statistics." Annals of
statistics, 1993, vol. 21, no 3, p. 1225-1271.
https://doi.org/10.1112/blms/16.2.81
"""
def __init__(self, points=None, coords_type="pre-shape"):
self.points = []
self.coords_type = coords_type
self.ax = None
self.pole = gs.array(
[[1.0, 0.0], [-0.5, gs.sqrt(3.0) / 2.0], [-0.5, -gs.sqrt(3.0) / 2.0]]
) / gs.sqrt(3.0)
self.centre = gs.array(
[
[1.0, 0.0, 0.0],
[-0.5, gs.sqrt(3.0) / 2.0, 0.0],
[-0.5, -gs.sqrt(3.0) / 2.0, 0.0],
]
) / gs.sqrt(3.0)
self.ua = gs.array(
[[-1.0, 0.0], [0.5, gs.sqrt(3.0) / 2.0], [0.5, -gs.sqrt(3.0) / 2.0]]
) / gs.sqrt(3.0)
self.ub = gs.array(
[[0.5, gs.sqrt(3.0) / 2.0], [0.5, -gs.sqrt(3.0) / 2], [-1.0, 0.0]]
) / gs.sqrt(3.0)
self.na = self.ub - S32.metric.inner_product(self.ub, self.ua) * self.ua
self.na = self.na / S32.metric.norm(self.na)
if points is not None:
self.add_points(points)
[docs]
def set_ax(self, ax=None):
"""Set axis."""
if ax is None:
ax = plt.subplot()
ax_s = gs.pi / 4 + 0.05
plt.setp(ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), xlabel="X", ylabel="Y")
self.ax = ax
[docs]
def convert_to_polar_coordinates(self, points):
"""Assign polar coordinates to given pre-shapes."""
aligned_points = S33.fiber_bundle.align(points, self.centre)
aligned_points2d = aligned_points[..., :, :2]
speeds = S32.metric.log(aligned_points2d, self.pole)
coords_r = S32.metric.dist(self.pole, aligned_points2d)
coords_theta = gs.arctan2(
S32.metric.inner_product(speeds, self.na),
S32.metric.inner_product(speeds, self.ua),
)
return coords_r, coords_theta
[docs]
def convert_to_planar_coordinates(self, points):
"""Convert polar coordinates to spherical one."""
coords_r, coords_theta = self.convert_to_polar_coordinates(points)
coords_x = coords_r * gs.cos(coords_theta)
coords_y = coords_r * gs.sin(coords_theta)
planar_coords = gs.transpose(gs.stack((coords_x, coords_y)))
return planar_coords
[docs]
def add_points(self, points):
"""Add points to draw on the Kendall disk."""
if self.coords_type == "extrinsic":
if not gs.all(M33.belongs(points)):
raise ValueError("Points do not belong to Matrices(3, 3).")
points = S33.projection(points)
elif self.coords_type == "pre-shape" and not gs.all(S33.belongs(points)):
raise ValueError("Points do not belong to the pre-shape space.")
points = self.convert_to_planar_coordinates(points)
if not isinstance(points, list):
if points.shape == (2,):
points = [gs.array(points)]
else:
points = list(points)
self.points.extend(points)
[docs]
def clear_points(self):
"""Clear the points to draw."""
self.points = []
[docs]
def draw(self, n_r=7, n_theta=25, scale=0.05):
"""Draw the disk regularly sampled with corresponding triangles."""
self.set_ax()
self.ax.set_axis_off()
plt.tight_layout()
coords_r = gs.linspace(0.0, gs.pi / 4.0, n_r)
coords_theta = gs.linspace(0.0, 2.0 * gs.pi, n_theta)
coords_x = gs.to_numpy(gs.outer(coords_r, gs.cos(coords_theta)))
coords_y = gs.to_numpy(gs.outer(coords_r, gs.sin(coords_theta)))
self.ax.fill(
list(coords_x[-1, :]),
list(coords_y[-1, :]),
color="grey",
alpha=0.1,
zorder=-1,
)
for i_r in range(n_r):
self.ax.plot(
coords_x[i_r, :],
coords_y[i_r, :],
linewidth=0.6,
color="grey",
alpha=0.6,
zorder=-1,
)
for i_t in range(n_theta):
self.ax.plot(
coords_x[:, i_t],
coords_y[:, i_t],
linewidth=0.6,
color="grey",
alpha=0.6,
zorder=-1,
)
for r in gs.linspace(0.0, gs.pi / 4, n_r):
for theta in gs.linspace(0.0, 2.0 * gs.pi, n_theta // 2 + 1):
if theta == 0.0:
self.draw_triangle(0.0, 0.0, scale)
else:
self.draw_triangle(r, theta, scale)
[docs]
def draw_triangle(self, r, theta, scale):
"""Draw the corresponding triangle on the disk at r, theta."""
u_theta = gs.cos(theta) * self.ua + gs.sin(theta) * self.na
triangle = gs.cos(r) * self.pole + gs.sin(r) * u_theta
triangle = scale * triangle
x = list(r * gs.cos(theta) + triangle[:, 0])
x = x + [x[0]]
y = list(r * gs.sin(theta) + triangle[:, 1])
y = y + [y[0]]
self.ax.plot(x, y, "grey", zorder=1)
c = ["red", "green", "blue"]
for i in range(3):
self.ax.scatter(x[i], y[i], color=c[i], s=10, alpha=1, zorder=1)
[docs]
def draw_points(self, alpha=1, zorder=0, **kwargs):
"""Draw points on the Kendall disk."""
points_x = [gs.to_numpy(point)[0] for point in self.points]
points_y = [gs.to_numpy(point)[1] for point in self.points]
self.ax.scatter(points_x, points_y, alpha=alpha, zorder=zorder, **kwargs)
[docs]
def draw_curve(self, alpha=1, zorder=0, **kwargs):
"""Draw a curve on the Kendall disk."""
points_x = [gs.to_numpy(point)[0] for point in self.points]
points_y = [gs.to_numpy(point)[1] for point in self.points]
self.ax.plot(points_x, points_y, alpha=alpha, zorder=zorder, **kwargs)
[docs]
def draw_vector(self, tangent_vec, base_point, tol=1e-03, **kwargs):
"""Draw one vector in the tangent space to disk at a base point."""
r_bp, th_bp = self.convert_to_polar_coordinates(base_point)
bp = gs.array(
[
gs.cos(th_bp) * gs.sin(2 * r_bp),
gs.sin(th_bp) * gs.sin(2 * r_bp),
gs.cos(2 * r_bp),
]
)
r_exp, th_exp = self.convert_to_polar_coordinates(
S33.quotient.metric.exp(
tol * tangent_vec / S33.quotient.metric.norm(tangent_vec, base_point),
base_point,
)
)
exp = gs.array(
[
gs.cos(th_exp) * gs.sin(2 * r_exp),
gs.sin(th_exp) * gs.sin(2 * r_exp),
gs.cos(2 * r_exp),
]
)
pole = gs.array([0.0, 0.0, 1.0])
tv = exp - gs.dot(exp, bp) * bp
u_tv = tv / gs.linalg.norm(tv)
u_r = (gs.dot(pole, bp) * bp - pole) / gs.linalg.norm(
gs.dot(pole, bp) * bp - pole
)
u_th = gs.cross(bp, u_r)
x_r, x_th = gs.dot(u_tv, u_r), gs.dot(u_tv, u_th)
bp = self.convert_to_planar_coordinates(base_point)
u_r = bp / gs.linalg.norm(bp)
u_th = gs.array([[0.0, -1.0], [1.0, 0.0]]) @ u_r
tv = S33.quotient.metric.norm(tangent_vec, base_point) * (
x_r * u_r + x_th * u_th
)
self.ax.quiver(bp[0], bp[1], tv[0], tv[1], **kwargs)