"""Visualization for Geometric Statistics."""
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # NOQA
import geomstats.backend as gs
from geomstats.geometry.hypersphere import Hypersphere
S1 = Hypersphere(dim=1)
S2 = Hypersphere(dim=2)
AX_SCALE = 1.2
[docs]
class Circle:
"""Class used to draw a circle."""
def __init__(self, n_angles=100, points=None):
angles = gs.linspace(0.0, 2 * gs.pi, n_angles + 1)
self.circle_x = gs.cos(angles)
self.circle_y = gs.sin(angles)
self.points = []
if points is not None:
self.add_points(points)
[docs]
@staticmethod
def set_ax(ax=None):
"""Set axis."""
if ax is None:
ax = plt.subplot()
ax_s = AX_SCALE
plt.setp(
ax,
xlim=(-ax_s, ax_s),
ylim=(-ax_s, ax_s),
xlabel="X",
ylabel="Y",
)
return ax
[docs]
def add_points(self, points):
"""Add points."""
if not gs.all(S1.belongs(points)):
raise ValueError("Points do not belong to the circle.")
if not isinstance(points, list):
points = list(points)
self.points.extend(points)
[docs]
def draw(self, ax, **plot_kwargs):
"""Plot circle shape."""
ax.plot(self.circle_x, self.circle_y, color="black")
if self.points:
self.draw_points(ax, **plot_kwargs)
[docs]
def draw_points(self, ax, points=None, **plot_kwargs):
"""Plot points."""
if points is None:
points = self.points
points = gs.array(points)
ax.plot(points[:, 0], points[:, 1], marker="o", linestyle="None", **plot_kwargs)
[docs]
def plot(self, points, ax=None, **point_draw_kwargs):
"""Plot points in the circle."""
ax = self.set_ax(ax=ax)
self.add_points(points)
self.draw(ax, **point_draw_kwargs)
[docs]
class Sphere:
"""Create the arrays sphere_x, sphere_y, sphere_z to plot a sphere.
Create the arrays sphere_x, sphere_y, sphere_z of values
to plot the wireframe of a sphere.
Their shape is (n_meridians, n_circles_latitude).
"""
def __init__(self, n_meridians=40, n_circles_latitude=None, points=None):
if n_circles_latitude is None:
n_circles_latitude = max(n_meridians // 2, 4)
u, v = gs.meshgrid(
gs.linspace(0.0, 2 * gs.pi, n_meridians + 1),
gs.linspace(0.0, gs.pi, n_circles_latitude + 1),
)
self.center = gs.zeros(3)
self.radius = 1
self.sphere_x = self.center[0] + self.radius * gs.cos(u) * gs.sin(v)
self.sphere_y = self.center[1] + self.radius * gs.sin(u) * gs.sin(v)
self.sphere_z = self.center[2] + self.radius * gs.cos(v)
self.points = []
if points is not None:
self.add_points(points)
[docs]
@staticmethod
def set_ax(ax=None):
"""Set axis."""
if ax is None:
ax = plt.subplot(111, projection="3d")
ax_s = AX_SCALE
plt.setp(
ax,
xlim=(-ax_s, ax_s),
ylim=(-ax_s, ax_s),
zlim=(-ax_s, ax_s),
xlabel="X",
ylabel="Y",
zlabel="Z",
)
ax.set_box_aspect([1.0, 1.0, 1.0])
return ax
[docs]
def add_points(self, points):
"""Add points."""
if not gs.all(S2.belongs(points)):
raise ValueError("Points do not belong to the sphere.")
if not isinstance(points, list):
points = list(points)
self.points.extend(points)
[docs]
def draw(self, ax, **scatter_kwargs):
"""Plot sphere shape."""
ax.plot_wireframe(
self.sphere_x, self.sphere_y, self.sphere_z, color="grey", alpha=0.2
)
ax.set_box_aspect([1.0, 1.0, 1.0])
if self.points:
self.draw_points(ax, **scatter_kwargs)
[docs]
def draw_points(self, ax, points=None, **scatter_kwargs):
"""Plot points."""
if points is None:
points = self.points
points = [gs.to_numpy(point) for point in points]
points_x = [point[0] for point in points]
points_y = [point[1] for point in points]
points_z = [point[2] for point in points]
ax.scatter(points_x, points_y, points_z, **scatter_kwargs)
for i_point, point in enumerate(points):
if "label" in scatter_kwargs:
if len(scatter_kwargs["label"]) == len(points):
ax.text(
point[0],
point[1],
point[2],
scatter_kwargs["label"][i_point],
size=10,
zorder=1,
color="k",
)
[docs]
def get_fibonnaci_points(self, n_points=16000):
"""Get spherical Fibonacci point sets.
Point sets are yield nearly uniform point distributions on the unit
sphere.
"""
x_vals = []
y_vals = []
z_vals = []
offset = 2.0 / n_points
increment = gs.pi * (3.0 - gs.sqrt(5.0))
for i in range(n_points):
y = ((i * offset) - 1) + (offset / 2)
r = gs.sqrt(1 - pow(y, 2))
phi = ((i + 1) % n_points) * increment
x = gs.cos(phi) * r
z = gs.sin(phi) * r
x_vals.append(x)
y_vals.append(y)
z_vals.append(z)
x_vals = [(self.radius * i) for i in x_vals]
y_vals = [(self.radius * i) for i in y_vals]
z_vals = [(self.radius * i) for i in z_vals]
return gs.array([x_vals, y_vals, z_vals])
[docs]
def plot_heatmap(self, ax, scalar_function, n_points=16000, alpha=0.2, cmap="jet"):
"""Plot a heatmap defined by a loss on the sphere."""
points = self.get_fibonnaci_points(n_points)
intensity = gs.array([scalar_function(x) for x in points.T])
ax.scatter(
points[0, :],
points[1, :],
points[2, :],
c=intensity,
alpha=alpha,
marker=".",
cmap=plt.get_cmap(cmap),
)
[docs]
def plot(self, points, ax=None, **point_draw_kwargs):
"""Plot points in the sphere."""
ax = self.set_ax(ax=ax)
self.points = []
self.add_points(points)
self.draw(ax, **point_draw_kwargs)