"""Group actions."""
import math
from abc import ABC, abstractmethod
import geomstats.backend as gs
from geomstats.geometry.matrices import Matrices
from geomstats.geometry.special_orthogonal import SpecialOrthogonal
from geomstats.vectorization import get_batch_shape
[docs]
class GroupAction(ABC):
"""Base class for group action."""
@abstractmethod
def __call__(self, group_elem, point):
"""Perform action of a group element on a manifold point.
Parameters
----------
group_elem : array-like, shape=[..., *group.shape]
The element of a group.
point : array-like, shape=[..., *space.shape]
A point on a manifold.
Returns
-------
orbit_point : array-like
A point on the orbit of point.
"""
[docs]
class ComposeAction(GroupAction):
"""Action of a group on itself by composition.
Parameters
----------
group : LieGroup
lie_algebra_repr : bool
If True, group elements are represented by the vector
representation of Lie algebra elements.
This is amenable to perform gradient-based optimizations
on the Lie algebra.
"""
def __init__(self, group, lie_algebra_repr=True):
self._group = group
self.lie_algebra_repr = lie_algebra_repr
@property
def group_elem_shape(self):
"""Shape of the group element representation."""
if self.lie_algebra_repr:
return (self._group.lie_algebra.dim,)
return self._group.shape
def __call__(self, group_elem, point):
"""Compose action of a group element on a point.
Parameters
----------
group_elem : array-like, shape=[..., *group_elem_shape]
Group element in chosen representation.
point : array-like, shape=[..., *group.shape]
Point on the group.
Returns
-------
orbit_point : array-like, shape=[..., *group.shape]
A point in the orbit of point.
"""
if self.lie_algebra_repr:
algebra_elt = self._group.lie_algebra.matrix_representation(group_elem)
group_elem = self._group.exp(algebra_elt)
return self._group.compose(point, group_elem)
[docs]
class SpecialOrthogonalComposeAction(ComposeAction):
"""Action of the special orthogonal group.
Parameters
----------
n : int
Integer representing the shapes of the matrices : n x n.
lie_algebra_repr : bool
If True, group elements are represented by the vector
representation of Lie algebra elements.
This is amenable to perform gradient-based optimizations
on the Lie algebra.
"""
def __init__(self, n):
super().__init__(SpecialOrthogonal(n, equip=False))
[docs]
class CongruenceAction(GroupAction):
"""Congruence action."""
def __call__(self, group_elem, point):
"""Congruence action of a group element on a matrix.
Parameters
----------
group_elem : array-like, shape=[..., n, n]
The element of a group.
point : array-like, shape=[..., n, n]
A point on a manifold.
Returns
-------
orbit_point : array-like, shape=[..., n, n]
A point on the orbit of point.
"""
return Matrices.mul(group_elem, point, Matrices.transpose(group_elem))
[docs]
class PermutationAction(CongruenceAction):
"""Congruence action of the permutation group on matrices."""
def __call__(self, group_elem, point):
"""Congruence action of a group element on a matrix.
Parameters
----------
group_elem : array-like, shape=[..., n]
Permutations where in position i we have the value j meaning
the node i should be permuted with node j.
point : array-like, shape=[..., n, n]
A point on a manifold.
Returns
-------
orbit_point : array-like, shape=[..., n, n]
A point on the orbit of point.
"""
perm_mat = permutation_matrix_from_vector(group_elem, dtype=point.dtype)
return super().__call__(perm_mat, point)
[docs]
class RowPermutationAction(GroupAction):
"""Action of the permutation group on matrices by multiplication."""
def __call__(self, group_elem, point):
"""Permutation action applied to matrix.
Parameters
----------
group_elem: array-like, shape=[..., n]
Permutations where in position i we have the value j meaning
the node i should be permuted with node j.
point : array-like, shape=[..., n, n]
Matrices to be permuted.
Returns
-------
permuted_point : array-like, shape=[..., n, n]
Permuted matrices.
"""
perm_mat = permutation_matrix_from_vector(group_elem, dtype=point.dtype)
return gs.matmul(Matrices.transpose(perm_mat), point)
[docs]
def permutation_matrix_from_vector(group_elem, dtype=gs.int64):
"""Transform a permutation vector into a matrix.
Parameters
----------
group_elem : array-like, shape=[..., n]
The element of a group.
dtype: gs.dtype
Array dtype
Returns
-------
mat_group_element : array-like, shape=[..., n, n]
Matrix representation of group element.
"""
batch_shape = get_batch_shape(1, group_elem)
n = group_elem.shape[-1]
if batch_shape:
indices = gs.array(
[
(k, i, j)
for (k, group_elem_) in enumerate(group_elem)
for (i, j) in zip(range(n), group_elem_)
]
)
else:
indices = gs.array(list(zip(range(n), group_elem)))
return gs.array_from_sparse(
data=gs.ones(math.prod(batch_shape) * n, dtype=dtype),
indices=indices,
target_shape=batch_shape + (n, n),
)