"""The Visualization Package."""
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # NOQA
import geomstats.backend as gs
from geomstats.visualization.hyperbolic import (
KleinDisk,
PoincareDisk,
PoincareHalfPlane,
)
from geomstats.visualization.hypersphere import Circle, Sphere
from geomstats.visualization.poincare_polydisk import PoincarePolyDisk
from geomstats.visualization.pre_shape import KendallDisk, KendallSphere
from geomstats.visualization.spd_matrices import Ellipses
from geomstats.visualization.special_euclidean import SpecialEuclidean2
from geomstats.visualization.special_orthogonal import (
Arrow3D,
Trihedron,
convert_to_trihedron,
)
AX_SCALE = 1.2
IMPLEMENTED = [
"SO3_GROUP",
"SE3_GROUP",
"SE2_GROUP",
"S1",
"S2",
"H2_poincare_disk",
"H2_poincare_half_plane",
"H2_klein_disk",
"poincare_polydisk",
"S32",
"M32",
"S33",
"M33",
"SPD2",
]
[docs]
def tutorial_matplotlib():
"""Configure style for matplotlib tutorial."""
fontsize = 12
matplotlib.rc("font", size=fontsize)
matplotlib.rc("text")
matplotlib.rc("legend", fontsize=fontsize)
matplotlib.rc("axes", titlesize=21, labelsize=14)
matplotlib.rc(
"font",
family="times",
serif=["Computer Modern Roman"],
monospace=["Computer Modern Typewriter"],
)
[docs]
def plot(points, ax=None, space=None, coords_type=None, **point_draw_kwargs):
"""Plot points in one of the implemented manifolds.
The implemented manifolds are:
- the special orthogonal group SO(3)
- the special Euclidean group SE(3)
- the circle S1 and the sphere S2
- the hyperbolic plane (the Poincare disk, the Poincare
half plane and the Klein disk)
- the Poincare polydisk
- the Kendall shape space of 2D triangles
- the Kendall shape space of 3D triangles
Parameters
----------
points : array-like, shape=[..., dim]
Points to be plotted.
space: str, optional, {'SO3_GROUP', 'SE3_GROUP', 'S1', 'S2',
'H2_poincare_disk', 'H2_poincare_half_plane', 'H2_klein_disk',
'poincare_polydisk', 'S32', 'M32', 'S33', 'M33', 'SPD2'}
coords_type: str, optional, {'extrinsic', 'ball', 'half-space', 'pre-shape'}
"""
if space not in IMPLEMENTED:
raise NotImplementedError(
"The plot function is not implemented"
" for space {}. The spaces available for visualization"
" are: {}.".format(space, IMPLEMENTED)
)
if points is None:
raise ValueError("No points given for plotting.")
if points.ndim < 2:
points = gs.expand_dims(points, 0)
if space in ("SO3_GROUP", "SE3_GROUP"):
if ax is None:
ax = plt.subplot(111, projection="3d")
if space == "SE3_GROUP":
ax_s = AX_SCALE * gs.amax(gs.abs(points[:, 3:6]))
elif space == "SO3_GROUP":
ax_s = AX_SCALE * gs.amax(gs.abs(points[:, :3]))
ax_s = float(ax_s)
bounds = (-ax_s, ax_s)
plt.setp(
ax,
xlim=bounds,
ylim=bounds,
zlim=bounds,
xlabel="X",
ylabel="Y",
zlabel="Z",
)
trihedrons = convert_to_trihedron(points, space=space)
for t in trihedrons:
t.draw(ax, **point_draw_kwargs)
elif space == "S1":
circle = Circle()
ax = circle.set_ax(ax=ax)
circle.add_points(points)
circle.draw(ax, **point_draw_kwargs)
elif space == "S2":
sphere = Sphere()
ax = sphere.set_ax(ax=ax)
sphere.add_points(points)
sphere.draw(ax, **point_draw_kwargs)
elif space == "H2_poincare_disk":
if coords_type is None:
coords_type = "extrinsic"
poincare_disk = PoincareDisk(coords_type=coords_type)
ax = poincare_disk.set_ax(ax=ax)
poincare_disk.add_points(points)
poincare_disk.draw(ax, **point_draw_kwargs)
plt.axis("off")
elif space == "poincare_polydisk":
if coords_type is None:
coords_type = "extrinsic"
n_disks = points.shape[1]
poincare_poly_disk = PoincarePolyDisk(coords_type=coords_type, n_disks=n_disks)
n_columns = int(gs.ceil(n_disks**0.5))
n_rows = int(gs.ceil(n_disks / n_columns))
axis_list = []
for i_disk in range(n_disks):
axis_list.append(ax.add_subplot(n_rows, n_columns, i_disk + 1))
for i_disk, one_ax in enumerate(axis_list):
ax = poincare_poly_disk.set_ax(ax=one_ax)
poincare_poly_disk.clear_points()
poincare_poly_disk.add_points(points[:, i_disk, ...])
poincare_poly_disk.draw(ax, **point_draw_kwargs)
elif space == "H2_poincare_half_plane":
if coords_type is None:
coords_type = "half-space"
poincare_half_plane = PoincareHalfPlane(coords_type=coords_type)
ax = poincare_half_plane.set_ax(ax=ax)
poincare_half_plane.add_points(points)
poincare_half_plane.draw(ax, **point_draw_kwargs)
elif space == "H2_klein_disk":
klein_disk = KleinDisk()
ax = klein_disk.set_ax(ax=ax)
klein_disk.add_points(points)
klein_disk.draw(ax, **point_draw_kwargs)
elif space == "SE2_GROUP":
plane = SpecialEuclidean2()
ax = plane.set_ax(ax=ax)
plane.add_points(points)
plane.draw_points(ax, **point_draw_kwargs)
elif space == "S32":
sphere = KendallSphere()
sphere.add_points(points)
sphere.draw()
sphere.draw_points()
ax = sphere.ax
elif space == "M32":
sphere = KendallSphere(coords_type="extrinsic")
sphere.add_points(points)
sphere.draw()
sphere.draw_points()
ax = sphere.ax
elif space == "S33":
disk = KendallDisk()
disk.add_points(points)
disk.draw()
disk.draw_points()
ax = disk.ax
elif space == "M33":
disk = KendallDisk(coords_type="extrinsic")
disk.add_points(points)
disk.draw()
disk.draw_points()
ax = disk.ax
elif space == "SPD2":
ellipses = Ellipses()
ellipses.draw_points(points=points)
return ax