Source code for geomstats.visualization.special_euclidean

"""Visualization for Geometric Statistics."""

import logging

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # NOQA

import geomstats.backend as gs
from geomstats.geometry.special_euclidean import SpecialEuclidean

SE2_GROUP = SpecialEuclidean(n=2, point_type="matrix")
SE2_VECT = SpecialEuclidean(n=2, point_type="vector")


[docs] class SpecialEuclidean2: """Class used to plot points in the 2d special euclidean group.""" def __init__(self, points=None, point_type="matrix"): self.points = [] self.point_type = point_type if points is not None: self.add_points(points)
[docs] @staticmethod def set_ax(ax=None, x_lim=None, y_lim=None): if ax is None: ax = plt.subplot() if x_lim is not None: ax.set_xlim(x_lim) if y_lim is not None: ax.set_ylim(y_lim) return ax
[docs] def add_points(self, points): if self.point_type == "vector": points = SE2_VECT.matrix_from_vector(points) if not gs.all(SE2_GROUP.belongs(points)): logging.warning("Some points do not belong to SE2.") if not isinstance(points, list): points = list(points) self.points.extend(points)
[docs] def draw_points(self, ax, points=None, **kwargs): if points is None: points = gs.array(self.points) translation = points[..., :2, 2] frame_1 = points[:, :2, 0] frame_2 = points[:, :2, 1] ax.quiver( translation[:, 0], translation[:, 1], frame_1[:, 0], frame_1[:, 1], width=0.005, color="b", ) ax.quiver( translation[:, 0], translation[:, 1], frame_2[:, 0], frame_2[:, 1], width=0.005, color="r", ) ax.scatter(translation[:, 0], translation[:, 1], s=16, **kwargs)