Source code for geomstats.visualization.poincare_polydisk

"""Visualization for Geometric Statistics."""

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # NOQA

import geomstats.backend as gs

AX_SCALE = 1.2


[docs] class PoincarePolyDisk: """Class used to plot points in the Poincare polydisk.""" def __init__(self, points=None, coords_type="ball", n_disks=2): self.center = gs.array([0.0, 0.0]) self.points = [] self.coords_type = coords_type self.n_disks = n_disks if points is not None: self.add_points(points)
[docs] @staticmethod def set_ax(ax=None): """Define the ax parameters.""" if ax is None: ax = plt.subplot() ax_s = AX_SCALE plt.setp(ax, xlim=(-ax_s, ax_s), ylim=(-ax_s, ax_s), xlabel="X", ylabel="Y") return ax
[docs] def add_points(self, points): """Add points to draw.""" if self.coords_type == "extrinsic": points = self.convert_to_poincare_coordinates(points) if not isinstance(points, list): points = list(points) self.points.extend(points)
[docs] def clear_points(self): """Clear the points to draw.""" self.points = []
[docs] @staticmethod def convert_to_poincare_coordinates(points): """Convert points to poincare coordinates.""" poincare_coords = points[:, 1:] / (1 + points[:, :1]) return poincare_coords
[docs] def draw(self, ax, **kwargs): """Draw.""" circle = plt.Circle((0, 0), radius=1.0, color="black", fill=False) ax.add_artist(circle) points_x = [gs.to_numpy(point[0]) for point in self.points] points_y = [gs.to_numpy(point[1]) for point in self.points] ax.scatter(points_x, points_y, **kwargs)