_q\n",
"\n",
" where :math:`<\\cdot, \\cdot>_q` is the cometric at :math:`q`.\n",
"\n",
" Parameters\n",
" ----------\n",
" state : tuple of arrays\n",
" Position and momentum variables. The position is a point on the\n",
" manifold, while the momentum is cotangent vector.\n",
"\n",
" Returns\n",
" -------\n",
" energy : float\n",
" Hamiltonian energy at `state`.\n",
" \"\"\"\n",
" position, momentum = state\n",
" return 1.0 / 2 * self.inner_coproduct(momentum, momentum, position)\n",
"\n",
" def squared_norm(self, vector, base_point=None):\n",
" \"\"\"Compute the square of the norm of a vector.\n",
"\n",
" Squared norm of a vector associated to the inner product\n",
" at the tangent space at a base point.\n",
"\n",
" Parameters\n",
" ----------\n",
" vector : array-like, shape=[..., dim]\n",
" Vector.\n",
" base_point : array-like, shape=[..., dim]\n",
" Base point.\n",
" Optional, default: None.\n",
"\n",
" Returns\n",
" -------\n",
" sq_norm : array-like, shape=[...,]\n",
" Squared norm.\n",
" \"\"\"\n",
" return self.inner_product(vector, vector, base_point)\n",
"\n",
" def norm(self, vector, base_point=None):\n",
" \"\"\"Compute norm of a vector.\n",
"\n",
" Norm of a vector associated to the inner product\n",
" at the tangent space at a base point.\n",
"\n",
" Note: This only works for positive-definite\n",
" Riemannian metrics and inner products.\n",
"\n",
" Parameters\n",
" ----------\n",
" vector : array-like, shape=[..., dim]\n",
" Vector.\n",
" base_point : array-like, shape=[..., dim]\n",
" Base point.\n",
" Optional, default: None.\n",
"\n",
" Returns\n",
" -------\n",
" norm : array-like, shape=[...,]\n",
" Norm.\n",
" \"\"\"\n",
" sq_norm = self.squared_norm(vector, base_point)\n",
" return gs.sqrt(sq_norm)\n",
"\n",
" def normalize(self, vector, base_point):\n",
" \"\"\"Normalize tangent vector at a given point.\n",
"\n",
" Parameters\n",
" ----------\n",
" vector : array-like, shape=[..., dim]\n",
" Tangent vector at base_point.\n",
" base_point : array-like, shape=[..., dim]\n",
" Point.\n",
"\n",
" Returns\n",
" -------\n",
" normalized_vector : array-like, shape=[..., dim]\n",
" Unit tangent vector at base_point.\n",
" \"\"\"\n",
" norm = self.norm(vector, base_point)\n",
" norm = gs.where(norm == 0, gs.ones(norm.shape), norm)\n",
" indices = \"ijk\"[: self._space.point_ndim]\n",
" return gs.einsum(f\"...{indices},...->...{indices}\", vector, 1 / norm)\n",
"\n",
" def random_unit_tangent_vec(self, base_point, n_vectors=1):\n",
" \"\"\"Generate a random unit tangent vector at a given point.\n",
"\n",
" Parameters\n",
" ----------\n",
" base_point : array-like, shape=[..., dim]\n",
" Point.\n",
" n_vectors : float\n",
" Number of vectors to be generated at base_point.\n",
" For vectorization purposes n_vectors can be greater than 1 iff\n",
" base_point consists of a single point.\n",
"\n",
" Returns\n",
" -------\n",
" normalized_vector : array-like, shape=[..., n_vectors, dim]\n",
" Random unit tangent vector at base_point.\n",
" \"\"\"\n",
" is_batch = check_is_batch(self._space.point_ndim, base_point)\n",
" if is_batch and n_vectors > 1:\n",
" raise ValueError(\n",
" \"Several tangent vectors is only applicable to a single base point.\"\n",
" )\n",
" point_shape = self._space.shape\n",
" vec_shape = (n_vectors, *point_shape) if n_vectors > 1 else point_shape\n",
" random_vector = gs.random.rand(*vec_shape)\n",
" return self.normalize(random_vector, base_point)\n",
"\n",
" def squared_dist(self, point_a, point_b, **kwargs):\n",
" \"\"\"Squared geodesic distance between two points.\n",
"\n",
" Parameters\n",
" ----------\n",
" point_a : array-like, shape=[..., dim]\n",
" Point.\n",
" point_b : array-like, shape=[..., dim]\n",
" Point.\n",
"\n",
" Returns\n",
" -------\n",
" sq_dist : array-like, shape=[...,]\n",
" Squared distance.\n",
" \"\"\"\n",
" log = self.log(point=point_b, base_point=point_a, **kwargs)\n",
"\n",
" return self.squared_norm(vector=log, base_point=point_a)\n",
"\n",
" def dist(self, point_a, point_b, **kwargs):\n",
" \"\"\"Geodesic distance between two points.\n",
"\n",
" Note: It only works for positive definite\n",
" Riemannian metrics.\n",
"\n",
" Parameters\n",
" ----------\n",
" point_a : array-like, shape=[..., dim]\n",
" Point.\n",
" point_b : array-like, shape=[..., dim]\n",
" Point.\n",
"\n",
" Returns\n",
" -------\n",
" dist : array-like, shape=[...,]\n",
" Distance.\n",
" \"\"\"\n",
" sq_dist = self.squared_dist(point_a, point_b, **kwargs)\n",
" return gs.sqrt(sq_dist)\n",
"\n",
" def dist_broadcast(self, point_a, point_b):\n",
" \"\"\"Compute the geodesic distance between points.\n",
"\n",
" If n_samples_a == n_samples_b then dist is the element-wise\n",
" distance result of a point in points_a with the point from\n",
" points_b of the same index. If n_samples_a not equal to\n",
" n_samples_b then dist is the result of applying geodesic\n",
" distance for each point from points_a to all points from\n",
" points_b.\n",
"\n",
" Parameters\n",
" ----------\n",
" point_a : array-like, shape=[n_samples_a, dim]\n",
" Set of points in the Poincare ball.\n",
" point_b : array-like, shape=[n_samples_b, dim]\n",
" Second set of points in the Poincare ball.\n",
"\n",
" Returns\n",
" -------\n",
" dist : array-like, \\\n",
" shape=[n_samples_a, dim] or [n_samples_a, n_samples_b, dim]\n",
" Geodesic distance between the two points.\n",
" \"\"\"\n",
" ndim = len(self._space.shape)\n",
"\n",
" if point_a.shape[-ndim:] != point_b.shape[-ndim:]:\n",
" raise ValueError(\"Manifold dimensions not equal\")\n",
"\n",
" if ndim in (point_a.ndim, point_b.ndim) or (point_a.shape == point_b.shape):\n",
" return self.dist(point_a, point_b)\n",
"\n",
" n_samples = point_a.shape[0] * point_b.shape[0]\n",
" point_a_broadcast, point_b_broadcast = gs.broadcast_arrays(\n",
" point_a[:, None], point_b[None, ...]\n",
" )\n",
"\n",
" point_a_flatten = gs.reshape(\n",
" point_a_broadcast, (n_samples,) + point_a.shape[-ndim:]\n",
" )\n",
" point_b_flatten = gs.reshape(\n",
" point_b_broadcast, (n_samples,) + point_a.shape[-ndim:]\n",
" )\n",
"\n",
" dist = self.dist(point_a_flatten, point_b_flatten)\n",
" dist = gs.reshape(dist, (point_a.shape[0], point_b.shape[0]))\n",
" return gs.squeeze(dist)\n",
"\n",
" def dist_pairwise(self, points, n_jobs=1, **joblib_kwargs):\n",
" \"\"\"Compute the pairwise distance between points.\n",
"\n",
" Parameters\n",
" ----------\n",
" points : array-like, shape=[n_samples, dim]\n",
" Set of points in the manifold.\n",
" n_jobs : int\n",
" Number of jobs to run in parallel, using joblib. Note that a\n",
" higher number of jobs may not be beneficial when one computation\n",
" of a geodesic distance is cheap.\n",
" Optional. Default: 1.\n",
" **joblib_kwargs : dict\n",
" Keyword arguments to joblib.Parallel\n",
"\n",
" Returns\n",
" -------\n",
" dist : array-like, shape=[n_samples, n_samples]\n",
" Pairwise distance matrix between all the points.\n",
"\n",
" See Also\n",
" --------\n",
" `joblib documentations